Spaces:
Runtime error
Runtime error
import chainlit as cl # handles the chat interface | |
from langchain_together import ChatTogether, TogetherEmbeddings # for the LLM and Embeddings | |
from langchain_core.runnables import RunnableSequence, RunnablePassthrough # for chain execution | |
from langchain_core.prompts import ChatPromptTemplate # for writing the prompt template | |
from langchain_community.document_loaders import YoutubeLoader # for loading the youtube video | |
from typing import List # for type hinting | |
import langchain_core # for type hinting | |
from langchain_community.vectorstores import FAISS # for the vector store | |
from langchain_community.retrievers import BM25Retriever # for the BM25 retriever | |
from langchain.retrievers.ensemble import EnsembleRetriever # for the ensemble retriever | |
from langchain_text_splitters import RecursiveCharacterTextSplitter # for the text splitter | |
######## Chainlit ######## | |
async def start(): | |
""" | |
More info: https://docs.chainlit.io/api-reference/lifecycle-hooks/on-chat-start | |
This function is called when the chat starts. Under the hood it handles all the complicated stuff for loading the UI. | |
We explicitly load the model, embeddings, and retrievers. | |
Asks the user to provide the YouTube video link and loads the transcription. | |
With the transcription, it creates a vector store and a BM25 vector store. That is used to create an ensemble retriever combining the two. | |
""" | |
await cl.Message(content="Hello! I am your AI assistant. I can help you with your questions about the video you provide.").send() | |
try: # a try catch block prevents the app from crashing if have an error | |
llm = ChatTogether(model="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo") # initialize the LLM model | |
await cl.Message(content=f"model is successfully loaded").send() # we can send messages to be displayed with cl.Message().send() | |
cl.user_session.set("llm", llm) # we can store variables in a special memory called the user session, so we can use them in our on message function and more | |
embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval") # initialize the embedding model | |
cl.user_session.set("embedding", embedding) # store the embedding model in the user session | |
await cl.Message(content="embedding model loaded").send() | |
youtube_link = await cl.AskUserMessage("Please provide the YouTube video link").send() # We can ask the user for input using cl.AskUserMessage().send() which does not affect cl.on_message() | |
# more on ask user message: https://docs.chainlit.io/api-reference/ask/ask-for-input | |
await cl.Message(content=f"youtube link: {youtube_link['content']}").send() # display and double check to make sure the link is correct | |
youtube_docs = await create_youtube_transcription(youtube_link['content']) # create the youtube transcription | |
transcription = youtube_docs # get the transcription of the first document | |
await cl.Message(content=f"youtube docs: {transcription}").send() # display the transcription of the first document to show that we have the correct data | |
split_docs = await create_text_splitter(youtube_docs) # split the documents into chunks | |
vector_db = await create_faiss_vector_store(split_docs) # create the vector db | |
bm25 = await create_bm25_retreiver(split_docs) # create the BM25 retreiver | |
ensemble_retriever = await create_ensemble_retriever(vector_db, bm25) # create the ensemble retriever | |
cl.user_session.set("ensemble_retriever", ensemble_retriever) # store the ensemble retriever in the user session for our on message function | |
except Exception as e: | |
await cl.Message(content=f"failed to load model: {e}").send() # display the error if we failed to load the model | |
async def message(message: cl.Message): | |
""" | |
More info: https://docs.chainlit.io/api-reference/lifecycle-hooks/on-message | |
This function is called when the user sends a message. It uses the ensemble retriever to find the most relevant documents and feeds them into the LLM. | |
We can then display the answer and the relevant documents to the user. | |
""" | |
prompt_template = ChatPromptTemplate.from_template(template=""" | |
You are a helpful assistant that can answer questions about the following video. Here is the appropriate chunks of context: {context}. | |
Answer the question: {question} but do not use any information outside of the video. Site the source or information you used to answer the question | |
""") # we create a prompt template that we will use to format our prompt | |
llm = cl.user_session.get("llm") # we get the LLM model we initialized in the start function | |
ensemble_retriever = cl.user_session.get("ensemble_retriever") # we get the ensemble retriever we initialized in the start function | |
relevant_docs = ensemble_retriever.invoke(message.content) # we use the ensemble retriever to find the most relevant documents | |
cl.Message(content=f"Displaying Relevant Docs").send() # we display the relevant documents to the user | |
for doc in relevant_docs: # loop through the relevant documents and display each one! | |
await cl.Message(content=doc.page_content).send() | |
await cl.Message(content="Done Displaying Relevant Docs").send() | |
# question -> retrieve relevant docs -> format the question and context and add it to the prompt template -> pass to LLM | |
rag_chain = RunnableSequence({"context": ensemble_retriever, "question": RunnablePassthrough()} | prompt_template | llm) | |
response = rag_chain.invoke(message.content) # we invoke the rag chain with the user's message | |
await cl.Message(content=f"LLM Response: {response.content}").send() # we display the response to the user | |
######## Youtube ######## | |
async def create_youtube_transcription(youtube_url: str) -> List[langchain_core.documents.Document]: | |
""" | |
Create a youtube transcription from a youtube url | |
More Info: https://python.langchain.com/docs/integrations/document_loaders/youtube_transcript/ | |
Accepts: | |
youtube_url: str - The url of the youtube video | |
Returns: | |
List[langchain_core.documents.Document]: A list of documents containing the youtube transcription | |
""" | |
try: | |
loader = YoutubeLoader.from_youtube_url( | |
youtube_url, add_video_info=False | |
) # we can also pass an array of youtube urls to load multiple videos at once! | |
youtube_docs = loader.load() # this loads the transcript | |
return youtube_docs | |
except Exception as e: | |
await cl.Message(content=f"failed to load youtube video: {e} Please refresh the page").send() # display the error if we failed to load the youtube video | |
######## RAG ######## | |
async def create_text_splitter(docs: List[langchain_core.documents.Document]) -> List[langchain_core.documents.Document]: | |
""" | |
Create a text splitter from a list of documents | |
More Info: ument_transformers/recursive_text_splitter/ | |
Accepts: | |
docs: List[langchain_core.documents.Document] - A list of documents to split | |
Returns: | |
List[langchain_core.documents.Document]: A list of documents containing the text split | |
""" | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # without an overlap, context might get cut off | |
docs = text_splitter.split_documents(docs) # split the documents into chunks | |
return docs | |
async def create_faiss_vector_store(docs: List[langchain_core.documents.Document]) -> FAISS: | |
""" | |
Create a FAISS vector store or vector database from a list of documents | |
More Info: https://python.langchain.com/docs/integrations/vectorstores/faiss/ | |
Accepts: | |
docs: List[langchain_core.documents.Document] - A list of documents to store | |
Returns: | |
FAISS: A vector store containing the documents | |
""" | |
try: | |
embedding = cl.user_session.get("embedding") # we can get the embedding model from the user session or pass as a parameter too! | |
vector_db = FAISS.from_documents(docs, embedding) # create the vector store | |
vector_db.k = 5 # we set k to 5, so we get 5 documents back | |
return vector_db | |
except Exception as e: | |
await cl.Message(content=f"failed to create vector db: {e}").send() # display the error if we failed to create the vector db | |
async def create_bm25_retreiver(docs: List[langchain_core.documents.Document]) -> BM25Retriever: | |
""" | |
Create a BM25 retriever from a list of documents | |
More Info: https://python.langchain.com/docs/integrations/retrievers/bm25/ | |
Accepts: | |
docs: List[langchain_core.documents.Document] - A list of documents to store | |
Returns: | |
BM25Retriever: A BM25 retriever containing the documents | |
""" | |
try: | |
bm25 = BM25Retriever.from_documents(docs) # we don't need embeddings for BM25, as it uses keyword matching! | |
bm25.k = 5 # we set k to 5, so we get 5 documents back | |
return bm25 | |
except Exception as e: | |
await cl.Message(content=f"failed to create BM25 retreiver: {e}").send() # display the error if we failed to create the BM25 retreiver | |
async def create_ensemble_retriever(vector_db:FAISS, bm25:BM25Retriever) -> EnsembleRetriever: | |
""" | |
Create an ensemble retriever from a vector db and a BM25 retriever | |
More Info: https://python.langchain.com/docs/how_to/ensemble_retriever/ | |
Accepts: | |
vector_db: FAISS - A vector db | |
bm25: BM25Retriever - A BM25 retriever | |
Returns: | |
EnsembleRetriever: An ensemble retriever containing the vector db and the BM25 retriever | |
""" | |
try: | |
ensemble_retreiver = EnsembleRetriever(retrievers=[vector_db.as_retriever(), bm25], weights=[.3, .7]) # 30% semantic, 70% keyword retrieval | |
return ensemble_retreiver | |
except Exception as e: | |
await cl.Message(content=f"failed to create ensemble retriever: {e}").send() # display the error if we failed to create the ensemble retriever |