from typing import List, Tuple, TypedDict from re import sub from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, logging from transformers import AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader from transformers import QuestionAnsweringPipeline from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast import torch max_answer_len = 8 logging.set_verbosity_error() def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration, input_texts: List[str]): inputs = tokenizer(input_texts, padding=True, return_tensors='pt', truncation=True).to(1) with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): summary_ids = model.generate(inputs["input_ids"]) summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, batch_size=len(input_texts)) return summaries def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]: tokenizer = PegasusTokenizerFast.from_pretrained(model_id) model = PegasusXForConditionalGeneration.from_pretrained(model_id).to(1) model = torch.compile(model) return tokenizer, model # OpenAI reader class AnswerInfo(TypedDict): score: float start: int end: int answer: str @torch.inference_mode() def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering, questions: List[str], ctxs: List[str]) -> List[AnswerInfo]: with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): pipeline = QuestionAnsweringPipeline( model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len) answer_infos: List[AnswerInfo] = pipeline( question=questions, context=ctxs) for answer_info in answer_infos: answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer']) return answer_infos def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"): tokenizer = DPRReaderTokenizer.from_pretrained(model_id) model = DPRReader.from_pretrained(model_id).to(0) return tokenizer, model def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor: """Encode a question using DPR question encoder. https://huggingface.co./docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder Args: question (str): question string to encode model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base """ batch_dict = tokenizer(questions, return_tensors="pt", padding=True, truncation=True,).to(0) with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): embeddings: torch.FloatTensor = model(**batch_dict).pooler_output return embeddings def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") -> Tuple[DPRQuestionEncoder, DPRQuestionEncoderTokenizer]: """Encode a question using DPR question encoder. https://huggingface.co./docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder Args: question (str): question string to encode model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base """ tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id) model = DPRQuestionEncoder.from_pretrained(model_id).to(0) return tokenizer, model