File size: 3,653 Bytes
772e8d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import chainlit as cl 
from langchain_together import ChatTogether, TogetherEmbeddings
from langchain_core.runnables import RunnableSequence, RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.document_loaders import YoutubeLoader
from typing import List
import langchain_core
from langchain_community.vectorstores import FAISS
from langchain.retrievers.ensemble import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain_text_splitters import RecursiveCharacterTextSplitter

def create_youtube_transcription(youtube_url: str):
    loader = YoutubeLoader.from_youtube_url(
    youtube_url, add_video_info=False
    )
    youtube_docs = loader.load()
    return youtube_docs

def create_text_splitter(docs: List[langchain_core.documents.Document]):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    docs = text_splitter.split_documents(docs)
    return docs

def create_vector_store(docs: List[langchain_core.documents.Document]):
    embedding = cl.user_session.get("embedding")
    vector_db = FAISS.from_documents(docs, embedding)
    return vector_db

def create_bm25_vector_store(docs: List[langchain_core.documents.Document]):
    bm25 = BM25Retriever.from_documents(docs)
    return bm25

def create_ensemble_retriever(vector_db:FAISS, bm25:BM25Retriever):
    ensemble_retreiver = EnsembleRetriever(retrievers=[vector_db.as_retriever(), bm25], weights=[.3, .7])
    return ensemble_retreiver

@cl.on_chat_start
async def start():
    await cl.Message(content="my name is josh!").send()
    try: 
        llm = ChatTogether(model="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo")
        await cl.Message(content=f"model is successfully loaded").send()
        cl.user_session.set("llm", llm)
        embedding = TogetherEmbeddings(model="togethercomputer/m2-bert-80M-8k-retrieval")
        cl.user_session.set("embedding", embedding)
        await cl.Message(content="embedding model loaded").send()
        youtube_link = await cl.AskUserMessage("Please provide the YouTube video link").send()
        youtube_docs = create_youtube_transcription(youtube_link['output'])
        split_docs = create_text_splitter(youtube_docs)
        vector_db = create_vector_store(split_docs)
        bm25 = create_bm25_vector_store(split_docs)
        ensemble_retriever = create_ensemble_retriever(vector_db, bm25)
        cl.user_session.set("ensemble_retriever", ensemble_retriever)
        transcription = youtube_docs[0].page_content
        await cl.Message(content=f"youtube docs: {transcription}").send()
        cl.user_session.set("transcription", transcription)
    except Exception as e:
        await cl.Message(content=f"failed to load model: {e}").send()


@cl.on_message
async def message(message: cl.Message):
    prompt_template = ChatPromptTemplate.from_template(template="""
    You are a helpful assistant that can answer questions about the following video. Here is the appropriate chunks of context: {context}.
    Answer the question: {question} but do not use any information outside of the video. Site the source or information you used to answer the question
    """)
    llm = cl.user_session.get("llm")
    vector_db = cl.user_session.get("vector_db")
    transcription = cl.user_session.get("transcription")
    ensemble_retriever = cl.user_session.get("ensemble_retriever")
    rag_chain = RunnableSequence({"context": ensemble_retriever, "question": RunnablePassthrough()}, prompt_template | llm)
    response = rag_chain.invoke(message.content)
    await cl.Message(content=response.content).send()