Lautaro Cardarelli
add fix for qa
e396b2c
import torch
import gradio as gr
from googletrans import Translator
from transformers import T5Tokenizer
from transformers import T5ForConditionalGeneration
from transformers import BartForConditionalGeneration
from transformers import BartTokenizer
from transformers import PreTrainedModel
from transformers import PreTrainedTokenizer
from transformers import AutoTokenizer
# Question launcher
class E2EQGPipeline:
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer
):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = model
self.tokenizer = tokenizer
self.model_type = "t5"
self.kwargs = {
"max_length": 256,
"num_beams": 4,
"length_penalty": 1.5,
"no_repeat_ngram_size": 3,
"early_stopping": True,
}
def generate_questions(self, context: str):
inputs = self._prepare_inputs_for_e2e_qg(context)
outs = self.model.generate(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
**self.kwargs
)
prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True)
questions = prediction.split("<sep>")
questions = [question.strip() for question in questions[:-1]]
return questions
def _prepare_inputs_for_e2e_qg(self, context):
source_text = f"generate questions: {context}"
inputs = self._tokenize([source_text], padding=False)
return inputs
def _tokenize(
self,
inputs,
padding=True,
truncation=True,
add_special_tokens=True,
max_length=512
):
inputs = self.tokenizer.batch_encode_plus(
inputs,
max_length=max_length,
add_special_tokens=add_special_tokens,
truncation=truncation,
padding="max_length" if padding else False,
pad_to_max_length=padding,
return_tensors="pt"
)
return inputs
qg_model = T5ForConditionalGeneration.from_pretrained('valhalla/t5-base-e2e-qg')
qg_tokenizer = T5Tokenizer.from_pretrained('valhalla/t5-base-e2e-qg')
def generate_questions(text):
qg_final_model = E2EQGPipeline(qg_model, qg_tokenizer)
questions = qg_final_model.generate_questions(text)
translator = Translator()
translated_questions = [translator.translate(question, dest='es').text for question in questions]
return translated_questions
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
def generate_summary(text):
inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True)
summary_ids = model.generate(inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
# QA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt = 'mrm8488/spanish-t5-small-sqac-for-qa'
qa_tokenizer = AutoTokenizer.from_pretrained(ckpt)
qa_model = T5ForConditionalGeneration.from_pretrained(ckpt).to(device)
def generate_question_response(question, context):
input_text = 'question: %s context: %s' % (question, context)
print(input_text)
features = qa_tokenizer([input_text], padding='max_length', truncation=True, max_length=512, return_tensors='pt')
output = qa_model.generate(
input_ids=features['input_ids'].to(device),
attention_mask=features['attention_mask'].to(device),
temperature=1.0
)
return qa_tokenizer.decode(output[0], skip_special_tokens=True)
class SummarizerAndQA:
def __init__(self):
self.input_text = ''
self.question = ''
self.summary = ''
self.study_generated_questions = ''
self.question_response = ''
def process(self, text, question):
if text != self.input_text:
self.input_text = text
self.summary = generate_summary(text)
self.study_generated_questions = generate_questions(text)
if question != self.question and text != '':
self.question = question
self.question_response = generate_question_response(question, text)
return self.summary, self.study_generated_questions, self.question_response
summarizer_and_qa = SummarizerAndQA()
textbox_input = gr.Textbox(label="Pega el text aca:", placeholder="Texto...", lines=15)
question_input = gr.Textbox(label="Pregunta sobre el texto aca:", placeholder="Mensaje...", lines=15)
summary_output = gr.Textbox(label="Resumen", lines=15)
questions_output = gr.Textbox(label="Preguntas de guia generadas", lines=5)
questions_response = gr.Textbox(label="Respuestas", lines=5)
demo = gr.Interface(fn=summarizer_and_qa.process, inputs=[textbox_input, question_input], outputs=[summary_output, questions_output, questions_response], allow_flagging="never")
demo.launch()