JoshuaKelleyDs commited on
Commit
f8ae825
1 Parent(s): 7492d51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -53
app.py CHANGED
@@ -1,75 +1,133 @@
1
- import chainlit as cl
2
- from langchain_together import ChatTogether, TogetherEmbeddings
3
- from langchain_core.runnables import RunnableSequence, RunnablePassthrough
4
- from langchain_core.prompts import ChatPromptTemplate
5
- from langchain_community.document_loaders import YoutubeLoader
6
- from typing import List
7
- import langchain_core
8
- from langchain_community.vectorstores import FAISS
9
- from langchain.retrievers.ensemble import EnsembleRetriever
10
- from langchain_community.retrievers import BM25Retriever
11
- from langchain_text_splitters import RecursiveCharacterTextSplitter
12
 
13
- def create_youtube_transcription(youtube_url: str):
 
 
 
 
 
 
 
 
14
  loader = YoutubeLoader.from_youtube_url(
15
- youtube_url, add_video_info=False
16
- )
17
- youtube_docs = loader.load()
18
  return youtube_docs
19
 
20
  def create_text_splitter(docs: List[langchain_core.documents.Document]):
21
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
22
- docs = text_splitter.split_documents(docs)
 
 
 
 
 
 
 
 
23
  return docs
24
 
25
- def create_vector_store(docs: List[langchain_core.documents.Document]):
26
- embedding = cl.user_session.get("embedding")
27
- vector_db = FAISS.from_documents(docs, embedding)
 
 
 
 
 
 
 
 
 
28
  return vector_db
29
 
30
- def create_bm25_vector_store(docs: List[langchain_core.documents.Document]):
31
- bm25 = BM25Retriever.from_documents(docs)
 
 
 
 
 
 
 
 
 
32
  return bm25
33
 
34
- def create_ensemble_retriever(vector_db:FAISS, bm25:BM25Retriever):
35
- ensemble_retreiver = EnsembleRetriever(retrievers=[vector_db.as_retriever(), bm25], weights=[.3, .7])
 
 
 
 
 
 
 
 
 
36
  return ensemble_retreiver
37
 
38
  @cl.on_chat_start
39
  async def start():
40
- await cl.Message(content="my name is josh!").send()
41
- try:
42
- llm = ChatTogether(model="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo")
43
- await cl.Message(content=f"model is successfully loaded").send()
44
- cl.user_session.set("llm", llm)
45
- embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval")
46
- cl.user_session.set("embedding", embedding)
 
 
 
 
 
 
 
47
  await cl.Message(content="embedding model loaded").send()
48
- youtube_link = await cl.AskUserMessage("Please provide the YouTube video link").send()
49
- youtube_docs = create_youtube_transcription(youtube_link['output'])
50
- split_docs = create_text_splitter(youtube_docs)
51
- vector_db = create_vector_store(split_docs)
52
- bm25 = create_bm25_vector_store(split_docs)
53
- ensemble_retriever = create_ensemble_retriever(vector_db, bm25)
54
- cl.user_session.set("ensemble_retriever", ensemble_retriever)
55
- transcription = youtube_docs[0].page_content
56
- await cl.Message(content=f"youtube docs: {transcription}").send()
57
- cl.user_session.set("transcription", transcription)
 
58
  except Exception as e:
59
- await cl.Message(content=f"failed to load model: {e}").send()
60
 
61
 
62
  @cl.on_message
63
  async def message(message: cl.Message):
 
 
 
 
 
64
  prompt_template = ChatPromptTemplate.from_template(template="""
65
- You are a helpful assistant that can answer questions about the following video. Here is the appropriate chunks of context: {context}.
66
- Answer the question: {question} but do not use any information outside of the video. Site the source or information you used to answer the question
67
- """)
68
- llm = cl.user_session.get("llm")
69
- vector_db = cl.user_session.get("vector_db")
70
- transcription = cl.user_session.get("transcription")
71
- ensemble_retriever = cl.user_session.get("ensemble_retriever")
72
- rag_chain = RunnableSequence({"context": ensemble_retriever, "question": RunnablePassthrough()}, prompt_template | llm)
73
- response = rag_chain.invoke(message.content)
74
- await cl.Message(content=response.content).send()
75
-
 
 
 
