File size: 9,989 Bytes
f8ae825
 
 
 
 
 
 
 
 
 
 
772e8d9
c63c9d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ccac939
f8ae825
 
 
 
 
 
 
 
ccac939
 
0cb6ea1
ccac939
 
 
 
f7e5fb4
772e8d9
c63c9d3
 
 
 
f8ae825
 
 
 
 
 
 
 
 
 
772e8d9
 
6257fb8
f8ae825
 
 
 
 
 
 
 
6257fb8
 
 
 
 
 
 
772e8d9
d60224c
f8ae825
 
 
 
 
 
 
 
d60224c
 
 
 
 
 
772e8d9
d60224c
f8ae825
 
 
 
 
 
 
 
 
d60224c
 
 
 
c63c9d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import chainlit as cl # handles the chat interface
from langchain_together import ChatTogether, TogetherEmbeddings # for the LLM and Embeddings
from langchain_core.runnables import RunnableSequence, RunnablePassthrough # for chain execution
from langchain_core.prompts import ChatPromptTemplate # for writing the prompt template
from langchain_community.document_loaders import YoutubeLoader # for loading the youtube video
from typing import List # for type hinting
import langchain_core # for type hinting
from langchain_community.vectorstores import FAISS # for the vector store
from langchain_community.retrievers import BM25Retriever # for the BM25 retriever
from langchain.retrievers.ensemble import EnsembleRetriever # for the ensemble retriever
from langchain_text_splitters import RecursiveCharacterTextSplitter # for the text splitter



######## Chainlit ########
@cl.on_chat_start
async def start():
    """
    More info: https://docs.chainlit.io/api-reference/lifecycle-hooks/on-chat-start
    This function is called when the chat starts. Under the hood it handles all the complicated stuff for loading the UI. 
    We explicitly load the model, embeddings, and retrievers.
    Asks the user to provide the YouTube video link and loads the transcription.
    With the transcription, it creates a vector store and a BM25 vector store. That is used to create an ensemble retriever combining the two.
    """
    await cl.Message(content="Hello! I am your AI assistant. I can help you with your questions about the video you provide.").send()
    try: # a try catch block prevents the app from crashing if have an error
        llm = ChatTogether(model="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo") # initialize the LLM model
        await cl.Message(content=f"model is successfully loaded").send() # we can send messages to be displayed with cl.Message().send()
        cl.user_session.set("llm", llm) # we can store variables in a special memory called the user session, so we can use them in our on message function and more
        embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval") # initialize the embedding model
        cl.user_session.set("embedding", embedding) # store the embedding model in the user session
        await cl.Message(content="embedding model loaded").send()
        youtube_link = await cl.AskUserMessage("Please provide the YouTube video link").send() # We can ask the user for input using cl.AskUserMessage().send() which does not affect cl.on_message()
        # more on ask user message: https://docs.chainlit.io/api-reference/ask/ask-for-input
        
        await cl.Message(content=f"youtube link: {youtube_link['content']}").send() # display and double check to make sure the link is correct
        youtube_docs = await create_youtube_transcription(youtube_link['content']) # create the youtube transcription
        transcription = youtube_docs # get the transcription of the first document
        await cl.Message(content=f"youtube docs: {transcription}").send() # display the transcription of the first document to show that we have the correct data
        split_docs = await create_text_splitter(youtube_docs) # split the documents into chunks
        vector_db = await create_faiss_vector_store(split_docs) # create the vector db
        bm25 = await create_bm25_retreiver(split_docs) # create the BM25 retreiver
        ensemble_retriever = await create_ensemble_retriever(vector_db, bm25) # create the ensemble retriever
        cl.user_session.set("ensemble_retriever", ensemble_retriever) # store the ensemble retriever in the user session for our on message function
    except Exception as e:
        await cl.Message(content=f"failed to load model: {e}").send() # display the error if we failed to load the model

