Spaces:
Sleeping
Sleeping
from typing import List, Tuple, TypedDict | |
from re import sub | |
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, logging | |
from transformers import AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader | |
from transformers import QuestionAnsweringPipeline | |
from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast | |
import torch | |
max_answer_len = 8 | |
logging.set_verbosity_error() | |
def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration, | |
input_texts: List[str]): | |
inputs = tokenizer(input_texts, padding=True, | |
return_tensors='pt', truncation=True).to(1) | |
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): | |
summary_ids = model.generate(inputs["input_ids"]) | |
summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, | |
clean_up_tokenization_spaces=False, batch_size=len(input_texts)) | |
return summaries | |
def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]: | |
tokenizer = PegasusTokenizerFast.from_pretrained(model_id) | |
model = PegasusXForConditionalGeneration.from_pretrained(model_id).to(1) | |
model = torch.compile(model) | |
return tokenizer, model | |
# OpenAI reader | |
class AnswerInfo(TypedDict): | |
score: float | |
start: int | |
end: int | |
answer: str | |
def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering, | |
questions: List[str], ctxs: List[str]) -> List[AnswerInfo]: | |
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): | |
pipeline = QuestionAnsweringPipeline( | |
model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len) | |
answer_infos: List[AnswerInfo] = pipeline( | |
question=questions, context=ctxs) | |
for answer_info in answer_infos: | |
answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer']) | |
return answer_infos | |
def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"): | |
tokenizer = DPRReaderTokenizer.from_pretrained(model_id) | |
model = DPRReader.from_pretrained(model_id).to(0) | |
return tokenizer, model | |
def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor: | |
"""Encode a question using DPR question encoder. | |
https://huggingface.co./docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder | |
Args: | |
question (str): question string to encode | |
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base | |
""" | |
batch_dict = tokenizer(questions, return_tensors="pt", | |
padding=True, truncation=True,).to(0) | |
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): | |
embeddings: torch.FloatTensor = model(**batch_dict).pooler_output | |
return embeddings | |
def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") -> Tuple[DPRQuestionEncoder, DPRQuestionEncoderTokenizer]: | |
"""Encode a question using DPR question encoder. | |
https://huggingface.co./docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder | |
Args: | |
question (str): question string to encode | |
model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base | |
""" | |
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id) | |
model = DPRQuestionEncoder.from_pretrained(model_id).to(0) | |
return tokenizer, model | |