Spaces:
Sleeping
Sleeping
minhdang14902
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
-
|
|
|
4 |
import nltk
|
5 |
from transformers.models.roberta.modeling_roberta import *
|
6 |
from transformers import RobertaForQuestionAnswering
|
@@ -354,6 +355,81 @@ def extract_answer(inputs, outputs, tokenizer):
|
|
354 |
})
|
355 |
return plain_result
|
356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
# st.title("Chatbot Roberta")
|
358 |
# st.write("Hi! Tôi là trợ lý của bạn trong việc trả lời các câu hỏi.")
|
359 |
# text = st.text_input("User: ", key="input")
|
@@ -398,7 +474,7 @@ def get_response(text):
|
|
398 |
answer, context = chatRoberta(text)
|
399 |
result = answer[0]['answer']
|
400 |
if result == "":
|
401 |
-
return
|
402 |
return result
|
403 |
|
404 |
st.title("General Law Chatbot")
|
|
|
1 |
import streamlit as st
|
2 |
import torch
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, T5Tokenizer, T5ForConditionalGeneration
|
5 |
import nltk
|
6 |
from transformers.models.roberta.modeling_roberta import *
|
7 |
from transformers import RobertaForQuestionAnswering
|
|
|
355 |
})
|
356 |
return plain_result
|
357 |
|
358 |
+
#T555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555555
|
359 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
360 |
+
|
361 |
+
INPUT_MAX_LEN = 128 # Adjusted input length
|
362 |
+
OUTPUT_MAX_LEN = 256 # Adjusted output length
|
363 |
+
|
364 |
+
MODEL_NAME = "VietAI/vit5-base"
|
365 |
+
|
366 |
+
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=INPUT_MAX_LEN)
|
367 |
+
|
368 |
+
class T5Model(pl.LightningModule):
|
369 |
+
def __init__(self):
|
370 |
+
super().__init__()
|
371 |
+
self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
|
372 |
+
|
373 |
+
def forward(self, input_ids, attention_mask, labels=None):
|
374 |
+
output = self.model(
|
375 |
+
input_ids=input_ids,
|
376 |
+
attention_mask=attention_mask,
|
377 |
+
labels=labels
|
378 |
+
)
|
379 |
+
return output.loss, output.logits
|
380 |
+
|
381 |
+
def training_step(self, batch, batch_idx):
|
382 |
+
input_ids = batch["input_ids"].to(DEVICE)
|
383 |
+
attention_mask = batch["attention_mask"].to(DEVICE)
|
384 |
+
labels = batch["target"].to(DEVICE)
|
385 |
+
loss, logits = self(input_ids, attention_mask, labels)
|
386 |
+
self.log("train_loss", loss, prog_bar=True, logger=True)
|
387 |
+
return {'loss': loss}
|
388 |
+
|
389 |
+
def validation_step(self, batch, batch_idx):
|
390 |
+
input_ids = batch["input_ids"].to(DEVICE)
|
391 |
+
attention_mask = batch["attention_mask"].to(DEVICE)
|
392 |
+
labels = batch["target"].to(DEVICE)
|
393 |
+
loss, logits = self(input_ids, attention_mask, labels)
|
394 |
+
self.log("val_loss", loss, prog_bar=True, logger=True)
|
395 |
+
return {'val_loss': loss}
|
396 |
+
|
397 |
+
def configure_optimizers(self):
|
398 |
+
return AdamW(self.parameters(), lr=0.0001)
|
399 |
+
|
400 |
+
train_model = T5Model.load_from_checkpoint('./data-law/law-model-v1.ckpt')
|
401 |
+
train_model.freeze()
|
402 |
+
|
403 |
+
|
404 |
+
def generate_question(question):
|
405 |
+
inputs_encoding = tokenizer(
|
406 |
+
question,
|
407 |
+
add_special_tokens=True,
|
408 |
+
max_length=INPUT_MAX_LEN,
|
409 |
+
padding='max_length',
|
410 |
+
truncation='only_first',
|
411 |
+
return_attention_mask=True,
|
412 |
+
return_tensors="pt"
|
413 |
+
).to(DEVICE)
|
414 |
+
|
415 |
+
generate_ids = train_model.model.generate(
|
416 |
+
input_ids=inputs_encoding["input_ids"],
|
417 |
+
attention_mask=inputs_encoding["attention_mask"],
|
418 |
+
max_length=INPUT_MAX_LEN,
|
419 |
+
num_beams=4,
|
420 |
+
num_return_sequences=1,
|
421 |
+
no_repeat_ngram_size=2,
|
422 |
+
early_stopping=True,
|
423 |
+
)
|
424 |
+
|
425 |
+
preds = [
|
426 |
+
tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
427 |
+
for gen_id in generate_ids
|
428 |
+
]
|
429 |
+
|
430 |
+
response = " ".join(preds[0].split())
|
431 |
+
return response
|
432 |
+
|
433 |
# st.title("Chatbot Roberta")
|
434 |
# st.write("Hi! Tôi là trợ lý của bạn trong việc trả lời các câu hỏi.")
|
435 |
# text = st.text_input("User: ", key="input")
|
|
|
474 |
answer, context = chatRoberta(text)
|
475 |
result = answer[0]['answer']
|
476 |
if result == "":
|
477 |
+
return generate_question(text)
|
478 |
return result
|
479 |
|
480 |
st.title("General Law Chatbot")
|