talk2docs / utils.py
olegperegudov's picture
wip
11f324c
raw
history blame contribute delete
No virus
3.4 kB
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
import constants
class Conversation:
def __init__(
self,
message_template=constants.DEFAULT_MESSAGE_TEMPLATE,
system_prompt=constants.DEFAULT_SYSTEM_PROMPT,
response_template=constants.DEFAULT_RESPONSE_TEMPLATE,
):
self.message_template = message_template
self.response_template = response_template
self.messages = [{"role": "system", "content": system_prompt}]
def add_user_message(self, message):
self.messages.append({"role": "user", "content": message})
def add_bot_message(self, message):
self.messages.append({"role": "bot", "content": message})
def get_conversation_history(self):
final_text = ""
# 1st system message + last few messages (excluding system duplicate)
context_and_last_few_messages = [self.messages[0]] + self.messages[1:][-constants.LAST_MESSAGES :]
for message in context_and_last_few_messages:
message_text = self.message_template.format(**message)
final_text += message_text
return final_text.strip()
def source_documents(relevant_docs):
source_docs = set()
for doc in relevant_docs:
fname = doc.metadata["source"]
fname_base = os.path.splitext(os.path.basename(fname))[0]
source_docs.add(fname_base)
return list(source_docs)
def load_raw_documents():
return DirectoryLoader(constants.DOCS_PATH, glob="*.txt").load()
def build_nodes(raw_documents):
return RecursiveCharacterTextSplitter(
chunk_size=constants.CHUNK_SIZE,
chunk_overlap=constants.CHUNK_OVERLAP,
length_function=len,
is_separator_regex=False,
).split_documents(raw_documents)
def build_embeddings():
return HuggingFaceEmbeddings(model_name=constants.EMBED_MODEL_NAME, model_kwargs={"device": constants.DEVICE})
def build_db(nodes, embeddings):
return Chroma.from_documents(nodes, embeddings)
def build_retriever():
raw_documents = load_raw_documents()
nodes = build_nodes(raw_documents)
embeddings = build_embeddings()
db = build_db(nodes, embeddings)
return db.as_retriever(search_kwargs=constants.SEARCH_KWARGS, search_type=constants.SEARCH_TYPE)
def fetch_relevant_nodes(question, retriever):
relevant_docs = retriever.get_relevant_documents(question)
context = [doc.page_content for doc in relevant_docs]
source_docs = source_documents(relevant_docs)
context = list(set(context)) # remove duplicated strings from context
return context, source_docs
def ask_question(question, conversation, model, retriever):
context, source_docs = fetch_relevant_nodes(question, retriever)
# add user message to conversation's context
conversation.add_user_message(question)
conversation_history = conversation.get_conversation_history()
prompt = f"{conversation_history}\n\
{context}\n\
{constants.DEFAULT_RESPONSE_TEMPLATE}"
answer = model.invoke(prompt)
# add bot message to conversation's context
conversation.add_bot_message(answer)
return answer, source_docs