File size: 1,779 Bytes
dd87c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393ef8e
 
dd87c4b
242ea7d
 
dd87c4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# chain.py

from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate

def init_conversational_chain(llm, retriever):
    """
    Initialize the Conversational Retrieval Chain with memory and custom prompt.
    
    Args:
        llm: The language model to use.
        retriever: The retriever to fetch relevant documents.
    
    Returns:
        An instance of ConversationalRetrievalChain.
    """
    # Initialize conversation memory
    memory = ConversationBufferMemory(
        return_messages=True,
        memory_key="chat_history",
        output_key="answer"
    )
    
    # Define a custom prompt template
    custom_prompt = PromptTemplate(
        input_variables=["context", "question"],
        template=(
            "You are LangChat, a knowledgeable assistant for the LangChain Python Library. "
            "Given the following context from the documentation, provide a helpful answer to the user's question. \n\n"
            "Context:\n{context}\n\n"
            "You can ignore the context if the question is a simple chat like Hi, hello, and just respond in a normal manner as LangChat, otherwise use the context to answer the query."
            "If you can't find the answer from the sources, mention that clearly instead of making up an answer.\n\n"
            "Question: {question}\n\n"
            "Answer:"
        )
    )

    # Initialize the Conversational Retrieval Chain
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm=llm,
        retriever=retriever,
        memory=memory,
        return_source_documents=True,
        combine_docs_chain_kwargs={"prompt": custom_prompt},
        verbose=False
    )
    return qa_chain