Spaces:
Sleeping
Sleeping
danicafisher
commited on
Commit
•
eb52945
1
Parent(s):
a7f2408
Update app.py
Browse files
app.py
CHANGED
@@ -1,100 +1,80 @@
|
|
1 |
-
from
|
2 |
-
from aimakerspace.text_utils import CharacterTextSplitter, PDFFileLoader
|
3 |
-
from aimakerspace.openai_utils.prompts import (
|
4 |
-
UserRolePrompt,
|
5 |
-
SystemRolePrompt
|
6 |
-
)
|
7 |
-
from aimakerspace.vectordatabase import VectorDatabase
|
8 |
-
from aimakerspace.openai_utils.chatmodel import ChatOpenAI
|
9 |
-
from langchain_community.embeddings import OpenAIEmbeddings
|
10 |
-
from langchain_community.vectorstores import Chroma
|
11 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
12 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
import chainlit as cl
|
14 |
-
import nest_asyncio
|
15 |
-
nest_asyncio.apply()
|
16 |
-
|
17 |
-
# # pdf_loader_NIST = PDFFileLoader("data/NIST.AI.600-1.pdf")
|
18 |
-
# # pdf_loader_Blueprint = PDFFileLoader("data/Blueprint-for-an-AI-Bill-of-Rights.pdf")
|
19 |
-
# # documents_NIST = pdf_loader_NIST.load_documents()
|
20 |
-
# # documents_Blueprint = pdf_loader_Blueprint.load_documents()
|
21 |
-
|
22 |
-
# text_splitter = CharacterTextSplitter()
|
23 |
-
# # text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)
|
24 |
-
# split_documents_NIST = text_splitter.split_texts(documents_NIST)
|
25 |
-
# split_documents_Blueprint = text_splitter.split_texts(documents_Blueprint)
|
26 |
-
loader = PDFFileLoader("data/")
|
27 |
-
loader.load()
|
28 |
-
splitter = CharacterTextSplitter()
|
29 |
-
chunks = splitter.split_texts(loader.documents)
|
30 |
-
|
31 |
-
# rag_documents = split_documents_NIST + split_documents_Blueprint
|
32 |
-
|
33 |
-
RAG_PROMPT_TEMPLATE = """ \
|
34 |
-
Use the provided context to answer the user's query.
|
35 |
-
You may not answer the user's query unless there is specific context in the following text.
|
36 |
-
If you do not know the answer, or cannot answer, please respond with "I don't know".
|
37 |
-
"""
|
38 |
-
|
39 |
-
rag_prompt = SystemRolePrompt(RAG_PROMPT_TEMPLATE)
|
40 |
-
|
41 |
-
USER_PROMPT_TEMPLATE = """ \
|
42 |
-
Context:
|
43 |
-
{context}
|
44 |
-
Question:
|
45 |
-
{question}
|
46 |
-
"""
|
47 |
-
|
48 |
-
user_prompt = UserRolePrompt(USER_PROMPT_TEMPLATE)
|
49 |
-
|
50 |
-
class RetrievalAugmentedQAPipeline:
|
51 |
-
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: Chroma) -> None:
|
52 |
-
self.llm = llm
|
53 |
-
self.vector_db_retriever = vector_db_retriever
|
54 |
-
|
55 |
-
async def arun_pipeline(self, question: str):
|
56 |
-
context_list = self.vector_db_retriever.search_by_text(question, k=4)
|
57 |
-
|
58 |
-
context_prompt = ""
|
59 |
-
for context in context_list:
|
60 |
-
context_prompt += context[0] + "\n"
|
61 |
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
async def generate_response():
|
67 |
-
async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]):
|
68 |
-
yield chunk
|
69 |
|
70 |
-
|
|
|
71 |
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
#
|
|
|
|
|
|
|
|
|
|
|
74 |
|
|
|
|
|
|
|
75 |
|
76 |
-
@cl.on_chat_start
|
77 |
-
async def start_chat():
|
78 |
-
settings = {
|
79 |
-
"model": "gpt-4o-mini"
|
80 |
-
}
|
81 |
-
cl.user_session.set("settings", settings)
|
82 |
-
|
83 |
-
# Create a vector store
|
84 |
-
# vector_db = VectorDatabase()
|
85 |
-
# vector_db = await vector_db.abuild_from_list(split_documents_NIST)
|
86 |
-
# vector_db = await vector_db.abuild_from_list(split_documents_Blueprint)
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
)
|
96 |
|
97 |
-
cl.user_session.set("chain",
|
98 |
|
99 |
|
100 |
@cl.on_message
|
@@ -102,7 +82,7 @@ async def main(message):
|
|
102 |
chain = cl.user_session.get("chain")
|
103 |
|
104 |
msg = cl.Message(content="")
|
105 |
-
result = await chain.
|
106 |
|
107 |
async for stream_resp in result["response"]:
|
108 |
await msg.stream_token(stream_resp)
|
|
|
1 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
3 |
+
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
4 |
+
from langchain_qdrant import QdrantVectorStore
|
5 |
+
from langchain.prompts import ChatPromptTemplate
|
6 |
+
from langchain_core.output_parsers import StrOutputParser
|
7 |
+
from langchain_core.runnables import RunnablePassthrough
|
8 |
+
from qdrant_client import QdrantClient
|
9 |
+
from qdrant_client.http.models import Distance, VectorParams
|
10 |
+
from operator import itemgetter
|
11 |
import chainlit as cl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
# Load the documents
|
14 |
+
pdf_loader_NIST = PyMuPDFLoader(file_path="data/NIST.AI.600-1.pdf").load()
|
15 |
+
pdf_loader_Blueprint = PyMuPDFLoader(file_path="data/Blueprint-for-an-AI-Bill-of-Rights.pdf").load()
|
16 |
+
documents = pdf_loader_NIST + pdf_loader_Blueprint
|
17 |
+
|
18 |
+
# Split the documents
|
19 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
20 |
+
chunk_size=500,
|
21 |
+
chunk_overlap=40,
|
22 |
+
length_function=len,
|
23 |
+
is_separator_regex=False
|
24 |
+
)
|
25 |
+
rag_documents = text_splitter.split_documents(documents)
|
26 |
|
27 |
+
# Create the vector store
|
28 |
+
# @cl.cache_resource
|
29 |
+
@cl.on_chat_start
|
30 |
+
async def start_chat():
|
31 |
+
LOCATION = ":memory:"
|
32 |
+
COLLECTION_NAME = "Implications of AI"
|
33 |
+
VECTOR_SIZE = 1536
|
34 |
|
|
|
|
|
|
|
35 |
|
36 |
+
embeddings = OpenAIEmbeddings()
|
37 |
+
qdrant_client = QdrantClient(location=LOCATION)
|
38 |
|
39 |
+
# Create the collection
|
40 |
+
qdrant_client.create_collection(
|
41 |
+
collection_name=COLLECTION_NAME,
|
42 |
+
vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE),
|
43 |
+
)
|
44 |
|
45 |
+
# Create the vector store
|
46 |
+
vectorstore = QdrantVectorStore(
|
47 |
+
client=qdrant_client,
|
48 |
+
collection_name=COLLECTION_NAME,
|
49 |
+
embedding=embeddings
|
50 |
+
)
|
51 |
|
52 |
+
# Load and add documents
|
53 |
+
vectorstore.add_documents(rag_documents)
|
54 |
+
retriever = vectorstore.as_retriever()
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
template = """
|
58 |
+
Use the provided context to answer the user's query.
|
59 |
+
You may not answer the user's query unless there is specific context in the following text.
|
60 |
+
If you do not know the answer, or cannot answer, please respond with "I don't know".
|
61 |
+
Question:
|
62 |
+
{question}
|
63 |
+
Context:
|
64 |
+
{context}
|
65 |
+
Answer:
|
66 |
+
"""
|
67 |
+
|
68 |
+
prompt = ChatPromptTemplate.from_template(template)
|
69 |
+
base_llm = ChatOpenAI(model_name="gpt-4", temperature=0)
|
70 |
+
|
71 |
+
retrieval_augmented_qa_chain = (
|
72 |
+
{"context": itemgetter("question") | retriever, "question": itemgetter("question")}
|
73 |
+
| RunnablePassthrough.assign(context=itemgetter("context"))
|
74 |
+
| {"response": prompt | base_llm, "context": itemgetter("context")}
|
75 |
)
|
76 |
|
77 |
+
cl.user_session.set("chain", retrieval_augmented_qa_chain)
|
78 |
|
79 |
|
80 |
@cl.on_message
|
|
|
82 |
chain = cl.user_session.get("chain")
|
83 |
|
84 |
msg = cl.Message(content="")
|
85 |
+
result = await chain.invoke(message.content)
|
86 |
|
87 |
async for stream_resp in result["response"]:
|
88 |
await msg.stream_token(stream_resp)
|