|
|
|
|
|
import os |
|
from langchain.prompts.chat import ChatPromptTemplate |
|
from langchain.memory import ConversationBufferMemory |
|
from generator import load_llm |
|
from langchain.prompts import PromptTemplate |
|
from retrieverV2 import process_pdf_document, create_vectorstore, rag_retriever |
|
from langchain.schema import format_document |
|
from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string |
|
from langchain_core.runnables import RunnableParallel |
|
from langchain_core.runnables import RunnableLambda, RunnablePassthrough |
|
from operator import itemgetter |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
from nltk.stem import PorterStemmer |
|
from nltk.corpus import stopwords |
|
nltk.download('punkt') |
|
import pickle |
|
|
|
class VectorStoreSingleton: |
|
_instance = None |
|
|
|
@classmethod |
|
def get_instance(cls): |
|
if cls._instance is None: |
|
cls._instance = create_vectorstore() |
|
return cls._instance |
|
|
|
class LanguageModelSingleton: |
|
_instance = None |
|
|
|
@classmethod |
|
def get_instance(cls): |
|
if cls._instance is None: |
|
cls._instance = load_llm() |
|
return cls._instance |
|
|
|
|
|
class ModelPipeLine: |
|
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") |
|
def __init__(self): |
|
self.curr_dir = os.path.dirname(__file__) |
|
self.knowledge_dir = 'knowledge' |
|
self.prompt_dir = 'prompts' |
|
self.child_splitter = RecursiveCharacterTextSplitter(chunk_size=200) |
|
self.parent_splitter = RecursiveCharacterTextSplitter(chunk_size=500) |
|
self._documents = None |
|
self.vectorstore, self.store = VectorStoreSingleton.get_instance() |
|
self._retriever = None |
|
self.llm = LanguageModelSingleton.get_instance() |
|
self.memory = ConversationBufferMemory(return_messages=True, output_key="answer", input_key="question") |
|
|
|
@property |
|
def documents(self): |
|
if self._documents is None: |
|
self._documents = process_pdf_document([ |
|
os.path.join(self.knowledge_dir, 'depression_1.pdf'), |
|
os.path.join(self.knowledge_dir, 'depression_2.pdf') |
|
]) |
|
return self._documents |
|
|
|
@property |
|
def retriever(self): |
|
if self._retriever is None: |
|
self._retriever = rag_retriever(self.vectorstore, self.store, self.documents, self.parent_splitter, self.child_splitter) |
|
return self._retriever |
|
|
|
def get_prompts(self, system_file_path='system_prompt_template.txt', |
|
condense_file_path='condense_question_prompt_template.txt'): |
|
|
|
with open(os.path.join(self.prompt_dir, system_file_path), 'r') as f: |
|
system_prompt_template = f.read() |
|
|
|
with open(os.path.join(self.prompt_dir, condense_file_path), 'r') as f: |
|
condense_question_prompt = f.read() |
|
|
|
|
|
ANSWER_PROMPT = ChatPromptTemplate.from_template(system_prompt_template) |
|
|
|
|
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_question_prompt) |
|
|
|
return ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT |
|
|
|
|
|
def _combine_documents(self,docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"): |
|
|
|
doc_strings = [format_document(doc, document_prompt) for doc in docs] |
|
return document_separator.join(doc_strings) |
|
|
|
def create_final_chain(self): |
|
|
|
answer_prompt, condense_question_prompt = self.get_prompts() |
|
|
|
loaded_memory = RunnablePassthrough.assign( |
|
chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"), |
|
) |
|
|
|
standalone_question = { |
|
"standalone_question": { |
|
"question": lambda x: x["question"], |
|
"chat_history": lambda x: get_buffer_string(x["chat_history"]), |
|
} |
|
| condense_question_prompt |
|
| self.llm, |
|
} |
|
|
|
retrieved_documents = { |
|
"docs": itemgetter("standalone_question") | self.retriever, |
|
"question": lambda x: x["standalone_question"], |
|
} |
|
|
|
final_inputs = { |
|
"context": lambda x: self._combine_documents(x["docs"]), |
|
"question": itemgetter("question"), |
|
} |
|
|
|
answer = { |
|
"answer": final_inputs | answer_prompt | self.llm, |
|
"docs": itemgetter("docs"), |
|
} |
|
|
|
final_chain = loaded_memory | standalone_question | retrieved_documents | answer |
|
|
|
return final_chain |
|
|
|
|
|
def call_conversational_rag(self,question, chain): |
|
""" |
|
Calls a conversational RAG (Retrieval-Augmented Generation) model to generate an answer to a given question. |
|
This function sends a question to the RAG model, retrieves the answer, and stores the question-answer pair in memory |
|
for context in future interactions. |
|
Parameters: |
|
question (str): The question to be answered by the RAG model. |
|
chain (LangChain object): An instance of LangChain which encapsulates the RAG model and its functionality. |
|
memory (Memory object): An object used for storing the context of the conversation. |
|
Returns: |
|
dict: A dictionary containing the generated answer from the RAG model. |
|
""" |
|
|
|
|
|
inputs = {"question": question} |
|
|
|
|
|
result = chain.invoke(inputs) |
|
|
|
|
|
self.memory.save_context(inputs, {"answer": result["answer"]}) |
|
|
|
|
|
return result |
|
|
|
def process_message(self, message, lower_case=True, stem=True, stop_words=True): |
|
if lower_case: |
|
message = message.lower() |
|
|
|
words = word_tokenize(message) |
|
|
|
if stop_words: |
|
sw = set(stopwords.words('english')) |
|
words = [word for word in words if word not in sw] |
|
|
|
if stem: |
|
stemmer = PorterStemmer() |
|
words = [stemmer.stem(word) for word in words] |
|
return ' '.join(words) |
|
|
|
def load_model(self): |
|
model_path = 'sentiment_classifier.pkl' |
|
with open(model_path, 'rb') as file: |
|
return pickle.load(file) |
|
|
|
def predict_classification(self, message): |
|
s_model = self.load_model() |
|
processed_msg = self.process_message(message) |
|
pred_label = s_model.predict([processed_msg]) |
|
return pred_label[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|