policy-rag / app.py
lw2134's picture
Create app.py
c62a3d3 verified
import os
from typing import List
from operator import itemgetter
from langchain_openai import ChatOpenAI
from langchain_core.vectorstores import VectorStoreRetriever
from langchain.docstore.document import Document
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
import streamlit as st
import sys
sys.path.append('..')
from policy_rag.vectorstore_utils import QdrantVectorstoreHelper
from policy_rag.chains import get_qa_chain
from policy_rag.app_utils import get_embedding_model
# Qdrant and Model Setup
QDRANT_COLLECTION = 'policy-embed-te3-large-plus'
VECTORSTORE_MODEL = {
'model_source': 'openai',
'model_name': 'text-embedding-3-large',
'vector_size': 3072
}
K = 5
# Initialize session state
if 'rag_qa_chain' not in st.session_state:
qdrant_retriever = QdrantVectorstoreHelper().get_retriever(
collection_name=QDRANT_COLLECTION,
embedding_model=get_embedding_model(VECTORSTORE_MODEL),
k=K
)
st.session_state.rag_qa_chain = get_qa_chain(retriever=qdrant_retriever, streaming=True)
# User input for question
st.title("AI Policy QA System")
user_input = st.text_input("Ask a question about AI policy:")
if st.button("Submit"):
rag_qa_chain = st.session_state.rag_qa_chain
if user_input:
st.write(f"**You asked:** {user_input}")
msg = ""
# Generate the answer and context asynchronously
with st.spinner("Generating answer..."):
for chunk in rag_qa_chain.astream({"question": user_input}):
if "answer" in chunk:
st.write(f"**Answer:** {chunk['answer'].content}")
if "contexts" in chunk:
st.write("**Sources:**")
for doc in chunk["contexts"]:
st.write(f"Page {doc.metadata['page']} - {doc.metadata['title']}")