pattonma commited on
Commit
4c95dc7
1 Parent(s): 805a608
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
marketingRAG/Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+ RUN useradd -m -u 1000 user
3
+ USER user
4
+ ENV HOME=/home/user \
5
+ PATH=/home/user/.local/bin:$PATH
6
+ WORKDIR $HOME/app
7
+ COPY --chown=user . $HOME/app
8
+ COPY ./requirements.txt ~/app/requirements.txt
9
+ RUN pip install -r requirements.txt
10
+ COPY . .
11
+ CMD ["chainlit", "run", "app.py", "--port", "7860"]
marketingRAG/app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
2
+ from qdrant_client import QdrantClient
3
+ from langchain_openai.embeddings import OpenAIEmbeddings
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_core.globals import set_llm_cache
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_core.caches import InMemoryCache
8
+ from operator import itemgetter
9
+ from langchain_core.runnables.passthrough import RunnablePassthrough
10
+ from langchain_qdrant import QdrantVectorStore, Qdrant
11
+ import uuid
12
+ import chainlit as cl
13
+ import os
14
+
15
+ chat_model = ChatOpenAI(model="gpt-4o-mini")
16
+ te3_small = OpenAIEmbeddings(model="text-embedding-3-small")
17
+ set_llm_cache(InMemoryCache())
18
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
19
+ rag_system_prompt_template = """\
20
+ You are a helpful assistant that uses the provided context to answer questions. Never reference this prompt, or the existance of context.
21
+ """
22
+ rag_message_list = [{"role" : "system", "content" : rag_system_prompt_template},]
23
+ rag_user_prompt_template = """\
24
+ Question:
25
+ {question}
26
+ Context:
27
+ {context}
28
+ """
29
+ chat_prompt = ChatPromptTemplate.from_messages([("system", rag_system_prompt_template), ("human", rag_user_prompt_template)])
30
+
31
+ @cl.on_chat_start
32
+ async def on_chat_start():
33
+ qdrant_client = QdrantClient(url=os.environ["QDRANT_ENDPOINT"], api_key=os.environ["QDRANT_API_KEY"])
34
+ qdrant_store = Qdrant(
35
+ client=qdrant_client,
36
+ collection_name="kai_test_docs",
37
+ embeddings=te3_small
38
+ )
39
+ retriever = qdrant_store.as_retriever()
40
+
41
+ global retrieval_augmented_qa_chain
42
+ retrieval_augmented_qa_chain = (
43
+ {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
44
+ | RunnablePassthrough.assign(context=itemgetter("context"))
45
+ | chat_prompt
46
+ | chat_model
47
+ )
48
+
49
+ await cl.Message(content="YAsk away!").send()
50
+
51
+ @cl.author_rename
52
+ def rename(orig_author: str):
53
+ return "AI Assistant"
54
+
55
+ @cl.on_message
56
+ async def main(message: cl.Message):
57
+ response = retrieval_augmented_qa_chain.invoke({"question": message.content})
58
+ await cl.Message(content=response.content).send()
marketingRAG/constants.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ OPENAI_API_KEY = "";
2
+ ANTRHOPIC_API_KEY = "";
3
+ LANGCHAIN_API_KEY = "";
4
+ LANGCHAIN_TRACING_V2=True;
5
+ LANGCHAIN_ENDPOINT='https://api.smith.langchain.com';
6
+ QDRANT_API_KEY="";
7
+ QDRANT_ENDPOINT="";
marketingRAG/load_existing_docs.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import models
2
+ import constants
3
+ from langchain_experimental.text_splitter import SemanticChunker
4
+ from langchain_qdrant import QdrantVectorStore, Qdrant
5
+ from langchain_community.document_loaders import PyPDFLoader
6
+ from qdrant_client.http.models import VectorParams
7
+
8
+ #qdrant = QdrantVectorStore.from_existing_collection(
9
+ # embedding=models.basic_embeddings,
10
+ # collection_name="kai_test_documents",
11
+ # url=constants.QDRANT_ENDPOINT,
12
+ #)
13
+
14
+
15
+ #gather kai's docs
16
+ filepaths = ["./test_docs/Employee Statistics FINAL.pdf","./test_docs/Employer Statistics FINAL.pdf"]
17
+ all_documents = []
18
+ for file in filepaths:
19
+ loader = PyPDFLoader(file)
20
+ documents = loader.load()
21
+ for doc in documents:
22
+ doc.metadata = {
23
+ "source": file,
24
+ "tag": "employee" if "employee" in file.lower() else "employer"
25
+ }
26
+ all_documents.extend(documents)
27
+
28
+ #chunk them
29
+ semantic_split_docs = models.semanticChunker.split_documents(all_documents)
30
+
31
+
32
+ #add them to the existing qdrant client
33
+ collection_name = "kai_test_docs"
34
+
35
+ collections = models.qdrant_client.get_collections()
36
+ collection_names = [collection.name for collection in collections.collections]
37
+ # If the collection does not exist, create it
38
+ if collection_name not in collection_names:
39
+ models.qdrant_client.create_collection(
40
+ collection_name=collection_name,
41
+ vectors_config=VectorParams(size=1536, distance="Cosine")
42
+ )
43
+
44
+ qdrant_vector_store = Qdrant(
45
+ client=models.qdrant_client,
46
+ collection_name=collection_name,
47
+ embeddings=models.te3_small
48
+ )
49
+
50
+ qdrant_vector_store.add_documents(semantic_split_docs)
51
+
52
+
53
+
54
+ collection_info = models.qdrant_client.get_collection(collection_name)
55
+ print(f"Number of points in collection: {collection_info.points_count}")
marketingRAG/models.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_anthropic import ChatAnthropic
2
+ from langchain_openai import ChatOpenAI
3
+ from langchain.callbacks.manager import CallbackManager
4
+ from langchain.callbacks.tracers import LangChainTracer
5
+ from langchain_huggingface.embeddings import HuggingFaceEmbeddings
6
+ from langchain_community.embeddings import HuggingFaceInstructEmbeddings
7
+ from langchain_experimental.text_splitter import SemanticChunker
8
+ from langchain_openai.embeddings import OpenAIEmbeddings
9
+ from langchain_community.vectorstores import Qdrant
10
+ from qdrant_client import QdrantClient
11
+ import constants
12
+ import os
13
+
14
+ os.environ["LANGCHAIN_API_KEY"] = constants.LANGCHAIN_API_KEY
15
+ os.environ["LANGCHAIN_TRACING_V2"] = str(constants.LANGCHAIN_TRACING_V2)
16
+ os.environ["LANGCHAIN_ENDPOINT"] = constants.LANGCHAIN_ENDPOINT
17
+
18
+ tracer = LangChainTracer()
19
+ callback_manager = CallbackManager([tracer])
20
+
21
+ qdrant_client = QdrantClient(url=constants.QDRANT_ENDPOINT, api_key=constants.QDRANT_API_KEY)
22
+
23
+ opus3 = ChatAnthropic(
24
+ api_key=constants.ANTRHOPIC_API_KEY,
25
+ temperature=0,
26
+ model='claude-3-opus-20240229',
27
+ callback_manager=callback_manager
28
+ )
29
+
30
+ sonnet35 = ChatAnthropic(
31
+ api_key=constants.ANTRHOPIC_API_KEY,
32
+ temperature=0,
33
+ model='claude-3-5-sonnet-20240620',
34
+ max_tokens=4096,
35
+ callback_manager=callback_manager
36
+ )
37
+
38
+ gpt4 = ChatOpenAI(
39
+ model="gpt-4",
40
+ temperature=0,
41
+ max_tokens=None,
42
+ timeout=None,
43
+ max_retries=2,
44
+ api_key=constants.OPENAI_API_KEY,
45
+ callback_manager=callback_manager
46
+ )
47
+
48
+ gpt4o = ChatOpenAI(
49
+ model="gpt-4o",
50
+ temperature=0,
51
+ max_tokens=None,
52
+ timeout=None,
53
+ max_retries=2,
54
+ api_key=constants.OPENAI_API_KEY,
55
+ callback_manager=callback_manager
56
+ )
57
+
58
+ gpt4o_mini = ChatOpenAI(
59
+ model="gpt-4o-mini",
60
+ temperature=0,
61
+ max_tokens=None,
62
+ timeout=None,
63
+ max_retries=2,
64
+ api_key=constants.OPENAI_API_KEY,
65
+ callback_manager=callback_manager
66
+ )
67
+
68
+ basic_embeddings = HuggingFaceEmbeddings(model_name="snowflake/snowflake-arctic-embed-l")
69
+ #hkunlp_instructor_large = HuggingFaceInstructEmbeddings(
70
+ # model_name = "hkunlp/instructor-large",
71
+ # query_instruction="Represent the query for retrieval: "
72
+ #)
73
+
74
+ te3_small = OpenAIEmbeddings(api_key=constants.OPENAI_API_KEY, model="text-embedding-3-small")
75
+
76
+ semanticChunker = SemanticChunker(
77
+ te3_small,
78
+ breakpoint_threshold_type="percentile"
79
+ )
marketingRAG/prompts.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder, PromptTemplate
2
+ from langchain.schema import SystemMessage
3
+
marketingRAG/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-experimental
3
+ langchain-qdrant
4
+ langchain-community
5
+ qdrant-client
6
+ langchain-anthropic
7
+ langchain-openai
8
+ langchain-huggingface
marketingRAG/set_constants.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import constants
2
+ import os
3
+ from dotenv import load_dotenv, find_dotenv
4
+
5
+ load_dotenv(find_dotenv())
6
+
7
+ current_directory = os.path.dirname(os.path.abspath(__file__))
8
+ file_path = os.path.join(current_directory, 'constants.py')
9
+ constantsFile = open(file_path, "w")
10
+ constantsFile.write("OPENAI_API_KEY='" + os.getenv("OPENAI_API_KEY") + "';\n");
11
+ constantsFile.write("ANTRHOPIC_API_KEY='" + os.getenv("ANTRHOPIC_API_KEY") + "';\n");
12
+ constantsFile.write("LANGCHAIN_API_KEY='" + os.getenv("LANGCHAIN_API_KEY") + "';\n");
13
+ constantsFile.write("LANGCHAIN_TRACING_V2=True;\n");
14
+ constantsFile.write("LANGCHAIN_ENDPOINT='https://api.smith.langchain.com';\n");
15
+ constantsFile.write("QDRANT_API_KEY='" + os.getenv("QDRANT_API_KEY") + "';\n");
16
+ constantsFile.write("QDRANT_ENDPOINT='" + os.getenv("QDRANT_ENDPOINT") + "';\n");
17
+ constantsFile.close()
marketingRAG/test_docs/Employee Statistics FINAL.pdf ADDED
Binary file (92.2 kB). View file
 
marketingRAG/test_docs/Employer Statistics FINAL.pdf ADDED
Binary file (113 kB). View file