import os import streamlit as st from pymilvus import MilvusClient from model import encode_dpr_question, get_dpr_encoder from model import summarize_text, get_summarizer from model import ask_reader, get_reader TITLE = 'ReSRer: Retriever-Summarizer-Reader' INITIAL = "What is the population of NYC" st.set_page_config(page_title=TITLE) st.header(TITLE) st.markdown(''' ### Ask short-answer question that can be find in Wikipedia data. ''', unsafe_allow_html=True) @st.cache_resource def load_models(): models = {} models['encoder'] = get_dpr_encoder() models['summarizer'] = get_summarizer() models['reader'] = get_reader() return models @st.cache_resource def load_client(): client = MilvusClient(user='resrer', password=os.env['MILVUS_PW'], uri=f"http://{os.env['MILVUS_HOST']}:19530", db_name='psgs_w100') return client client = load_client() models = load_models() styl = """ """ st.markdown(styl, unsafe_allow_html=True) question = st.text_area("Text to summarize", INITIAL, height=400) def main(question: str): if question in st.session_state: print("Cache hit!") ctx, summary, answer = st.session_state[question] else: print(f"Input: {question}") # Embedding question_vectors = encode_dpr_question( models['encoder'][0], models['encoder'][1], [question]) query_vector = question_vectors.detach().cpu().numpy().tolist()[0] # Retriever results = client.search(collection_name='dpr_nq', data=[ query_vector], limit=10, output_fields=['title', 'text']) texts = [result['entity']['text'] for result in results[0]] ctx = '\n'.join(texts) # Reader summary = summarize_text(models['summarizer'][0], models['summarizer'][1], [summary]) answers = ask_reader(models['reader'][0], models['reader'][1], [question], [ctx]) answer = answers[0]['answer'] print(f"\nAnswer: {answer}") st.session_state[question] = (ctx, summary, answer) # Summary st.markdown(answer) st.write("## Summary") st.markdown( f"