aakash0017's picture
Upload folder using huggingface_hub
7b5e0ec
from dotenv import load_dotenv, find_dotenv
import os
import openai
from langchain.prompts import PromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from rag import fetch_top_k
import json
from langchain.chat_models import ChatOpenAI
from langchain.schema.messages import SystemMessage
import time
# langchain.chat_models.init(openai_api_key=os.getenv("OPENAI_API_KEY"))
class InValidOutput(ValueError):
def __init__(self, message='In valid result from the model'):
# Call the base class constructor with the parameters it needs
super(InValidOutput, self).__init__(message)
class PineConeException(Exception):
def __init__(self, message='Pinecone error'):
# Call the base class constructor with the parameters it needs
super(PineConeException, self).__init__(message)
class ApiException(Exception):
def __init__(self, message='API error'):
# Call the base class constructor with the parameters it needs
super(ApiException, self).__init__(message)
def create_openai_prompt(input_data, top_k, prompt_file):
#First do rag and find relevant context
try:
top_k_docuemnts = fetch_top_k(input_data=input_data, top_k=top_k)
# time.sleep(10)
context = '\n\n'.join(top_k_docuemnts)
except PineConeException as e:
print(f"Caught Pinecone Exception: {e}")
prompt = open(prompt_file).read()
# Create Chat Template
chat_template = ChatPromptTemplate.from_messages(
[
SystemMessage(
content=prompt,
),
HumanMessagePromptTemplate.from_template("Context: {context}\nAnswer the following question.\n{input_data}"),
# HumanMessagePromptTemplate.from_template("Answer this question. Give only the option followed by a brief explanation as output\\n{text}"),
]
)
complete_prompt = chat_template.format_messages(input_data=input_data, context=context)
return complete_prompt, top_k_docuemnts
def api_call(input_data, history={}):
top_k = 5
service="openai"
prompt_file="prompts/version_2.txt"
if service == "openai":
complete_prompt, top_k_docuemnts = create_openai_prompt(
input_data=input_data,
top_k=top_k,
prompt_file=prompt_file)
# print(top_k_docuemnts)
# time.sleep(5)
llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"), openai_organization=os.getenv("ORGANIZATION_KEY"), model="gpt-4", temperature=0.2)
try:
output = llm(complete_prompt).content
except ApiException as e:
print(f"API Error: {e}")
try:
output_dict = json.loads(output)
# document_index = output_dict["doc_no"].replace(" ","").split(",")
# document_index = [int(i)-1 for i in document_index]
# documents = []
# for i in document_index:
# documents.append(f"{top_k_docuemnts[i]}")
# docs = "\n\n".join(documents)
# output_text = f"Answer for the question is {output_dict['option']}.\n\nExplanation: {output_dict['explanation']}\n\nDocuments used for reference are\n\n{docs}"
return output_dict, top_k_docuemnts
except Exception as e:
docs = "\n\n".join(top_k_docuemnts)
output_text = f"{output}\n\nDocument fetched: {docs}"
# print(f"Invalid output from the model: {e}\n\nDocument fetched: {docs}")
return output_text
else:
print("Service not Supported")
# return output_text