JoshuaKelleyDs's picture
Update app.py
f7e5fb4 verified
import chainlit as cl # handles the chat interface
from langchain_together import ChatTogether, TogetherEmbeddings # for the LLM and Embeddings
from langchain_core.runnables import RunnableSequence, RunnablePassthrough # for chain execution
from langchain_core.prompts import ChatPromptTemplate # for writing the prompt template
from langchain_community.document_loaders import YoutubeLoader # for loading the youtube video
from typing import List # for type hinting
import langchain_core # for type hinting
from langchain_community.vectorstores import FAISS # for the vector store
from langchain_community.retrievers import BM25Retriever # for the BM25 retriever
from langchain.retrievers.ensemble import EnsembleRetriever # for the ensemble retriever
from langchain_text_splitters import RecursiveCharacterTextSplitter # for the text splitter
######## Chainlit ########
@cl.on_chat_start
async def start():
"""
More info: https://docs.chainlit.io/api-reference/lifecycle-hooks/on-chat-start
This function is called when the chat starts. Under the hood it handles all the complicated stuff for loading the UI.
We explicitly load the model, embeddings, and retrievers.
Asks the user to provide the YouTube video link and loads the transcription.
With the transcription, it creates a vector store and a BM25 vector store. That is used to create an ensemble retriever combining the two.
"""
await cl.Message(content="Hello! I am your AI assistant. I can help you with your questions about the video you provide.").send()
try: # a try catch block prevents the app from crashing if have an error
llm = ChatTogether(model="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo") # initialize the LLM model
await cl.Message(content=f"model is successfully loaded").send() # we can send messages to be displayed with cl.Message().send()
cl.user_session.set("llm", llm) # we can store variables in a special memory called the user session, so we can use them in our on message function and more
embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval") # initialize the embedding model
cl.user_session.set("embedding", embedding) # store the embedding model in the user session
await cl.Message(content="embedding model loaded").send()
youtube_link = await cl.AskUserMessage("Please provide the YouTube video link").send() # We can ask the user for input using cl.AskUserMessage().send() which does not affect cl.on_message()
# more on ask user message: https://docs.chainlit.io/api-reference/ask/ask-for-input
await cl.Message(content=f"youtube link: {youtube_link['content']}").send() # display and double check to make sure the link is correct
youtube_docs = await create_youtube_transcription(youtube_link['content']) # create the youtube transcription
transcription = youtube_docs # get the transcription of the first document
await cl.Message(content=f"youtube docs: {transcription}").send() # display the transcription of the first document to show that we have the correct data
split_docs = await create_text_splitter(youtube_docs) # split the documents into chunks
vector_db = await create_faiss_vector_store(split_docs) # create the vector db
bm25 = await create_bm25_retreiver(split_docs) # create the BM25 retreiver
ensemble_retriever = await create_ensemble_retriever(vector_db, bm25) # create the ensemble retriever
cl.user_session.set("ensemble_retriever", ensemble_retriever) # store the ensemble retriever in the user session for our on message function
except Exception as e:
await cl.Message(content=f"failed to load model: {e}").send() # display the error if we failed to load the model
@cl.on_message
async def message(message: cl.Message):
"""
More info: https://docs.chainlit.io/api-reference/lifecycle-hooks/on-message
This function is called when the user sends a message. It uses the ensemble retriever to find the most relevant documents and feeds them into the LLM.
We can then display the answer and the relevant documents to the user.
"""
prompt_template = ChatPromptTemplate.from_template(template="""
You are a helpful assistant that can answer questions about the following video. Here is the appropriate chunks of context: {context}.
Answer the question: {question} but do not use any information outside of the video. Site the source or information you used to answer the question
""") # we create a prompt template that we will use to format our prompt
llm = cl.user_session.get("llm") # we get the LLM model we initialized in the start function
ensemble_retriever = cl.user_session.get("ensemble_retriever") # we get the ensemble retriever we initialized in the start function
relevant_docs = ensemble_retriever.invoke(message.content) # we use the ensemble retriever to find the most relevant documents
cl.Message(content=f"Displaying Relevant Docs").send() # we display the relevant documents to the user
for doc in relevant_docs: # loop through the relevant documents and display each one!
await cl.Message(content=doc.page_content).send()
await cl.Message(content="Done Displaying Relevant Docs").send()
# question -> retrieve relevant docs -> format the question and context and add it to the prompt template -> pass to LLM
rag_chain = RunnableSequence({"context": ensemble_retriever, "question": RunnablePassthrough()} | prompt_template | llm)
response = rag_chain.invoke(message.content) # we invoke the rag chain with the user's message
await cl.Message(content=f"LLM Response: {response.content}").send() # we display the response to the user
######## Youtube ########
async def create_youtube_transcription(youtube_url: str) -> List[langchain_core.documents.Document]:
"""
Create a youtube transcription from a youtube url
More Info: https://python.langchain.com/docs/integrations/document_loaders/youtube_transcript/
Accepts:
youtube_url: str - The url of the youtube video
Returns:
List[langchain_core.documents.Document]: A list of documents containing the youtube transcription
"""
try:
loader = YoutubeLoader.from_youtube_url(
youtube_url, add_video_info=False
) # we can also pass an array of youtube urls to load multiple videos at once!
youtube_docs = loader.load() # this loads the transcript
return youtube_docs
except Exception as e:
await cl.Message(content=f"Error: {e} Please refresh the page").send() # display the error if we failed to load the youtube video
######## RAG ########
async def create_text_splitter(docs: List[langchain_core.documents.Document]) -> List[langchain_core.documents.Document]:
"""
Create a text splitter from a list of documents
More Info: ument_transformers/recursive_text_splitter/
Accepts:
docs: List[langchain_core.documents.Document] - A list of documents to split
Returns:
List[langchain_core.documents.Document]: A list of documents containing the text split
"""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # without an overlap, context might get cut off
docs = text_splitter.split_documents(docs) # split the documents into chunks
return docs
async def create_faiss_vector_store(docs: List[langchain_core.documents.Document]) -> FAISS:
"""
Create a FAISS vector store or vector database from a list of documents
More Info: https://python.langchain.com/docs/integrations/vectorstores/faiss/
Accepts:
docs: List[langchain_core.documents.Document] - A list of documents to store
Returns:
FAISS: A vector store containing the documents
"""
try:
embedding = cl.user_session.get("embedding") # we can get the embedding model from the user session or pass as a parameter too!
vector_db = FAISS.from_documents(docs, embedding) # create the vector store
vector_db.k = 5 # we set k to 5, so we get 5 documents back
return vector_db
except Exception as e:
await cl.Message(content=f"failed to create vector db: {e}").send() # display the error if we failed to create the vector db
async def create_bm25_retreiver(docs: List[langchain_core.documents.Document]) -> BM25Retriever:
"""
Create a BM25 retriever from a list of documents
More Info: https://python.langchain.com/docs/integrations/retrievers/bm25/
Accepts:
docs: List[langchain_core.documents.Document] - A list of documents to store
Returns:
BM25Retriever: A BM25 retriever containing the documents
"""
try:
bm25 = BM25Retriever.from_documents(docs) # we don't need embeddings for BM25, as it uses keyword matching!
bm25.k = 5 # we set k to 5, so we get 5 documents back
return bm25
except Exception as e:
await cl.Message(content=f"failed to create BM25 retreiver: {e}").send() # display the error if we failed to create the BM25 retreiver
async def create_ensemble_retriever(vector_db:FAISS, bm25:BM25Retriever) -> EnsembleRetriever:
"""
Create an ensemble retriever from a vector db and a BM25 retriever
More Info: https://python.langchain.com/docs/how_to/ensemble_retriever/
Accepts:
vector_db: FAISS - A vector db
bm25: BM25Retriever - A BM25 retriever
Returns:
EnsembleRetriever: An ensemble retriever containing the vector db and the BM25 retriever
"""
try:
ensemble_retreiver = EnsembleRetriever(retrievers=[vector_db.as_retriever(), bm25], weights=[.3, .7]) # 30% semantic, 70% keyword retrieval
return ensemble_retreiver
except Exception as e:
await cl.Message(content=f"failed to create ensemble retriever: {e}").send() # display the error if we failed to create the ensemble retriever