danicafisher commited on
Commit
f9769ad
1 Parent(s): 1179339

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -40
app.py CHANGED
@@ -11,16 +11,12 @@ from qdrant_client.http.models import Distance, VectorParams
11
  from operator import itemgetter
12
  import chainlit as cl
13
 
14
- # # Load the documents
15
- # pdf_loader_NIST = PyMuPDFLoader("data/NIST.AI.600-1.pdf").load()
16
- # pdf_loader_Blueprint = PyMuPDFLoader("data/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
17
- # documents = pdf_loader_NIST + pdf_loader_Blueprint
18
 
19
- # List to store all the documents
 
20
  documents = []
21
  directory = "data/"
22
 
23
- # Iterate through all the files in the directory
24
  for filename in os.listdir(directory):
25
  if filename.endswith(".pdf"): # Check if the file is a PDF
26
  file_path = os.path.join(directory, filename)
@@ -37,36 +33,24 @@ text_splitter = RecursiveCharacterTextSplitter(
37
  )
38
  rag_documents = text_splitter.split_documents(documents)
39
 
40
- # Create the vector store
41
- # @cl.cache_resource
42
- @cl.on_chat_start
43
- async def start_chat():
44
- LOCATION = ":memory:"
45
- COLLECTION_NAME = "Implications of AI"
46
- VECTOR_SIZE = 1536
47
-
48
 
49
- embeddings = OpenAIEmbeddings()
50
- qdrant_client = QdrantClient(location=LOCATION)
 
 
 
 
 
51
 
52
- # Create the collection
53
- qdrant_client.create_collection(
54
- collection_name=COLLECTION_NAME,
55
- vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
56
- )
57
 
58
- # Create the vector store
59
- vectorstore = QdrantVectorStore(
60
- client=qdrant_client,
61
- collection_name=COLLECTION_NAME,
62
- embedding=embeddings
63
- )
64
 
65
- # Load and add documents
66
- vectorstore.add_documents(rag_documents)
67
- retriever = vectorstore.as_retriever()
68
 
69
-
70
  template = """
71
  Use the provided context to answer the user's query.
72
  You may not answer the user's query unless there is specific context in the following text.
@@ -79,25 +63,23 @@ async def start_chat():
79
  """
80
 
81
  prompt = ChatPromptTemplate.from_template(template)
82
- base_llm = ChatOpenAI(model_name="gpt-4", temperature=0)
83
 
84
- retrieval_augmented_qa_chain = (
85
  {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
86
- | RunnablePassthrough.assign(context=itemgetter("context"))
87
- | {"response": prompt | base_llm, "context": itemgetter("context")}
88
  )
89
 
90
- cl.user_session.set("chain", retrieval_augmented_qa_chain)
91
 
92
 
93
  @cl.on_message
94
  async def main(message):
95
  chain = cl.user_session.get("chain")
 
96
 
97
- msg = cl.Message(content="")
98
- result = await chain.invoke(message.content)
99
 
100
- async for stream_resp in result["response"]:
101
- await msg.stream_token(stream_resp)
102
 
103
  await msg.send()
 
11
  from operator import itemgetter
12
  import chainlit as cl
13
 
 
 
 
 
14
 
15
+
16
+ # Load all the documents in the directory
17
  documents = []
18
  directory = "data/"
19
 
 
20
  for filename in os.listdir(directory):
21
  if filename.endswith(".pdf"): # Check if the file is a PDF
22
  file_path = os.path.join(directory, filename)
 
33
  )
34
  rag_documents = text_splitter.split_documents(documents)
35
 
36
+ embedding = OpenAIEmbeddings(model="text-embedding-3-small")
 
 
 
 
 
 
 
37
 
38
+ # Create the vector store
39
+ vectorstore = Qdrant.from_documents(
40
+ rag_documents,
41
+ embedding,
42
+ location=":memory:",
43
+ collection_name="Implications of AI",
44
+ )
45
 
46
+ retriever = vectorstore.as_retriever()
47
+ llm = ChatOpenAI(model="gpt-4")
 
 
 
48
 
 
 
 
 
 
 
49
 
50
+ # @cl.cache_resource
51
+ @cl.on_chat_start
52
+ async def start_chat():
53
 
 
54
  template = """
55
  Use the provided context to answer the user's query.
56
  You may not answer the user's query unless there is specific context in the following text.
 
63
  """
64
 
65
  prompt = ChatPromptTemplate.from_template(template)
 
66
 
67
+ base_chain = (
68
  {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
69
+ | prompt | llm | StrOutputParser()
 
70
  )
71
 
72
+ cl.user_session.set("chain", base_chain)
73
 
74
 
75
  @cl.on_message
76
  async def main(message):
77
  chain = cl.user_session.get("chain")
78
+ result = await chain.invoke({"question":message.content})
79
 
80
+ msg = cl.Message(content=result)
 
81
 
82
+ # async for stream_resp in result["response"]:
83
+ # await msg.stream_token(stream_resp)
84
 
85
  await msg.send()