Spaces:
Build error
Build error
import os | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import DirectoryLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import Chroma | |
import constants | |
class Conversation: | |
def __init__( | |
self, | |
message_template=constants.DEFAULT_MESSAGE_TEMPLATE, | |
system_prompt=constants.DEFAULT_SYSTEM_PROMPT, | |
response_template=constants.DEFAULT_RESPONSE_TEMPLATE, | |
): | |
self.message_template = message_template | |
self.response_template = response_template | |
self.messages = [{"role": "system", "content": system_prompt}] | |
def add_user_message(self, message): | |
self.messages.append({"role": "user", "content": message}) | |
def add_bot_message(self, message): | |
self.messages.append({"role": "bot", "content": message}) | |
def get_conversation_history(self): | |
final_text = "" | |
# 1st system message + last few messages (excluding system duplicate) | |
context_and_last_few_messages = [self.messages[0]] + self.messages[1:][-constants.LAST_MESSAGES :] | |
for message in context_and_last_few_messages: | |
message_text = self.message_template.format(**message) | |
final_text += message_text | |
return final_text.strip() | |
def source_documents(relevant_docs): | |
source_docs = set() | |
for doc in relevant_docs: | |
fname = doc.metadata["source"] | |
fname_base = os.path.splitext(os.path.basename(fname))[0] | |
source_docs.add(fname_base) | |
return list(source_docs) | |
def load_raw_documents(): | |
return DirectoryLoader(constants.DOCS_PATH, glob="*.txt").load() | |
def build_nodes(raw_documents): | |
return RecursiveCharacterTextSplitter( | |
chunk_size=constants.CHUNK_SIZE, | |
chunk_overlap=constants.CHUNK_OVERLAP, | |
length_function=len, | |
is_separator_regex=False, | |
).split_documents(raw_documents) | |
def build_embeddings(): | |
return HuggingFaceEmbeddings(model_name=constants.EMBED_MODEL_NAME, model_kwargs={"device": constants.DEVICE}) | |
def build_db(nodes, embeddings): | |
return Chroma.from_documents(nodes, embeddings) | |
def build_retriever(): | |
raw_documents = load_raw_documents() | |
nodes = build_nodes(raw_documents) | |
embeddings = build_embeddings() | |
db = build_db(nodes, embeddings) | |
return db.as_retriever(search_kwargs=constants.SEARCH_KWARGS, search_type=constants.SEARCH_TYPE) | |
def fetch_relevant_nodes(question, retriever): | |
relevant_docs = retriever.get_relevant_documents(question) | |
context = [doc.page_content for doc in relevant_docs] | |
source_docs = source_documents(relevant_docs) | |
context = list(set(context)) # remove duplicated strings from context | |
return context, source_docs | |
def ask_question(question, conversation, model, retriever): | |
context, source_docs = fetch_relevant_nodes(question, retriever) | |
# add user message to conversation's context | |
conversation.add_user_message(question) | |
conversation_history = conversation.get_conversation_history() | |
prompt = f"{conversation_history}\n\ | |
{context}\n\ | |
{constants.DEFAULT_RESPONSE_TEMPLATE}" | |
answer = model.invoke(prompt) | |
# add bot message to conversation's context | |
conversation.add_bot_message(answer) | |
return answer, source_docs | |