Spaces:
Runtime error
Runtime error
JoshuaKelleyDs
commited on
Commit
•
f8ae825
1
Parent(s):
7492d51
Update app.py
Browse files
app.py
CHANGED
@@ -1,75 +1,133 @@
|
|
1 |
-
import chainlit as cl
|
2 |
-
from langchain_together import ChatTogether, TogetherEmbeddings
|
3 |
-
from langchain_core.runnables import RunnableSequence, RunnablePassthrough
|
4 |
-
from langchain_core.prompts import ChatPromptTemplate
|
5 |
-
from langchain_community.document_loaders import YoutubeLoader
|
6 |
-
from typing import List
|
7 |
-
import langchain_core
|
8 |
-
from langchain_community.vectorstores import FAISS
|
9 |
-
from
|
10 |
-
from
|
11 |
-
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
12 |
|
13 |
-
def create_youtube_transcription(youtube_url: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
loader = YoutubeLoader.from_youtube_url(
|
15 |
-
|
16 |
-
)
|
17 |
-
youtube_docs = loader.load()
|
18 |
return youtube_docs
|
19 |
|
20 |
def create_text_splitter(docs: List[langchain_core.documents.Document]):
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
return docs
|
24 |
|
25 |
-
def
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
return vector_db
|
29 |
|
30 |
-
def
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
return bm25
|
33 |
|
34 |
-
def create_ensemble_retriever(vector_db:FAISS, bm25:BM25Retriever):
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
return ensemble_retreiver
|
37 |
|
38 |
@cl.on_chat_start
|
39 |
async def start():
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
await cl.Message(content="embedding model loaded").send()
|
48 |
-
youtube_link = await cl.AskUserMessage("Please provide the YouTube video link").send()
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
except Exception as e:
|
59 |
-
await cl.Message(content=f"failed to load model: {e}").send()
|
60 |
|
61 |
|
62 |
@cl.on_message
|
63 |
async def message(message: cl.Message):
|
|
|
|
|
|
|
|
|
|
|
64 |
prompt_template = ChatPromptTemplate.from_template(template="""
|
65 |
-
|
66 |
-
|
67 |
-
""")
|
68 |
-
llm = cl.user_session.get("llm")
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
1 |
+
import chainlit as cl # handles the chat interface
|
2 |
+
from langchain_together import ChatTogether, TogetherEmbeddings # for the LLM and Embeddings
|
3 |
+
from langchain_core.runnables import RunnableSequence, RunnablePassthrough # for chain execution
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate # for writing the prompt template
|
5 |
+
from langchain_community.document_loaders import YoutubeLoader # for loading the youtube video
|
6 |
+
from typing import List # for type hinting
|
7 |
+
import langchain_core # for type hinting
|
8 |
+
from langchain_community.vectorstores import FAISS # for the vector store
|
9 |
+
from langchain_community.retrievers import BM25Retriever # for the BM25 retriever
|
10 |
+
from langchain.retrievers.ensemble import EnsembleRetriever # for the ensemble retriever
|
11 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter # for the text splitter
|
12 |
|
13 |
+
def create_youtube_transcription(youtube_url: str) -> List[langchain_core.documents.Document]:
|
14 |
+
"""
|
15 |
+
Create a youtube transcription from a youtube url
|
16 |
+
More Info: https://python.langchain.com/docs/integrations/document_loaders/youtube_transcript/
|
17 |
+
Accepts:
|
18 |
+
youtube_url: str - The url of the youtube video
|
19 |
+
Returns:
|
20 |
+
List[langchain_core.documents.Document]: A list of documents containing the youtube transcription
|
21 |
+
"""
|
22 |
loader = YoutubeLoader.from_youtube_url(
|
23 |
+
youtube_url, add_video_info=True
|
24 |
+
) # we can also pass an array of youtube urls to load multiple videos at once!
|
25 |
+
youtube_docs = loader.load() # this loads the transcript
|
26 |
return youtube_docs
|
27 |
|
28 |
def create_text_splitter(docs: List[langchain_core.documents.Document]):
|
29 |
+
"""
|
30 |
+
Create a text splitter from a list of documents
|
31 |
+
More Info: ument_transformers/recursive_text_splitter/
|
32 |
+
Accepts:
|
33 |
+
docs: List[langchain_core.documents.Document] - A list of documents to split
|
34 |
+
Returns:
|
35 |
+
List[langchain_core.documents.Document]: A list of documents containing the text split
|
36 |
+
"""
|
37 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # without an overlap, context might get cut off
|
38 |
+
docs = text_splitter.split_documents(docs) # split the documents into chunks
|
39 |
return docs
|
40 |
|
41 |
+
def create_faiss_vector_store(docs: List[langchain_core.documents.Document]) -> FAISS:
|
42 |
+
"""
|
43 |
+
Create a FAISS vector store or vector database from a list of documents
|
44 |
+
More Info: https://python.langchain.com/docs/integrations/vectorstores/faiss/
|
45 |
+
Accepts:
|
46 |
+
docs: List[langchain_core.documents.Document] - A list of documents to store
|
47 |
+
Returns:
|
48 |
+
FAISS: A vector store containing the documents
|
49 |
+
"""
|
50 |
+
embedding = cl.user_session.get("embedding") # we can get the embedding model from the user session or pass as a parameter too!
|
51 |
+
vector_db = FAISS.from_documents(docs, embedding) # create the vector store
|
52 |
+
vector_db.k = 5 # we set k to 5, so we get 5 documents back
|
53 |
return vector_db
|
54 |
|
55 |
+
def create_bm25_retreiver(docs: List[langchain_core.documents.Document]) -> BM25Retriever:
|
56 |
+
"""
|
57 |
+
Create a BM25 retriever from a list of documents
|
58 |
+
More Info: https://python.langchain.com/docs/integrations/retrievers/bm25/
|
59 |
+
Accepts:
|
60 |
+
docs: List[langchain_core.documents.Document] - A list of documents to store
|
61 |
+
Returns:
|
62 |
+
BM25Retriever: A BM25 retriever containing the documents
|
63 |
+
"""
|
64 |
+
bm25 = BM25Retriever.from_documents(docs) # we don't need embeddings for BM25, as it uses keyword matching!
|
65 |
+
bm25.k = 5 # we set k to 5, so we get 5 documents back
|
66 |
return bm25
|
67 |
|
68 |
+
def create_ensemble_retriever(vector_db:FAISS, bm25:BM25Retriever) -> EnsembleRetriever:
|
69 |
+
"""
|
70 |
+
Create an ensemble retriever from a vector db and a BM25 retriever
|
71 |
+
More Info: https://python.langchain.com/docs/how_to/ensemble_retriever/
|
72 |
+
Accepts:
|
73 |
+
vector_db: FAISS - A vector db
|
74 |
+
bm25: BM25Retriever - A BM25 retriever
|
75 |
+
Returns:
|
76 |
+
EnsembleRetriever: An ensemble retriever containing the vector db and the BM25 retriever
|
77 |
+
"""
|
78 |
+
ensemble_retreiver = EnsembleRetriever(retrievers=[vector_db.as_retriever(), bm25], weights=[.3, .7]) # 30% semantic, 70% keyword retrieval
|
79 |
return ensemble_retreiver
|
80 |
|
81 |
@cl.on_chat_start
|
82 |
async def start():
|
83 |
+
"""
|
84 |
+
More info: https://docs.chainlit.io/api-reference/lifecycle-hooks/on-chat-start
|
85 |
+
This function is called when the chat starts. Under the hood it handles all the complicated stuff for loading the UI.
|
86 |
+
We explicitly load the model, embeddings, and retrievers.
|
87 |
+
Asks the user to provide the YouTube video link and loads the transcription.
|
88 |
+
With the transcription, it creates a vector store and a BM25 vector store. That is used to create an ensemble retriever combining the two.
|
89 |
+
"""
|
90 |
+
await cl.Message(content="Hello! I am your AI assistant. I can help you with your questions about the video you provide.").send()
|
91 |
+
try: # a try catch block prevents the app from crashing if have an error
|
92 |
+
llm = ChatTogether(model="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo") # initialize the LLM model
|
93 |
+
await cl.Message(content=f"model is successfully loaded").send() # we can send messages to be displayed with cl.Message().send()
|
94 |
+
cl.user_session.set("llm", llm) # we can store variables in a special memory called the user session, so we can use them in our on message function and more
|
95 |
+
embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval") # initialize the embedding model
|
96 |
+
cl.user_session.set("embedding", embedding) # store the embedding model in the user session
|
97 |
await cl.Message(content="embedding model loaded").send()
|
98 |
+
youtube_link = await cl.AskUserMessage("Please provide the YouTube video link").send() # We can ask the user for input using cl.AskUserMessage().send() which does not affect cl.on_message()
|
99 |
+
# more on ask user message: https://docs.chainlit.io/api-reference/ask/ask-for-input
|
100 |
+
await cl.Message(content=f"youtube link: {youtube_link}").send() # display and double check to make sure the link is correct
|
101 |
+
youtube_docs = create_youtube_transcription(youtube_link['content']) # create the youtube transcription
|
102 |
+
split_docs = create_text_splitter(youtube_docs) # split the documents into chunks
|
103 |
+
vector_db = create_faiss_vector_store(split_docs) # create the vector db
|
104 |
+
bm25 = create_bm25_retreiver(split_docs) # create the BM25 retreiver
|
105 |
+
ensemble_retriever = create_ensemble_retriever(vector_db, bm25) # create the ensemble retriever
|
106 |
+
cl.user_session.set("ensemble_retriever", ensemble_retriever) # store the ensemble retriever in the user session for our on message function
|
107 |
+
transcription = youtube_docs[0].page_content # get the transcription of the first document
|
108 |
+
await cl.Message(content=f"youtube docs: {transcription}").send() # display the transcription of the first document to show that we have the correct data
|
109 |
except Exception as e:
|
110 |
+
await cl.Message(content=f"failed to load model: {e}").send() # display the error if we failed to load the model
|
111 |
|
112 |
|
113 |
@cl.on_message
|
114 |
async def message(message: cl.Message):
|
115 |
+
"""
|
116 |
+
More info: https://docs.chainlit.io/api-reference/lifecycle-hooks/on-message
|
117 |
+
This function is called when the user sends a message. It uses the ensemble retriever to find the most relevant documents and feeds them into the LLM.
|
118 |
+
We can then display the answer and the relevant documents to the user.
|
119 |
+
"""
|
120 |
prompt_template = ChatPromptTemplate.from_template(template="""
|
121 |
+
You are a helpful assistant that can answer questions about the following video. Here is the appropriate chunks of context: {context}.
|
122 |
+
Answer the question: {question} but do not use any information outside of the video. Site the source or information you used to answer the question
|
123 |
+
""") # we create a prompt template that we will use to format our prompt
|
124 |
+
llm = cl.user_session.get("llm") # we get the LLM model we initialized in the start function
|
125 |
+
ensemble_retriever = cl.user_session.get("ensemble_retriever") # we get the ensemble retriever we initialized in the start function
|
126 |
+
relevant_docs = ensemble_retriever.invoke(message.content) # we use the ensemble retriever to find the most relevant documents
|
127 |
+
cl.Message(content=f"Displaying Relevant Docs").send() # we display the relevant documents to the user
|
128 |
+
for doc in relevant_docs: # loop through the relevant documents and display each one!
|
129 |
+
await cl.Message(content=doc.page_content).send()
|
130 |
+
rag_chain = RunnableSequence({"context": ensemble_retriever, "question": RunnablePassthrough()} | prompt_template | llm)
|
131 |
+
response = rag_chain.invoke(message.content) # we invoke the rag chain with the user's message
|
132 |
+
await cl.Message(content="Done Displaying Relevant Docs").send()
|
133 |
+
await cl.Message(content=f"LLM Response: {response.content}").send() # we display the response to the user
|