File size: 2,872 Bytes
eb52945
0fdee25
eb52945
 
 
 
 
 
 
 
9b3f2e9
8983152
eb52945
 
 
 
 
 
 
 
 
 
 
 
 
8983152
eb52945
 
 
 
 
 
 
8983152
 
eb52945
 
9b3f2e9
eb52945
 
 
 
 
9b3f2e9
eb52945
 
 
 
 
 
8983152
eb52945
 
 
8983152
 
eb52945
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b3f2e9
 
eb52945
9b3f2e9
 
8983152
9b3f2e9
 
 
 
eb52945
9b3f2e9
 
 
 
ee4e058
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_qdrant import QdrantVectorStore
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from operator import itemgetter
import chainlit as cl

# Load the documents
pdf_loader_NIST = PyMuPDFLoader(file_path="data/NIST.AI.600-1.pdf").load()
pdf_loader_Blueprint = PyMuPDFLoader(file_path="data/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
documents = pdf_loader_NIST + pdf_loader_Blueprint

# Split the documents
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=40,
    length_function=len,
    is_separator_regex=False
)
rag_documents = text_splitter.split_documents(documents)

# Create the vector store
# @cl.cache_resource
@cl.on_chat_start
async def start_chat():
    LOCATION = ":memory:"
    COLLECTION_NAME = "Implications of AI"
    VECTOR_SIZE = 1536


    embeddings = OpenAIEmbeddings()
    qdrant_client = QdrantClient(location=LOCATION)

    # Create the collection
    qdrant_client.create_collection(
        collection_name=COLLECTION_NAME,
        vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
    )

    # Create the vector store
    vectorstore = QdrantVectorStore(
        client=qdrant_client,
        collection_name=COLLECTION_NAME,
        embedding=embeddings
    )

    # Load and add documents
    vectorstore.add_documents(rag_documents)
    retriever = vectorstore.as_retriever()

    
    template = """
    Use the provided context to answer the user's query.
    You may not answer the user's query unless there is specific context in the following text.
    If you do not know the answer, or cannot answer, please respond with "I don't know".
    Question:
    {question}
    Context:
    {context}
    Answer:
    """

    prompt = ChatPromptTemplate.from_template(template)
    base_llm = ChatOpenAI(model_name="gpt-4", temperature=0)

    retrieval_augmented_qa_chain = (
        {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
        | RunnablePassthrough.assign(context=itemgetter("context"))
        | {"response": prompt | base_llm, "context": itemgetter("context")}
    )

    cl.user_session.set("chain", retrieval_augmented_qa_chain)


@cl.on_message
async def main(message):
    chain = cl.user_session.get("chain")

    msg = cl.Message(content="")
    result = await chain.invoke(message.content)

    async for stream_resp in result["response"]:
        await msg.stream_token(stream_resp)

    await msg.send()