Lautaro Cardarelli commited on
Commit
e9e44ae
1 Parent(s): f555fb0

add spanish qa answer

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -96,22 +96,25 @@ def generate_summary(text):
96
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
97
  return summary
98
 
 
99
  # QA
100
- # Cargar el modelo para preguntas y respuestas
101
- qa_model_name = "MaRiOrOsSi/t5-base-finetuned-question-answering"
102
- qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
103
- qa_model = AutoModelForSeq2SeqLM.from_pretrained(qa_model_name)
104
 
105
 
106
  def generate_question_response(question, context):
107
- # Formar el input para el modelo de preguntas y respuestas
108
- input_text = f"question: {question} context: {context}"
109
- encoded_input = qa_tokenizer(input_text, return_tensors='pt', max_length=1024, truncation=True)
110
- output = qa_model.generate(input_ids=encoded_input['input_ids'], attention_mask=encoded_input['attention_mask'])
111
- response_en = qa_tokenizer.decode(output[0], skip_special_tokens=True)
112
- translator = Translator()
113
- translated_response = translator.translate(response_en, dest='es').text
114
- return f'Respuesta: {translated_response}'
 
 
115
 
116
 
117
  class SummarizerAndQA:
 
96
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
97
  return summary
98
 
99
+
100
  # QA
101
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102
+ ckpt = 'mrm8488/spanish-t5-small-sqac-for-qa'
103
+ qa_tokenizer = AutoTokenizer.from_pretrained(ckpt)
104
+ qa_model = T5ForConditionalGeneration.from_pretrained(ckpt).to(device)
105
 
106
 
107
  def generate_question_response(question, context):
108
+ input_text = 'question: %s context: %s' % (question, context)
109
+ features = tokenizer([input_text], padding='max_length', truncation=True, max_length=512, return_tensors='pt')
110
+ output = qa_model.generate(
111
+ input_ids=features['input_ids'].to(device),
112
+ attention_mask=features['attention_mask'].to(device),
113
+ max_length=200, # Permite respuestas más largas
114
+ temperature=1.0 # Ajusta la temperatura
115
+ )
116
+
117
+ return qa_tokenizer.decode(output[0], skip_special_tokens=True)
118
 
119
 
120
  class SummarizerAndQA: