import os from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import DirectoryLoader from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import Chroma import constants class Conversation: def __init__( self, message_template=constants.DEFAULT_MESSAGE_TEMPLATE, system_prompt=constants.DEFAULT_SYSTEM_PROMPT, response_template=constants.DEFAULT_RESPONSE_TEMPLATE, ): self.message_template = message_template self.response_template = response_template self.messages = [{"role": "system", "content": system_prompt}] def add_user_message(self, message): self.messages.append({"role": "user", "content": message}) def add_bot_message(self, message): self.messages.append({"role": "bot", "content": message}) def get_conversation_history(self): final_text = "" # 1st system message + last few messages (excluding system duplicate) context_and_last_few_messages = [self.messages[0]] + self.messages[1:][-constants.LAST_MESSAGES :] for message in context_and_last_few_messages: message_text = self.message_template.format(**message) final_text += message_text return final_text.strip() def source_documents(relevant_docs): source_docs = set() for doc in relevant_docs: fname = doc.metadata["source"] fname_base = os.path.splitext(os.path.basename(fname))[0] source_docs.add(fname_base) return list(source_docs) def load_raw_documents(): return DirectoryLoader(constants.DOCS_PATH, glob="*.txt").load() def build_nodes(raw_documents): return RecursiveCharacterTextSplitter( chunk_size=constants.CHUNK_SIZE, chunk_overlap=constants.CHUNK_OVERLAP, length_function=len, is_separator_regex=False, ).split_documents(raw_documents) def build_embeddings(): return HuggingFaceEmbeddings(model_name=constants.EMBED_MODEL_NAME, model_kwargs={"device": constants.DEVICE}) def build_db(nodes, embeddings): return Chroma.from_documents(nodes, embeddings) def build_retriever(): raw_documents = load_raw_documents() nodes = build_nodes(raw_documents) embeddings = build_embeddings() db = build_db(nodes, embeddings) return db.as_retriever(search_kwargs=constants.SEARCH_KWARGS, search_type=constants.SEARCH_TYPE) def fetch_relevant_nodes(question, retriever): relevant_docs = retriever.get_relevant_documents(question) context = [doc.page_content for doc in relevant_docs] source_docs = source_documents(relevant_docs) context = list(set(context)) # remove duplicated strings from context return context, source_docs def ask_question(question, conversation, model, retriever): context, source_docs = fetch_relevant_nodes(question, retriever) # add user message to conversation's context conversation.add_user_message(question) conversation_history = conversation.get_conversation_history() prompt = f"{conversation_history}\n\ {context}\n\ {constants.DEFAULT_RESPONSE_TEMPLATE}" answer = model.invoke(prompt) # add bot message to conversation's context conversation.add_bot_message(answer) return answer, source_docs