1
+ import chainlit as cl # handles the chat interface
2
+ from langchain_together import ChatTogether, TogetherEmbeddings # for the LLM and Embeddings
3
+ from langchain_core.runnables import RunnableSequence, RunnablePassthrough # for chain execution
4
+ from langchain_core.prompts import ChatPromptTemplate # for writing the prompt template
5
+ from langchain_community.document_loaders import YoutubeLoader # for loading the youtube video
6
+ from typing import List # for type hinting
7
+ import langchain_core # for type hinting
8
+ from langchain_community.vectorstores import FAISS # for the vector store
9
+ from langchain_community.retrievers import BM25Retriever # for the BM25 retriever
10
+ from langchain.retrievers.ensemble import EnsembleRetriever # for the ensemble retriever
11
+ from langchain_text_splitters import RecursiveCharacterTextSplitter # for the text splitter
12
 
13
+ def create_youtube_transcription(youtube_url: str) -> List[langchain_core.documents.Document]:
14
+ """
15
+ Create a youtube transcription from a youtube url
16
+ More Info: https://python.langchain.com/docs/integrations/document_loaders/youtube_transcript/
17
+ Accepts:
18
+ youtube_url: str - The url of the youtube video
19
+ Returns:
20
+ List[langchain_core.documents.Document]: A list of documents containing the youtube transcription
21
+ """
22
  loader = YoutubeLoader.from_youtube_url(
23
+ youtube_url, add_video_info=True
24
+ ) # we can also pass an array of youtube urls to load multiple videos at once!
25
+ youtube_docs = loader.load() # this loads the transcript
26
  return youtube_docs
27
 
28
  def create_text_splitter(docs: List[langchain_core.documents.Document]):
29
+ """
30
+ Create a text splitter from a list of documents
31
+ More Info: ument_transformers/recursive_text_splitter/
32
+ Accepts:
33
+ docs: List[langchain_core.documents.Document] - A list of documents to split
34
+ Returns:
35
+ List[langchain_core.documents.Document]: A list of documents containing the text split
36
+ """
37
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # without an overlap, context might get cut off
38
+ docs = text_splitter.split_documents(docs) # split the documents into chunks
39
  return docs
40
 
41
+ def create_faiss_vector_store(docs: List[langchain_core.documents.Document]) -> FAISS:
42
+ """
43
+ Create a FAISS vector store or vector database from a list of documents
44
+ More Info: https://python.langchain.com/docs/integrations/vectorstores/faiss/
45
+ Accepts:
46
+ docs: List[langchain_core.documents.Document] - A list of documents to store
47
+ Returns:
48
+ FAISS: A vector store containing the documents
49
+ """
50
+ embedding = cl.user_session.get("embedding") # we can get the embedding model from the user session or pass as a parameter too!
51
+ vector_db = FAISS.from_documents(docs, embedding) # create the vector store
52
+ vector_db.k = 5 # we set k to 5, so we get 5 documents back
53
  return vector_db
54
 
55
+ def create_bm25_retreiver(docs: List[langchain_core.documents.Document]) -> BM25Retriever:
56
+ """
57
+ Create a BM25 retriever from a list of documents
58
+ More Info: https://python.langchain.com/docs/integrations/retrievers/bm25/
59
+ Accepts:
60
+ docs: List[langchain_core.documents.Document] - A list of documents to store
61
+ Returns:
62
+ BM25Retriever: A BM25 retriever containing the documents
63
+ """
64
+ bm25 = BM25Retriever.from_documents(docs) # we don't need embeddings for BM25, as it uses keyword matching!
65
+ bm25.k = 5 # we set k to 5, so we get 5 documents back
66
  return bm25
67
 
68
+ def create_ensemble_retriever(vector_db:FAISS, bm25:BM25Retriever) -> EnsembleRetriever:
69
+ """
70
+ Create an ensemble retriever from a vector db and a BM25 retriever
71
+ More Info: https://python.langchain.com/docs/how_to/ensemble_retriever/
72
+ Accepts:
73
+ vector_db: FAISS - A vector db
74
+ bm25: BM25Retriever - A BM25 retriever
75
+ Returns:
76
+ EnsembleRetriever: An ensemble retriever containing the vector db and the BM25 retriever
77
+ """
78
+ ensemble_retreiver = EnsembleRetriever(retrievers=[vector_db.as_retriever(), bm25], weights=[.3, .7]) # 30% semantic, 70% keyword retrieval
79
  return ensemble_retreiver
