File size: 3,401 Bytes
11f324c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma

import constants


class Conversation:
    def __init__(
        self,
        message_template=constants.DEFAULT_MESSAGE_TEMPLATE,
        system_prompt=constants.DEFAULT_SYSTEM_PROMPT,
        response_template=constants.DEFAULT_RESPONSE_TEMPLATE,
    ):
        self.message_template = message_template
        self.response_template = response_template
        self.messages = [{"role": "system", "content": system_prompt}]

    def add_user_message(self, message):
        self.messages.append({"role": "user", "content": message})

    def add_bot_message(self, message):
        self.messages.append({"role": "bot", "content": message})

    def get_conversation_history(self):
        final_text = ""
        # 1st system message + last few messages (excluding system duplicate)
        context_and_last_few_messages = [self.messages[0]] + self.messages[1:][-constants.LAST_MESSAGES :]
        for message in context_and_last_few_messages:
            message_text = self.message_template.format(**message)
            final_text += message_text
        return final_text.strip()


def source_documents(relevant_docs):
    source_docs = set()
    for doc in relevant_docs:
        fname = doc.metadata["source"]
        fname_base = os.path.splitext(os.path.basename(fname))[0]
        source_docs.add(fname_base)
    return list(source_docs)


def load_raw_documents():
    return DirectoryLoader(constants.DOCS_PATH, glob="*.txt").load()


def build_nodes(raw_documents):
    return RecursiveCharacterTextSplitter(
        chunk_size=constants.CHUNK_SIZE,
        chunk_overlap=constants.CHUNK_OVERLAP,
        length_function=len,
        is_separator_regex=False,
    ).split_documents(raw_documents)


def build_embeddings():
    return HuggingFaceEmbeddings(model_name=constants.EMBED_MODEL_NAME, model_kwargs={"device": constants.DEVICE})


def build_db(nodes, embeddings):
    return Chroma.from_documents(nodes, embeddings)


def build_retriever():
    raw_documents = load_raw_documents()
    nodes = build_nodes(raw_documents)
    embeddings = build_embeddings()
    db = build_db(nodes, embeddings)
    return db.as_retriever(search_kwargs=constants.SEARCH_KWARGS, search_type=constants.SEARCH_TYPE)


def fetch_relevant_nodes(question, retriever):
    relevant_docs = retriever.get_relevant_documents(question)
    context = [doc.page_content for doc in relevant_docs]
    source_docs = source_documents(relevant_docs)
    context = list(set(context))  # remove duplicated strings from context
    return context, source_docs


def ask_question(question, conversation, model, retriever):

    context, source_docs = fetch_relevant_nodes(question, retriever)

    # add user message to conversation's context
    conversation.add_user_message(question)
    conversation_history = conversation.get_conversation_history()
    prompt = f"{conversation_history}\n\
               {context}\n\
               {constants.DEFAULT_RESPONSE_TEMPLATE}"
                    

    answer = model.invoke(prompt)
    # add bot message to conversation's context
    conversation.add_bot_message(answer)

    return answer, source_docs