danicafisher commited on
Commit
eb52945
1 Parent(s): a7f2408

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -85
app.py CHANGED
@@ -1,100 +1,80 @@
1
- from typing import List
2
- from aimakerspace.text_utils import CharacterTextSplitter, PDFFileLoader
3
- from aimakerspace.openai_utils.prompts import (
4
- UserRolePrompt,
5
- SystemRolePrompt
6
- )
7
- from aimakerspace.vectordatabase import VectorDatabase
8
- from aimakerspace.openai_utils.chatmodel import ChatOpenAI
9
- from langchain_community.embeddings import OpenAIEmbeddings
10
- from langchain_community.vectorstores import Chroma
11
  from langchain.text_splitter import RecursiveCharacterTextSplitter
12
- from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
 
 
 
 
 
 
 
13
  import chainlit as cl
14
- import nest_asyncio
15
- nest_asyncio.apply()
16
-
17
- # # pdf_loader_NIST = PDFFileLoader("data/NIST.AI.600-1.pdf")
18
- # # pdf_loader_Blueprint = PDFFileLoader("data/Blueprint-for-an-AI-Bill-of-Rights.pdf")
19
- # # documents_NIST = pdf_loader_NIST.load_documents()
20
- # # documents_Blueprint = pdf_loader_Blueprint.load_documents()
21
-
22
- # text_splitter = CharacterTextSplitter()
23
- # # text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)
24
- # split_documents_NIST = text_splitter.split_texts(documents_NIST)
25
- # split_documents_Blueprint = text_splitter.split_texts(documents_Blueprint)
26
- loader = PDFFileLoader("data/")
27
- loader.load()
28
- splitter = CharacterTextSplitter()
29
- chunks = splitter.split_texts(loader.documents)
30
-
31
- # rag_documents = split_documents_NIST + split_documents_Blueprint
32
-
33
- RAG_PROMPT_TEMPLATE = """ \
34
- Use the provided context to answer the user's query.
35
- You may not answer the user's query unless there is specific context in the following text.
36
- If you do not know the answer, or cannot answer, please respond with "I don't know".
37
- """
38
-
39
- rag_prompt = SystemRolePrompt(RAG_PROMPT_TEMPLATE)
40
-
41
- USER_PROMPT_TEMPLATE = """ \
42
- Context:
43
- {context}
44
- Question:
45
- {question}
46
- """
47
-
48
- user_prompt = UserRolePrompt(USER_PROMPT_TEMPLATE)
49
-
50
- class RetrievalAugmentedQAPipeline:
51
- def __init__(self, llm: ChatOpenAI(), vector_db_retriever: Chroma) -> None:
52
- self.llm = llm
53
- self.vector_db_retriever = vector_db_retriever
54
-
55
- async def arun_pipeline(self, question: str):
56
- context_list = self.vector_db_retriever.search_by_text(question, k=4)
57
-
58
- context_prompt = ""
59
- for context in context_list:
60
- context_prompt += context[0] + "\n"
61
 
62
- formatted_system_prompt = rag_prompt.create_message()
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- formatted_user_prompt = user_prompt.create_message(user_query=user_query, context=context_prompt)
 
 
 
 
 
 
65
 
66
- async def generate_response():
67
- async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
68
- yield chunk
69
 
70
- return {"response": generate_response(), "context": context_list}
 
71
 
 
 
 
 
 
72
 
73
- # ------------------------------------------------------------
 
 
 
 
 
74
 
 
 
 
75
 
76
- @cl.on_chat_start
77
- async def start_chat():
78
- settings = {
79
- "model": "gpt-4o-mini"
80
- }
81
- cl.user_session.set("settings", settings)
82
-
83
- # Create a vector store
84
- # vector_db = VectorDatabase()
85
- # vector_db = await vector_db.abuild_from_list(split_documents_NIST)
86
- # vector_db = await vector_db.abuild_from_list(split_documents_Blueprint)
87
 
