Spaces:
Build error
Build error
from datasets import load_dataset | |
from transformers import ( | |
DPRQuestionEncoder, | |
DPRQuestionEncoderTokenizer, | |
MT5ForConditionalGeneration, | |
AutoTokenizer, | |
AutoModelForCTC, | |
Wav2Vec2Tokenizer, | |
) | |
from general_utils import ( | |
embed_questions, | |
transcript, | |
remove_chars_to_tts, | |
parse_final_answer, | |
) | |
from typing import List | |
import gradio as gr | |
from article_app import article, description, examples | |
from haystack.nodes import DensePassageRetriever | |
from haystack.document_stores import InMemoryDocumentStore | |
import numpy as np | |
from sentence_transformers import SentenceTransformer, util, CrossEncoder | |
topk = 21 | |
minchars = 200 | |
min_snippet_length = 20 | |
device = "cpu" | |
covidterms = ["covid19", "covid", "coronavirus", "covid-19", "sars-cov-2"] | |
models = { | |
"wav2vec2-iic": { | |
"processor": Wav2Vec2Tokenizer.from_pretrained( | |
"IIC/wav2vec2-spanish-multilibrispeech" | |
), | |
"model": AutoModelForCTC.from_pretrained( | |
"IIC/wav2vec2-spanish-multilibrispeech" | |
), | |
}, | |
} | |
tts_es = gr.Interface.load("huggingface/facebook/tts_transformer-es-css10") | |
params_generate = { | |
"min_length": 50, | |
"max_length": 250, | |
"do_sample": False, | |
"early_stopping": True, | |
"num_beams": 8, | |
"temperature": 1.0, | |
"top_k": None, | |
"top_p": None, | |
"no_repeat_ngram_size": 3, | |
"num_return_sequences": 1, | |
} | |
dpr = DensePassageRetriever( | |
document_store=InMemoryDocumentStore(), | |
query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", | |
passage_embedding_model="IIC/dpr-spanish-passage_encoder-allqa-base", | |
max_seq_len_query=64, | |
max_seq_len_passage=256, | |
batch_size=512, | |
use_gpu=False, | |
) | |
mt5_tokenizer = AutoTokenizer.from_pretrained("IIC/mt5-base-lfqa-es") | |
mt5_lfqa = MT5ForConditionalGeneration.from_pretrained("IIC/mt5-base-lfqa-es") | |
similarity_model = SentenceTransformer( | |
"distiluse-base-multilingual-cased", device="cpu" | |
) | |
crossencoder = CrossEncoder("IIC/roberta-base-bne-ranker", device="cpu") | |
dataset = load_dataset("IIC/spanish_biomedical_crawled_corpus", split="train") | |
dataset = dataset.filter(lambda example: len(example["text"]) > minchars) | |
dataset.load_faiss_index( | |
"embeddings", | |
"dpr_index_bio_newdpr.faiss", | |
) | |
def query_index(question: str): | |
question_embedding = dpr.embed_queries([question])[0] | |
scores, closest_passages = dataset.get_nearest_examples( | |
"embeddings", question_embedding, k=topk | |
) | |
contexts = [ | |
closest_passages["text"][i] for i in range(len(closest_passages["text"])) | |
]# [:int(topk / 3)] | |
return [ | |
context for context in contexts if len(context.split()) > min_snippet_length | |
] | |
def sort_on_similarity(question, contexts, include_rank: int = 5): | |
question_encoded = similarity_model.encode([question])[0] | |
ctxs_encoded = similarity_model.encode(contexts) | |
sim_scores_ss = [ | |
util.cos_sim(question_encoded, ctx_encoded) for ctx_encoded in ctxs_encoded | |
] | |
text_pairs = [[question, ctx] for ctx in contexts] | |
similarity_scores = crossencoder.predict(text_pairs) | |
similarity_scores = np.array(sim_scores_ss) * similarity_scores | |
similarity_ranking_idx = np.flip(np.argsort(similarity_scores)) | |
return [contexts[idx] for idx in similarity_ranking_idx][:include_rank] | |
def create_context(contexts: List): | |
return "<p>" + "<p>".join(contexts) | |
def create_model_input(question: str, context: str): | |
return f"question: {question} context: {context}" | |
def generate_answer(model_input, update_params): | |
model_input = mt5_tokenizer( | |
model_input, truncation=True, padding=True, return_tensors="pt", max_length=1024 | |
) | |
params_generate.update(update_params) | |
answers_encoded = mt5_lfqa.generate( | |
input_ids=model_input["input_ids"].to(device), | |
attention_mask=model_input["attention_mask"].to(device), | |
**params_generate, | |
) | |
answers = mt5_tokenizer.batch_decode( | |
answers_encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
) | |
results = [{"generated_text": answer} for answer in answers] | |
return results | |
def search_and_answer( | |
question, | |
audio_file, | |
audio_array, | |
min_length_answer, | |
num_beams, | |
no_repeat_ngram_size, | |
temperature, | |
max_answer_length, | |
do_tts, | |
): | |
update_params = { | |
"min_length": min_length_answer, | |
"max_length": max_answer_length, | |
"num_beams": int(num_beams), | |
"temperature": temperature, | |
"no_repeat_ngram_size": no_repeat_ngram_size, | |
} | |
if not question: | |
s2t_model = models["wav2vec2-iic"]["model"] | |
s2t_processor = models["wav2vec2-iic"]["processor"] | |
question = transcript( | |
audio_file, audio_array, processor=s2t_processor, model=s2t_model | |
) | |
print(f"Transcripted question: *** {question} ****") | |
if any([any([term in word.lower() for term in covidterms]) for word in question.split(" ")]): | |
return "Del COVID no queremos saber ya más nada, lo sentimos, pregúntame sobre otra cosa :P ", "ni contexto ni contexta.", "audio_troll.flac" | |
contexts = query_index(question) | |
contexts = sort_on_similarity(question, contexts) | |
context = create_context(contexts) | |
model_input = create_model_input(question, context) | |
answers = generate_answer(model_input, update_params) | |
final_answer = answers[0]["generated_text"] | |
if do_tts: | |
audio_answer = tts_es(remove_chars_to_tts(final_answer)) | |
final_answer, documents = parse_final_answer(final_answer, contexts) | |
return final_answer, documents, audio_answer if do_tts else "audio_troll.flac" | |
if __name__ == "__main__": | |
gr.Interface( | |
search_and_answer, | |
inputs=[ | |
gr.inputs.Textbox( | |
lines=2, | |
label="Pregunta", | |
placeholder="Escribe aquí tu pregunta", | |
optional=True, | |
), | |
gr.inputs.Audio( | |
source="upload", | |
type="filepath", | |
label="Sube un audio con tu respuesta aquí si quieres.", | |
optional=True, | |
), | |
gr.inputs.Audio( | |
source="microphone", | |
type="numpy", | |
label="Graba aquí un audio con tu pregunta.", | |
optional=True, | |
), | |
gr.inputs.Slider( | |
minimum=10, | |
maximum=200, | |
default=50, | |
label="Minimum size for the answer", | |
step=1, | |
), | |
gr.inputs.Slider( | |
minimum=4, maximum=12, default=8, label="number of beams", step=1 | |
), | |
gr.inputs.Slider( | |
minimum=2, maximum=5, default=3, label="no repeat n-gram size", step=1 | |
), | |
gr.inputs.Slider( | |
minimum=0.8, maximum=2.0, default=1.0, label="temperature", step=0.1 | |
), | |
gr.inputs.Slider( | |
minimum=220, | |
maximum=360, | |
default=250, | |
label="maximum answer length", | |
step=1, | |
), | |
gr.inputs.Checkbox( | |
default=False, label="Text to Speech", optional=True), | |
], | |
outputs=[ | |
gr.outputs.HTML( | |
label="Respuesta generada." | |
), | |
gr.outputs.HTML( | |
label="Documentos utilizados." | |
), | |
gr.outputs.Audio(label="Respuesta en audio."), | |
], | |
description=description, | |
examples=examples, | |
theme="grass", | |
article=article, | |
thumbnail="IIC_logoP.png", | |
css="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css", | |
).launch() | |