Spaces:
Sleeping
Sleeping
import os | |
from typing import List | |
from operator import itemgetter | |
from langchain_openai import ChatOpenAI | |
from langchain_core.vectorstores import VectorStoreRetriever | |
from langchain.docstore.document import Document | |
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
import streamlit as st | |
import sys | |
sys.path.append('..') | |
from policy_rag.vectorstore_utils import QdrantVectorstoreHelper | |
from policy_rag.chains import get_qa_chain | |
from policy_rag.app_utils import get_embedding_model | |
# Qdrant and Model Setup | |
QDRANT_COLLECTION = 'policy-embed-te3-large-plus' | |
VECTORSTORE_MODEL = { | |
'model_source': 'openai', | |
'model_name': 'text-embedding-3-large', | |
'vector_size': 3072 | |
} | |
K = 5 | |
# Initialize session state | |
if 'rag_qa_chain' not in st.session_state: | |
qdrant_retriever = QdrantVectorstoreHelper().get_retriever( | |
collection_name=QDRANT_COLLECTION, | |
embedding_model=get_embedding_model(VECTORSTORE_MODEL), | |
k=K | |
) | |
st.session_state.rag_qa_chain = get_qa_chain(retriever=qdrant_retriever, streaming=True) | |
# User input for question | |
st.title("AI Policy QA System") | |
user_input = st.text_input("Ask a question about AI policy:") | |
if st.button("Submit"): | |
rag_qa_chain = st.session_state.rag_qa_chain | |
if user_input: | |
st.write(f"**You asked:** {user_input}") | |
msg = "" | |
# Generate the answer and context asynchronously | |
with st.spinner("Generating answer..."): | |
for chunk in rag_qa_chain.astream({"question": user_input}): | |
if "answer" in chunk: | |
st.write(f"**Answer:** {chunk['answer'].content}") | |
if "contexts" in chunk: | |
st.write("**Sources:**") | |
for doc in chunk["contexts"]: | |
st.write(f"Page {doc.metadata['page']} - {doc.metadata['title']}") | |