from llmops.openai_utils.chatmodel import ChatOpenAI from llmops.vectordatabase import VectorDatabase from llmops.openai_utils.prompts import ( UserRolePrompt, SystemRolePrompt, AssistantRolePrompt, ) import datetime from wandb.sdk.data_types.trace_tree import Trace class RetrievalAugmentedQAPipeline: """ """ def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase) -> None: self.llm = llm self.vector_db_retriever = vector_db_retriever def run_pipeline(self, user_query: str, raqa_prompt:SystemRolePrompt, user_prompt:UserRolePrompt) -> str: context_list = self.vector_db_retriever.search_by_text(user_query, k=4) context_prompt = "" for context in context_list: context_prompt += context[0] + "\n" formatted_system_prompt = raqa_prompt.create_message(context=context_prompt) formatted_user_prompt = user_prompt.create_message(user_query=user_query) return self.llm.run([formatted_system_prompt, formatted_user_prompt]) class WandB_RetrievalAugmentedQAPipeline: def __init__(self, llm: ChatOpenAI(), vector_db_retriever: VectorDatabase, wandb_project = None) -> None: self.llm = llm self.vector_db_retriever = vector_db_retriever self.wandb_project = wandb_project def run_pipeline(self, user_query: str, raqa_prompt:SystemRolePrompt, user_prompt:UserRolePrompt) -> str: context_list = self.vector_db_retriever.search_by_text(user_query, k=4) context_prompt = "" for context in context_list: context_prompt += context[0] + "\n" formatted_system_prompt = raqa_prompt.create_message(context=context_prompt) formatted_user_prompt = user_prompt.create_message(user_query=user_query) start_time = datetime.datetime.now().timestamp() * 1000 try: openai_response = self.llm.run([formatted_system_prompt, formatted_user_prompt], text_only=False) end_time = datetime.datetime.now().timestamp() * 1000 status = "success" status_message = (None, ) response_text = openai_response.choices[0].message.content token_usage = openai_response["usage"].to_dict() model = openai_response["model"] except Exception as e: end_time = datetime.datetime.now().timestamp() * 1000 status = "error" status_message = str(e) response_text = "" token_usage = {} model = "" if self.wandb_project: root_span = Trace( name="root_span", kind="llm", status_code=status, status_message=status_message, start_time_ms=start_time, end_time_ms=end_time, metadata={ "token_usage" : token_usage, "model_name" : model }, inputs= {"system_prompt" : formatted_system_prompt, "user_prompt" : formatted_user_prompt}, outputs= {"response" : response_text} ) root_span.log(name="openai_trace") return response_text if response_text else "We ran into an error. Please try again later. Full Error Message: " + status_message