88
- embeddings = OpenAIEmbeddings()
89
- vector_db = Chroma.from_texts(chunks, embeddings)
90
-
91
- # Create a chain
92
- retrieval_augmented_qa_pipeline = RetrievalAugmentedQAPipeline(
93
- vector_db_retriever=vector_db,
94
- llm=chat_openai
 
 
 
 
 
 
 
 
 
 
 
95
  )
96
 
97
- cl.user_session.set("chain", retrieval_augmented_qa_pipeline)
98
 
99
 
100
  @cl.on_message
@@ -102,7 +82,7 @@ async def main(message):
102
  chain = cl.user_session.get("chain")
103
 
104
  msg = cl.Message(content="")
105
- result = await chain.arun_pipeline(message.content)
106
 
107
  async for stream_resp in result["response"]:
108
  await msg.stream_token(stream_resp)
 
1
+ from langchain_community.document_loaders import PyMuPDFLoader
 
 
 
 
 
 
 
 
 
2
  from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
4
+ from langchain_qdrant import QdrantVectorStore
5
+ from langchain.prompts import ChatPromptTemplate
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.runnables import RunnablePassthrough
8
+ from qdrant_client import QdrantClient
9
+ from qdrant_client.http.models import Distance, VectorParams
10
+ from operator import itemgetter
11
  import chainlit as cl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Load the documents
14
+ pdf_loader_NIST = PyMuPDFLoader(file_path="data/NIST.AI.600-1.pdf").load()
15
+ pdf_loader_Blueprint = PyMuPDFLoader(file_path="data/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
16
+ documents = pdf_loader_NIST + pdf_loader_Blueprint
17
+
18
+ # Split the documents
19
+ text_splitter = RecursiveCharacterTextSplitter(
20
+ chunk_size=500,
21
+ chunk_overlap=40,
22
+ length_function=len,
23
+ is_separator_regex=False
24
+ )
25
+ rag_documents = text_splitter.split_documents(documents)
26
 
27
+ # Create the vector store
28
+ # @cl.cache_resource
29
+ @cl.on_chat_start
30
+ async def start_chat():
31
+ LOCATION = ":memory:"
32
+ COLLECTION_NAME = "Implications of AI"
33
+ VECTOR_SIZE = 1536
34
 
 
 
 
35
 
36
+ embeddings = OpenAIEmbeddings()
37
+ qdrant_client = QdrantClient(location=LOCATION)
38
 
39
+ # Create the collection
40
+ qdrant_client.create_collection(
41
+ collection_name=COLLECTION_NAME,
42
+ vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
43
+ )
44
 
45
+ # Create the vector store
46
+ vectorstore = QdrantVectorStore(
47
+ client=qdrant_client,
48
+ collection_name=COLLECTION_NAME,
49
+ embedding=embeddings
50
+ )
51
 
52
+ # Load and add documents
53
+ vectorstore.add_documents(rag_documents)
54
+ retriever = vectorstore.as_retriever()
55
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ template = """
58
+ Use the provided context to answer the user's query.
59
+ You may not answer the user's query unless there is specific context in the following text.
60
+ If you do not know the answer, or cannot answer, please respond with "I don't know".
61
+ Question:
62
+ {question}
63
+ Context:
64
+ {context}
65
+ Answer:
66
+ """
67
+
68
+ prompt = ChatPromptTemplate.from_template(template)
69
+ base_llm = ChatOpenAI(model_name="gpt-4", temperature=0)
70
+
71
+ retrieval_augmented_qa_chain = (
72
+ {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
73
+ | RunnablePassthrough.assign(context=itemgetter("context"))
74
+ | {"response": prompt | base_llm, "context": itemgetter("context")}
75
  )
76
 
77
+ cl.user_session.set("chain", retrieval_augmented_qa_chain)
78
 
79
 
80
  @cl.on_message
 
82
  chain = cl.user_session.get("chain")
83
 
84
  msg = cl.Message(content="")
85
+ result = await chain.invoke(message.content)
86
 
87
  async for stream_resp in result["response"]:
88
  await msg.stream_token(stream_resp)