minhdang14902 commited on
Commit
931f63c
·
verified ·
1 Parent(s): 5f236ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
4
  import nltk
5
  from transformers.models.roberta.modeling_roberta import *
6
  from transformers import RobertaForQuestionAnswering
@@ -354,6 +355,81 @@ def extract_answer(inputs, outputs, tokenizer):
354
  })
355
  return plain_result
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  # st.title("Chatbot Roberta")
358
  # st.write("Hi! Tôi là trợ lý của bạn trong việc trả lời các câu hỏi.")
359
  # text = st.text_input("User: ", key="input")
@@ -398,7 +474,7 @@ def get_response(text):
398
  answer, context = chatRoberta(text)
399
  result = answer[0]['answer']
400
  if result == "":
401
- return "Xin lỗi, tôi không thể tìm được đáp án phù hợp cho câu hỏi này ... Hãy thử trả lời bằng câu hỏi khác!"
402
  return result
403
 
404
  st.title("General Law Chatbot")
 
1
  import streamlit as st
2
  import torch
3
+ import pytorch_lightning as pl
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, T5Tokenizer, T5ForConditionalGeneration
5
  import nltk
6
  from transformers.models.roberta.modeling_roberta import *
7
  from transformers import RobertaForQuestionAnswering
 
355
  })
356
  return plain_result
357
 
358
+ #T555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555
359
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
360
+
361
+ INPUT_MAX_LEN = 128 # Adjusted input length
362
+ OUTPUT_MAX_LEN = 256 # Adjusted output length
363
+
364
+ MODEL_NAME = "VietAI/vit5-base"
365
+
366
+ tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=INPUT_MAX_LEN)
367
+
368
+ class T5Model(pl.LightningModule):
369
+ def __init__(self):
370
+ super().__init__()
371
+ self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
372
+
373
+ def forward(self, input_ids, attention_mask, labels=None):
374
+ output = self.model(
375
+ input_ids=input_ids,
376
+ attention_mask=attention_mask,
377
+ labels=labels
378
+ )
379
+ return output.loss, output.logits
380
+
381
+ def training_step(self, batch, batch_idx):
382
+ input_ids = batch["input_ids"].to(DEVICE)
383
+ attention_mask = batch["attention_mask"].to(DEVICE)
384
+ labels = batch["target"].to(DEVICE)
385
+ loss, logits = self(input_ids, attention_mask, labels)
386
+ self.log("train_loss", loss, prog_bar=True, logger=True)
387
+ return {'loss': loss}
388
+
389
+ def validation_step(self, batch, batch_idx):
390
+ input_ids = batch["input_ids"].to(DEVICE)
391
+ attention_mask = batch["attention_mask"].to(DEVICE)
392
+ labels = batch["target"].to(DEVICE)
393
+ loss, logits = self(input_ids, attention_mask, labels)
394
+ self.log("val_loss", loss, prog_bar=True, logger=True)
395
+ return {'val_loss': loss}
396
+
397
+ def configure_optimizers(self):
398
+ return AdamW(self.parameters(), lr=0.0001)
399
+
400
+ train_model = T5Model.load_from_checkpoint('./data-law/law-model-v1.ckpt')
401
+ train_model.freeze()
402
+
403
+
404
+ def generate_question(question):
405
+ inputs_encoding = tokenizer(
406
+ question,
407
+ add_special_tokens=True,
408
+ max_length=INPUT_MAX_LEN,
409
+ padding='max_length',
410
+ truncation='only_first',
411
+ return_attention_mask=True,
412
+ return_tensors="pt"
413
+ ).to(DEVICE)
414
+
415
+ generate_ids = train_model.model.generate(
416
+ input_ids=inputs_encoding["input_ids"],
417
+ attention_mask=inputs_encoding["attention_mask"],
418
+ max_length=INPUT_MAX_LEN,
419
+ num_beams=4,
420
+ num_return_sequences=1,
421
+ no_repeat_ngram_size=2,
422
+ early_stopping=True,
423
+ )
424
+
425
+ preds = [
426
+ tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
427
+ for gen_id in generate_ids
428
+ ]
429
+
430
+ response = " ".join(preds[0].split())
431
+ return response
432
+
433
  # st.title("Chatbot Roberta")
434
  # st.write("Hi! Tôi là trợ lý của bạn trong việc trả lời các câu hỏi.")
435
  # text = st.text_input("User: ", key="input")
 
474
  answer, context = chatRoberta(text)
475
  result = answer[0]['answer']
476
  if result == "":
477
+ return generate_question(text)
478
  return result
479
 
480
  st.title("General Law Chatbot")