80
 
81
  @cl.on_chat_start
82
  async def start():
83
+ """
84
+ More info: https://docs.chainlit.io/api-reference/lifecycle-hooks/on-chat-start
85
+ This function is called when the chat starts. Under the hood it handles all the complicated stuff for loading the UI.
86
+ We explicitly load the model, embeddings, and retrievers.
87
+ Asks the user to provide the YouTube video link and loads the transcription.
88
+ With the transcription, it creates a vector store and a BM25 vector store. That is used to create an ensemble retriever combining the two.
89
+ """
90
+ await cl.Message(content="Hello! I am your AI assistant. I can help you with your questions about the video you provide.").send()
91
+ try: # a try catch block prevents the app from crashing if have an error
92
+ llm = ChatTogether(model="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo") # initialize the LLM model
93
+ await cl.Message(content=f"model is successfully loaded").send() # we can send messages to be displayed with cl.Message().send()
94
+ cl.user_session.set("llm", llm) # we can store variables in a special memory called the user session, so we can use them in our on message function and more
95
+ embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval") # initialize the embedding model
96
+ cl.user_session.set("embedding", embedding) # store the embedding model in the user session
97
  await cl.Message(content="embedding model loaded").send()
98
+ youtube_link = await cl.AskUserMessage("Please provide the YouTube video link").send() # We can ask the user for input using cl.AskUserMessage().send() which does not affect cl.on_message()
99
+ # more on ask user message: https://docs.chainlit.io/api-reference/ask/ask-for-input
100
+ await cl.Message(content=f"youtube link: {youtube_link}").send() # display and double check to make sure the link is correct
101
+ youtube_docs = create_youtube_transcription(youtube_link['content']) # create the youtube transcription
102
+ split_docs = create_text_splitter(youtube_docs) # split the documents into chunks
103
+ vector_db = create_faiss_vector_store(split_docs) # create the vector db
104
+ bm25 = create_bm25_retreiver(split_docs) # create the BM25 retreiver
105
+ ensemble_retriever = create_ensemble_retriever(vector_db, bm25) # create the ensemble retriever
106
+ cl.user_session.set("ensemble_retriever", ensemble_retriever) # store the ensemble retriever in the user session for our on message function
107
+ transcription = youtube_docs[0].page_content # get the transcription of the first document
108
+ await cl.Message(content=f"youtube docs: {transcription}").send() # display the transcription of the first document to show that we have the correct data
109
  except Exception as e:
110
+ await cl.Message(content=f"failed to load model: {e}").send() # display the error if we failed to load the model
111
 
112
 
113
  @cl.on_message
114
  async def message(message: cl.Message):
115
+ """
116
+ More info: https://docs.chainlit.io/api-reference/lifecycle-hooks/on-message
117
+ This function is called when the user sends a message. It uses the ensemble retriever to find the most relevant documents and feeds them into the LLM.
118
+ We can then display the answer and the relevant documents to the user.
119
+ """
120
  prompt_template = ChatPromptTemplate.from_template(template="""
121
+ You are a helpful assistant that can answer questions about the following video. Here is the appropriate chunks of context: {context}.
122
+ Answer the question: {question} but do not use any information outside of the video. Site the source or information you used to answer the question
123
+ """) # we create a prompt template that we will use to format our prompt
124
+ llm = cl.user_session.get("llm") # we get the LLM model we initialized in the start function
125
+ ensemble_retriever = cl.user_session.get("ensemble_retriever") # we get the ensemble retriever we initialized in the start function
126
+ relevant_docs = ensemble_retriever.invoke(message.content) # we use the ensemble retriever to find the most relevant documents
127
+ cl.Message(content=f"Displaying Relevant Docs").send() # we display the relevant documents to the user
128
+ for doc in relevant_docs: # loop through the relevant documents and display each one!
129
+ await cl.Message(content=doc.page_content).send()
130
+ rag_chain = RunnableSequence({"context": ensemble_retriever, "question": RunnablePassthrough()} | prompt_template | llm)
131
+ response = rag_chain.invoke(message.content) # we invoke the rag chain with the user's message
132
+ await cl.Message(content="Done Displaying Relevant Docs").send()
133
+ await cl.Message(content=f"LLM Response: {response.content}").send() # we display the response to the user