@cl.on_message
async def message(message: cl.Message):
    """
    More info: https://docs.chainlit.io/api-reference/lifecycle-hooks/on-message
    This function is called when the user sends a message. It uses the ensemble retriever to find the most relevant documents and feeds them into the LLM.
    We can then display the answer and the relevant documents to the user.
    """
    prompt_template = ChatPromptTemplate.from_template(template="""
        You are a helpful assistant that can answer questions about the following video. Here is the appropriate chunks of context: {context}.
        Answer the question: {question} but do not use any information outside of the video. Site the source or information you used to answer the question
    """) # we create a prompt template that we will use to format our prompt
    llm = cl.user_session.get("llm") # we get the LLM model we initialized in the start function
    ensemble_retriever = cl.user_session.get("ensemble_retriever") # we get the ensemble retriever we initialized in the start function
    relevant_docs = ensemble_retriever.invoke(message.content) # we use the ensemble retriever to find the most relevant documents
    cl.Message(content=f"Displaying Relevant Docs").send() # we display the relevant documents to the user
    for doc in relevant_docs: # loop through the relevant documents and display each one!
        await cl.Message(content=doc.page_content).send() 
    await cl.Message(content="Done Displaying Relevant Docs").send()
    # question -> retrieve relevant docs -> format the question and context and add it to the prompt template -> pass to LLM 
    rag_chain = RunnableSequence({"context": ensemble_retriever, "question": RunnablePassthrough()} | prompt_template | llm) 
    response = rag_chain.invoke(message.content) # we invoke the rag chain with the user's message
    await cl.Message(content=f"LLM Response: {response.content}").send() # we display the response to the user

######## Youtube ########

async def create_youtube_transcription(youtube_url: str) -> List[langchain_core.documents.Document]:
    """
    Create a youtube transcription from a youtube url
    More Info: https://python.langchain.com/docs/integrations/document_loaders/youtube_transcript/
    Accepts:
        youtube_url: str - The url of the youtube video
    Returns:
        List[langchain_core.documents.Document]: A list of documents containing the youtube transcription
    """
    try:
        loader = YoutubeLoader.from_youtube_url(
            youtube_url, add_video_info=False
        ) # we can also pass an array of youtube urls to load multiple videos at once!
        youtube_docs = loader.load() # this loads the transcript
        return youtube_docs
    except Exception as e:
        await cl.Message(content=f"Error: {e} Please refresh the page").send() # display the error if we failed to load the youtube video


######## RAG ########

async def create_text_splitter(docs: List[langchain_core.documents.Document]) -> List[langchain_core.documents.Document]:
    """
    Create a text splitter from a list of documents
    More Info: ument_transformers/recursive_text_splitter/
    Accepts:
        docs: List[langchain_core.documents.Document] - A list of documents to split
    Returns:
        List[langchain_core.documents.Document]: A list of documents containing the text split
    """
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) # without an overlap, context might get cut off
    docs = text_splitter.split_documents(docs) # split the documents into chunks
    return docs

async def create_faiss_vector_store(docs: List[langchain_core.documents.Document]) -> FAISS:
    """
    Create a FAISS vector store or vector database from a list of documents
    More Info: https://python.langchain.com/docs/integrations/vectorstores/faiss/
    Accepts:
        docs: List[langchain_core.documents.Document] - A list of documents to store
    Returns:
        FAISS: A vector store containing the documents
    """
    try:
        embedding = cl.user_session.get("embedding") # we can get the embedding model from the user session or pass as a parameter too!
        vector_db = FAISS.from_documents(docs, embedding) # create the vector store
        vector_db.k = 5 # we set k to 5, so we get 5 documents back
        return vector_db
    except Exception as e:
        await cl.Message(content=f"failed to create vector db: {e}").send() # display the error if we failed to create the vector db

async def create_bm25_retreiver(docs: List[langchain_core.documents.Document]) -> BM25Retriever:
    """
    Create a BM25 retriever from a list of documents
    More Info: https://python.langchain.com/docs/integrations/retrievers/bm25/
    Accepts:
        docs: List[langchain_core.documents.Document] - A list of documents to store
    Returns:
        BM25Retriever: A BM25 retriever containing the documents
    """
    try:
        bm25 = BM25Retriever.from_documents(docs) # we don't need embeddings for BM25, as it uses keyword matching!
        bm25.k = 5 # we set k to 5, so we get 5 documents back
        return bm25
    except Exception as e:
        await cl.Message(content=f"failed to create BM25 retreiver: {e}").send() # display the error if we failed to create the BM25 retreiver

async def create_ensemble_retriever(vector_db:FAISS, bm25:BM25Retriever) -> EnsembleRetriever:
    """
    Create an ensemble retriever from a vector db and a BM25 retriever
    More Info: https://python.langchain.com/docs/how_to/ensemble_retriever/
    Accepts:
        vector_db: FAISS - A vector db
        bm25: BM25Retriever - A BM25 retriever
    Returns:
        EnsembleRetriever: An ensemble retriever containing the vector db and the BM25 retriever
    """
    try:
        ensemble_retreiver = EnsembleRetriever(retrievers=[vector_db.as_retriever(), bm25], weights=[.3, .7]) # 30% semantic, 70% keyword retrieval
        return ensemble_retreiver
    except Exception as e:
        await cl.Message(content=f"failed to create ensemble retriever: {e}").send() # display the error if we failed to create the ensemble retriever