seonglae commited on
Commit
75148a1
1 Parent(s): c2a71e9

feat: huggingface space pipeline with resrer model

Browse files
Files changed (3) hide show
  1. app.py +97 -2
  2. model.py +86 -0
  3. requirements.txt +3 -0
app.py CHANGED
@@ -1,2 +1,97 @@
1
- x = st.slider('Select a value')
2
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import streamlit as st
4
+ from pymilvus import MilvusClient
5
+
6
+ from model import encode_dpr_question, get_dpr_encoder
7
+ from model import summarize_text, get_summarizer
8
+ from model import ask_reader, get_reader
9
+
10
+
11
+ TITLE = 'ReSRer: Retriever-Summarizer-Reader'
12
+ INITIAL = "What is the population of NYC"
13
+
14
+ st.set_page_config(page_title=TITLE)
15
+ st.header(TITLE)
16
+ st.markdown('''
17
+ ### Ask short-answer question that can be find in Wikipedia data.
18
+ ''', unsafe_allow_html=True)
19
+
20
+
21
+ @st.cache_resource
22
+ def load_models():
23
+ models = {}
24
+ models['encoder'] = get_dpr_encoder()
25
+ models['summarizer'] = get_summarizer()
26
+ models['reader'] = get_reader()
27
+ return models
28
+
29
+
30
+ @st.cache_resource
31
+ def load_client():
32
+ client = MilvusClient(user='resrer', password=os.env['MILVUS_PW'],
33
+ uri=f"http://{os.env['MILVUS_HOST']}:19530", db_name='psgs_w100')
34
+ return client
35
+
36
+
37
+ client = load_client()
38
+ models = load_models()
39
+
40
+ styl = """
41
+ <style>
42
+ .StatusWidget-enter-done{
43
+ position: fixed;
44
+ left: 50%;
45
+ top: 50%;
46
+ transform: translate(-50%, -50%);
47
+ }
48
+ .StatusWidget-enter-done button{
49
+ display: none;
50
+ }
51
+ </style>
52
+ """
53
+ st.markdown(styl, unsafe_allow_html=True)
54
+
55
+
56
+ question = st.text_area("Text to summarize", INITIAL, height=400)
57
+
58
+
59
+ def main(question: str):
60
+ if question in st.session_state:
61
+ print("Cache hit!")
62
+ ctx, summary, answer = st.session_state[question]
63
+ else:
64
+ print(f"Input: {question}")
65
+ # Embedding
66
+ question_vectors = encode_dpr_question(
67
+ models['encoder'][0], models['encoder'][1], [question])
68
+ query_vector = question_vectors.detach().cpu().numpy().tolist()[0]
69
+
70
+ # Retriever
71
+ results = client.search(collection_name='dpr_nq', data=[
72
+ query_vector], limit=10, output_fields=['title', 'text'])
73
+ texts = [result['entity']['text'] for result in results[0]]
74
+ ctx = '\n'.join(texts)
75
+
76
+ # Reader
77
+ summary = summarize_text(models['summarizer'][0],
78
+ models['summarizer'][1], [summary])
79
+ answers = ask_reader(models['reader'][0],
80
+ models['reader'][1], [question], [ctx])
81
+ answer = answers[0]['answer']
82
+ print(f"\nAnswer: {answer}")
83
+
84
+ st.session_state[question] = (ctx, summary, answer)
85
+
86
+ # Summary
87
+ st.markdown(answer)
88
+ st.write("## Summary")
89
+ st.markdown(
90
+ f"<h6 style='padding: 0'>{summary}</h6><hr style='margin: 1em 0px'>", unsafe_allow_html=True)
91
+ st.markdown(ctx)
92
+
93
+ st.write(f"{question}", unsafe_allow_html=True)
94
+
95
+
96
+ if question:
97
+ main(question)
model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, TypedDict
2
+ from re import sub
3
+
4
+ from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer, logging
5
+ from transformers import AutoModelForQuestionAnswering, DPRReaderTokenizer, DPRReader
6
+ from transformers import QuestionAnsweringPipeline
7
+ from transformers import AutoTokenizer, PegasusXForConditionalGeneration, PegasusTokenizerFast
8
+ import torch
9
+
10
+ max_answer_len = 8
11
+ logging.set_verbosity_error()
12
+
13
+
14
+ def summarize_text(tokenizer: PegasusTokenizerFast, model: PegasusXForConditionalGeneration,
15
+ input_texts: List[str]):
16
+ inputs = tokenizer(input_texts, padding=True,
17
+ return_tensors='pt', truncation=True).to(1)
18
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
19
+ summary_ids = model.generate(inputs["input_ids"])
20
+ summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True,
21
+ clean_up_tokenization_spaces=False, batch_size=len(input_texts))
22
+ return summaries
23
+
24
+
25
+ def get_summarizer(model_id="seonglae/resrer") -> Tuple[PegasusTokenizerFast, PegasusXForConditionalGeneration]:
26
+ tokenizer = PegasusTokenizerFast.from_pretrained(model_id)
27
+ model = PegasusXForConditionalGeneration.from_pretrained(model_id).to(1)
28
+ model = torch.compile(model)
29
+ return tokenizer, model
30
+
31
+
32
+ # OpenAI reader
33
+
34
+
35
+ class AnswerInfo(TypedDict):
36
+ score: float
37
+ start: int
38
+ end: int
39
+ answer: str
40
+
41
+
42
+ @torch.inference_mode()
43
+ def ask_reader(tokenizer: AutoTokenizer, model: AutoModelForQuestionAnswering,
44
+ questions: List[str], ctxs: List[str]) -> List[AnswerInfo]:
45
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
46
+ pipeline = QuestionAnsweringPipeline(
47
+ model=model, tokenizer=tokenizer, device='cuda', max_answer_len=max_answer_len)
48
+ answer_infos: List[AnswerInfo] = pipeline(
49
+ question=questions, context=ctxs)
50
+ for answer_info in answer_infos:
51
+ answer_info['answer'] = sub(r'[.\(\)"\',]', '', answer_info['answer'])
52
+ return answer_infos
53
+
54
+
55
+ def get_reader(model_id="mrm8488/longformer-base-4096-finetuned-squadv2"):
56
+ tokenizer = DPRReaderTokenizer.from_pretrained(model_id)
57
+ model = DPRReader.from_pretrained(model_id).to(0)
58
+ return tokenizer, model
59
+
60
+
61
+ def encode_dpr_question(tokenizer: DPRQuestionEncoderTokenizer, model: DPRQuestionEncoder, questions: List[str]) -> torch.FloatTensor:
62
+ """Encode a question using DPR question encoder.
63
+ https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
64
+
65
+ Args:
66
+ question (str): question string to encode
67
+ model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
68
+ """
69
+ batch_dict = tokenizer(questions, return_tensors="pt",
70
+ padding=True, truncation=True,).to(0)
71
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
72
+ embeddings: torch.FloatTensor = model(**batch_dict).pooler_output
73
+ return embeddings
74
+
75
+
76
+ def get_dpr_encoder(model_id="facebook/dpr-question_encoder-single-nq-base") -> Tuple[DPRQuestionEncoder, DPRQuestionEncoderTokenizer]:
77
+ """Encode a question using DPR question encoder.
78
+ https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DPRQuestionEncoder
79
+
80
+ Args:
81
+ question (str): question string to encode
82
+ model_id (str, optional): Default for NQ or "facebook/dpr-question_encoder-multiset-base
83
+ """
84
+ tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(model_id)
85
+ model = DPRQuestionEncoder.from_pretrained(model_id).to(0)
86
+ return tokenizer, model
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ pymilvus