RAQA-from-Scratch / llmops /retrieval_pipeline.py
Megatron17's picture
Upload 20 files
5623f53
raw
history blame
No virus
3.37 kB
from llmops.openai_utils.chatmodel import ChatOpenAI
from llmops.vectordatabase import VectorDatabase
from llmops.openai_utils.prompts import (
UserRolePrompt,
SystemRolePrompt,
AssistantRolePrompt,
)
import datetime
from wandb.sdk.data_types.trace_tree import Trace
class RetrievalAugmentedQAPipeline:
"""
"""
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
def run_pipeline(self, user_query: str, raqa_prompt:SystemRolePrompt, user_prompt:UserRolePrompt) -> str:
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
context_prompt = ""
for context in context_list:
context_prompt += context[0] + "\n"
formatted_system_prompt = raqa_prompt.create_message(context=context_prompt)
formatted_user_prompt = user_prompt.create_message(user_query=user_query)
return self.llm.run([formatted_system_prompt, formatted_user_prompt])
class WandB_RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase, wandb_project = None) -> None:
self.llm = llm
self.vector_db_retriever = vector_db_retriever
self.wandb_project = wandb_project
def run_pipeline(self, user_query: str, raqa_prompt:SystemRolePrompt, user_prompt:UserRolePrompt) -> str:
context_list = self.vector_db_retriever.search_by_text(user_query, k=4)
context_prompt = ""
for context in context_list:
context_prompt += context[0] + "\n"
formatted_system_prompt = raqa_prompt.create_message(context=context_prompt)
formatted_user_prompt = user_prompt.create_message(user_query=user_query)
start_time = datetime.datetime.now().timestamp() * 1000
try:
openai_response = self.llm.run([formatted_system_prompt, formatted_user_prompt], text_only=False)
end_time = datetime.datetime.now().timestamp() * 1000
status = "success"
status_message = (None, )
response_text = openai_response.choices[0].message.content
token_usage = openai_response["usage"].to_dict()
model = openai_response["model"]
except Exception as e:
end_time = datetime.datetime.now().timestamp() * 1000
status = "error"
status_message = str(e)
response_text = ""
token_usage = {}
model = ""
if self.wandb_project:
root_span = Trace(
name="root_span",
kind="llm",
status_code=status,
status_message=status_message,
start_time_ms=start_time,
end_time_ms=end_time,
metadata={
"token_usage" : token_usage,
"model_name" : model
},
inputs= {"system_prompt" : formatted_system_prompt, "user_prompt" : formatted_user_prompt},
outputs= {"response" : response_text}
)
root_span.log(name="openai_trace")
return response_text if response_text else "We ran into an error. Please try again later. Full Error Message: " + status_message