BioMedIA / app.py
avacaondata's picture
añadidas descripciones de campos en español, hagámoslo más accesible!
9f01174
raw
history blame
7.79 kB
from datasets import load_dataset
from transformers import (
DPRQuestionEncoder,
DPRQuestionEncoderTokenizer,
MT5ForConditionalGeneration,
AutoTokenizer,
AutoModelForCTC,
Wav2Vec2Tokenizer,
)
from general_utils import (
embed_questions,
transcript,
remove_chars_to_tts,
parse_final_answer,
)
from typing import List
import gradio as gr
from article_app import article, description, examples
from haystack.nodes import DensePassageRetriever
from haystack.document_stores import InMemoryDocumentStore
import numpy as np
from sentence_transformers import SentenceTransformer, util, CrossEncoder
topk = 21
minchars = 200
min_snippet_length = 20
device = "cpu"
covidterms = ["covid19", "covid", "coronavirus", "covid-19", "sars-cov-2"]
models = {
"wav2vec2-iic": {
"processor": Wav2Vec2Tokenizer.from_pretrained(
"IIC/wav2vec2-spanish-multilibrispeech"
),
"model": AutoModelForCTC.from_pretrained(
"IIC/wav2vec2-spanish-multilibrispeech"
),
},
}
tts_es = gr.Interface.load("huggingface/facebook/tts_transformer-es-css10")
params_generate = {
"min_length": 50,
"max_length": 250,
"do_sample": False,
"early_stopping": True,
"num_beams": 8,
"temperature": 1.0,
"top_k": None,
"top_p": None,
"no_repeat_ngram_size": 3,
"num_return_sequences": 1,
}
dpr = DensePassageRetriever(
document_store=InMemoryDocumentStore(),
query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base",
passage_embedding_model="IIC/dpr-spanish-passage_encoder-allqa-base",
max_seq_len_query=64,
max_seq_len_passage=256,
batch_size=512,
use_gpu=False,
)
mt5_tokenizer = AutoTokenizer.from_pretrained("IIC/mt5-base-lfqa-es")
mt5_lfqa = MT5ForConditionalGeneration.from_pretrained("IIC/mt5-base-lfqa-es")
similarity_model = SentenceTransformer(
"distiluse-base-multilingual-cased", device="cpu"
)
crossencoder = CrossEncoder("IIC/roberta-base-bne-ranker", device="cpu")
dataset = load_dataset("IIC/spanish_biomedical_crawled_corpus", split="train")
dataset = dataset.filter(lambda example: len(example["text"]) > minchars)
dataset.load_faiss_index(
"embeddings",
"dpr_index_bio_newdpr.faiss",
)
def query_index(question: str):
question_embedding = dpr.embed_queries([question])[0]
scores, closest_passages = dataset.get_nearest_examples(
"embeddings", question_embedding, k=topk
)
contexts = [
closest_passages["text"][i] for i in range(len(closest_passages["text"]))
]# [:int(topk / 3)]
return [
context for context in contexts if len(context.split()) > min_snippet_length
]
def sort_on_similarity(question, contexts, include_rank: int = 5):
question_encoded = similarity_model.encode([question])[0]
ctxs_encoded = similarity_model.encode(contexts)
sim_scores_ss = [
util.cos_sim(question_encoded, ctx_encoded) for ctx_encoded in ctxs_encoded
]
text_pairs = [[question, ctx] for ctx in contexts]
similarity_scores = crossencoder.predict(text_pairs)
similarity_scores = np.array(sim_scores_ss) * similarity_scores
similarity_ranking_idx = np.flip(np.argsort(similarity_scores))
return [contexts[idx] for idx in similarity_ranking_idx][:include_rank]
def create_context(contexts: List):
return "<p>" + "<p>".join(contexts)
def create_model_input(question: str, context: str):
return f"question: {question} context: {context}"
def generate_answer(model_input, update_params):
model_input = mt5_tokenizer(
model_input, truncation=True, padding=True, return_tensors="pt", max_length=1024
)
params_generate.update(update_params)
answers_encoded = mt5_lfqa.generate(
input_ids=model_input["input_ids"].to(device),
attention_mask=model_input["attention_mask"].to(device),
**params_generate,
)
answers = mt5_tokenizer.batch_decode(
answers_encoded, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
results = [{"generated_text": answer} for answer in answers]
return results
def search_and_answer(
question,
audio_file,
audio_array,
min_length_answer,
num_beams,
no_repeat_ngram_size,
temperature,
max_answer_length,
do_tts,
):
update_params = {
"min_length": min_length_answer,
"max_length": max_answer_length,
"num_beams": int(num_beams),
"temperature": temperature,
"no_repeat_ngram_size": no_repeat_ngram_size,
}
if not question:
s2t_model = models["wav2vec2-iic"]["model"]
s2t_processor = models["wav2vec2-iic"]["processor"]
question = transcript(
audio_file, audio_array, processor=s2t_processor, model=s2t_model
)
print(f"Transcripted question: *** {question} ****")
if any([any([term in word.lower() for term in covidterms]) for word in question.split(" ")]):
return "Del COVID no queremos saber ya más nada, lo sentimos, pregúntame sobre otra cosa :P ", "ni contexto ni contexta.", "audio_troll.flac"
contexts = query_index(question)
contexts = sort_on_similarity(question, contexts)
context = create_context(contexts)
model_input = create_model_input(question, context)
answers = generate_answer(model_input, update_params)
final_answer = answers[0]["generated_text"]
if do_tts:
audio_answer = tts_es(remove_chars_to_tts(final_answer))
final_answer, documents = parse_final_answer(final_answer, contexts)
return final_answer, documents, audio_answer if do_tts else "audio_troll.flac"
if __name__ == "__main__":
gr.Interface(
search_and_answer,
inputs=[
gr.inputs.Textbox(
lines=2,
label="Pregunta",
placeholder="Escribe aquí tu pregunta",
optional=True,
),
gr.inputs.Audio(
source="upload",
type="filepath",
label="Sube un audio con tu respuesta aquí si quieres.",
optional=True,
),
gr.inputs.Audio(
source="microphone",
type="numpy",
label="Graba aquí un audio con tu pregunta.",
optional=True,
),
gr.inputs.Slider(
minimum=10,
maximum=200,
default=50,
label="Minimum size for the answer",
step=1,
),
gr.inputs.Slider(
minimum=4, maximum=12, default=8, label="number of beams", step=1
),
gr.inputs.Slider(
minimum=2, maximum=5, default=3, label="no repeat n-gram size", step=1
),
gr.inputs.Slider(
minimum=0.8, maximum=2.0, default=1.0, label="temperature", step=0.1
),
gr.inputs.Slider(
minimum=220,
maximum=360,
default=250,
label="maximum answer length",
step=1,
),
gr.inputs.Checkbox(
default=False, label="Text to Speech", optional=True),
],
outputs=[
gr.outputs.HTML(
label="Respuesta generada."
),
gr.outputs.HTML(
label="Documentos utilizados."
),
gr.outputs.Audio(label="Respuesta en audio."),
],
description=description,
examples=examples,
theme="grass",
article=article,
thumbnail="IIC_logoP.png",
css="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css",
).launch()