File size: 2,100 Bytes
c62a3d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from typing import List
from operator import itemgetter

from langchain_openai import ChatOpenAI
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.docstore.document import Document
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

import streamlit as st
import sys
sys.path.append('..')
from policy_rag.vectorstore_utils import QdrantVectorstoreHelper
from policy_rag.chains import get_qa_chain
from policy_rag.app_utils import get_embedding_model

# Qdrant and Model Setup
QDRANT_COLLECTION = 'policy-embed-te3-large-plus'
VECTORSTORE_MODEL = {
    'model_source': 'openai',
    'model_name': 'text-embedding-3-large',
    'vector_size': 3072
}
K = 5

# Initialize session state
if 'rag_qa_chain' not in st.session_state:
    qdrant_retriever = QdrantVectorstoreHelper().get_retriever(
        collection_name=QDRANT_COLLECTION,
        embedding_model=get_embedding_model(VECTORSTORE_MODEL),
        k=K
    )
    st.session_state.rag_qa_chain = get_qa_chain(retriever=qdrant_retriever, streaming=True)

# User input for question
st.title("AI Policy QA System")
user_input = st.text_input("Ask a question about AI policy:")

if st.button("Submit"):
    rag_qa_chain = st.session_state.rag_qa_chain
    
    if user_input:
        st.write(f"**You asked:** {user_input}")
        msg = ""

        # Generate the answer and context asynchronously
        with st.spinner("Generating answer..."):
            for chunk in rag_qa_chain.astream({"question": user_input}):
                if "answer" in chunk:
                    st.write(f"**Answer:** {chunk['answer'].content}")
                if "contexts" in chunk:
                    st.write("**Sources:**")
                    for doc in chunk["contexts"]:
                        st.write(f"Page {doc.metadata['page']} - {doc.metadata['title']}")