Spaces:
Runtime error
Runtime error
from dotenv import load_dotenv, find_dotenv | |
import os | |
import openai | |
from langchain.prompts import PromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | |
from rag import fetch_top_k | |
import json | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema.messages import SystemMessage | |
import time | |
# langchain.chat_models.init(openai_api_key=os.getenv("OPENAI_API_KEY")) | |
class InValidOutput(ValueError): | |
def __init__(self, message='In valid result from the model'): | |
# Call the base class constructor with the parameters it needs | |
super(InValidOutput, self).__init__(message) | |
class PineConeException(Exception): | |
def __init__(self, message='Pinecone error'): | |
# Call the base class constructor with the parameters it needs | |
super(PineConeException, self).__init__(message) | |
class ApiException(Exception): | |
def __init__(self, message='API error'): | |
# Call the base class constructor with the parameters it needs | |
super(ApiException, self).__init__(message) | |
def create_openai_prompt(input_data, top_k, prompt_file): | |
#First do rag and find relevant context | |
try: | |
top_k_docuemnts = fetch_top_k(input_data=input_data, top_k=top_k) | |
# time.sleep(10) | |
context = '\n\n'.join(top_k_docuemnts) | |
except PineConeException as e: | |
print(f"Caught Pinecone Exception: {e}") | |
prompt = open(prompt_file).read() | |
# Create Chat Template | |
chat_template = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessage( | |
content=prompt, | |
), | |
HumanMessagePromptTemplate.from_template("Context: {context}\nAnswer the following question.\n{input_data}"), | |
# HumanMessagePromptTemplate.from_template("Answer this question. Give only the option followed by a brief explanation as output\\n{text}"), | |
] | |
) | |
complete_prompt = chat_template.format_messages(input_data=input_data, context=context) | |
return complete_prompt, top_k_docuemnts | |
def api_call(input_data, history={}): | |
top_k = 5 | |
service="openai" | |
prompt_file="prompts/version_2.txt" | |
if service == "openai": | |
complete_prompt, top_k_docuemnts = create_openai_prompt( | |
input_data=input_data, | |
top_k=top_k, | |
prompt_file=prompt_file) | |
# print(top_k_docuemnts) | |
# time.sleep(5) | |
llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"), openai_organization=os.getenv("ORGANIZATION_KEY"), model="gpt-4", temperature=0.2) | |
try: | |
output = llm(complete_prompt).content | |
except ApiException as e: | |
print(f"API Error: {e}") | |
try: | |
output_dict = json.loads(output) | |
# document_index = output_dict["doc_no"].replace(" ","").split(",") | |
# document_index = [int(i)-1 for i in document_index] | |
# documents = [] | |
# for i in document_index: | |
# documents.append(f"{top_k_docuemnts[i]}") | |
# docs = "\n\n".join(documents) | |
# output_text = f"Answer for the question is {output_dict['option']}.\n\nExplanation: {output_dict['explanation']}\n\nDocuments used for reference are\n\n{docs}" | |
return output_dict, top_k_docuemnts | |
except Exception as e: | |
docs = "\n\n".join(top_k_docuemnts) | |
output_text = f"{output}\n\nDocument fetched: {docs}" | |
# print(f"Invalid output from the model: {e}\n\nDocument fetched: {docs}") | |
return output_text | |
else: | |
print("Service not Supported") | |
# return output_text |