File size: 3,591 Bytes
7b5e0ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from dotenv import load_dotenv, find_dotenv
import os
import openai
from langchain.prompts import PromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from rag import fetch_top_k
import json
from langchain.chat_models import ChatOpenAI
from langchain.schema.messages import SystemMessage
import time

# langchain.chat_models.init(openai_api_key=os.getenv("OPENAI_API_KEY"))
class InValidOutput(ValueError):
    def __init__(self, message='In valid result from the model'):
        # Call the base class constructor with the parameters it needs
        super(InValidOutput, self).__init__(message)

class PineConeException(Exception):
    def __init__(self, message='Pinecone error'):
        # Call the base class constructor with the parameters it needs
        super(PineConeException, self).__init__(message)

class ApiException(Exception):
    def __init__(self, message='API error'):
        # Call the base class constructor with the parameters it needs
        super(ApiException, self).__init__(message)


def create_openai_prompt(input_data, top_k, prompt_file):
    #First do rag and find relevant context
    try:
        top_k_docuemnts = fetch_top_k(input_data=input_data, top_k=top_k)
        # time.sleep(10)
        context = '\n\n'.join(top_k_docuemnts)
    except PineConeException as e:
        print(f"Caught Pinecone Exception: {e}")

    prompt = open(prompt_file).read()

    # Create Chat Template

    chat_template = ChatPromptTemplate.from_messages(
    [
        SystemMessage(
            content=prompt,     
        ),
        HumanMessagePromptTemplate.from_template("Context: {context}\nAnswer the following question.\n{input_data}"),
        # HumanMessagePromptTemplate.from_template("Answer this question. Give only the option followed by a brief explanation as output\\n{text}"),
    ]
    )

    complete_prompt = chat_template.format_messages(input_data=input_data, context=context)

    return complete_prompt, top_k_docuemnts

def api_call(input_data, history={}):

    top_k = 5
    service="openai"
    prompt_file="prompts/version_2.txt"
    
    if service == "openai":

        complete_prompt, top_k_docuemnts = create_openai_prompt(
        input_data=input_data,
        top_k=top_k,
        prompt_file=prompt_file)
        # print(top_k_docuemnts)
        # time.sleep(5)

        llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"), openai_organization=os.getenv("ORGANIZATION_KEY"), model="gpt-4", temperature=0.2)

        try:
            output = llm(complete_prompt).content
        except ApiException as e:
            print(f"API Error: {e}")
        
        try:
            output_dict = json.loads(output)
            # document_index = output_dict["doc_no"].replace(" ","").split(",")
            # document_index = [int(i)-1 for i in document_index]

            # documents = []
            # for i in document_index:
            #     documents.append(f"{top_k_docuemnts[i]}")
            
            # docs = "\n\n".join(documents)

            # output_text = f"Answer for the question is {output_dict['option']}.\n\nExplanation: {output_dict['explanation']}\n\nDocuments used for reference are\n\n{docs}"

            return output_dict, top_k_docuemnts
        
        except Exception as e:
            docs = "\n\n".join(top_k_docuemnts)
            output_text = f"{output}\n\nDocument fetched: {docs}"
            # print(f"Invalid output from the model: {e}\n\nDocument fetched: {docs}")
            return output_text

    else:
        print("Service not Supported")

    # return output_text