File size: 4,127 Bytes
9de5882
dd00b61
9de5882
 
 
 
421716e
 
 
dd00b61
9de5882
 
 
 
 
 
 
 
 
 
 
 
285e13f
9de5882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a23421d
 
dd00b61
 
a23421d
 
 
 
9de5882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285e13f
a13da0d
dd00b61
285e13f
 
a23421d
dd00b61
dc52018
 
 
 
 
 
a13da0d
dc52018
 
a13da0d
 
 
 
 
 
 
 
 
 
 
9de5882
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
from io import BytesIO
import pandas as pd
from dotenv import load_dotenv
load_dotenv()
import openai
import streamlit as st

# # set OPENAI_API_KEY environment variable from .streamlit/secrets.toml file
openai.api_key = st.secrets["OPENAI_API_KEY"]

# # set OPENAI_API_KEY environment variable from .env file
# openai.api_key = os.getenv("OPENAI_API_KEY")

# # read in llm-data-cleaner/prompts/gpt4-system-message.txt file into variable system_message
# system_message = open('../prompts/gpt4-system-message.txt', 'r').read()

class OpenAIChatCompletions:
    def __init__(self, model="gpt-4", system_message=None):
        self.model = model
        self.system_message = system_message


    # function to input args such as model, prompt, etc. and return completion
    def openai_chat_completion(self, prompt, n_shot=None):
        messages = [{"role": "system", "content": self.system_message}] if self.system_message else []
        
        # add n_shot number of samples to messages list ... if n_shot is None, then only system_message and prompt will be added to messages list
        if n_shot is not None:
            messages = self._add_samples(messages, n_samples=n_shot)

        messages.append({"role": "user", "content": prompt})

        # set up the API request parameters for OpenAI
        chat_request_kwargs = dict(
            model=self.model,
            messages=messages,
        )

        # make the API request to OpenAI
        response = openai.ChatCompletion.create(**chat_request_kwargs)

        # return only the completion text
        # return response['choices'][0]['message']['content']
        # return response
        return response


    # function to use test data to predict completions
    def predict_jsonl(
        self,
        path_or_buf='../data/cookies_train.jsonl',
        # path_or_buf='~/data/cookies_train.jsonl',
        n_samples=None,
        n_shot=None
        ):
        
        jsonObj = pd.read_json(path_or_buf=path_or_buf, lines=True)
        if n_samples is not None:
            jsonObj = jsonObj.sample(n_samples, random_state=42)

        iter_range = range(len(jsonObj))
        prompts = [jsonObj.iloc[i]['prompt'] for i in iter_range]
        completions = [jsonObj.iloc[i]['completion'] for i in iter_range]
        predictions = [self.openai_chat_completion(prompt, n_shot=n_shot) for prompt in prompts]

        return prompts, completions, predictions


    # a method that adds prompt and completion samples to messages
    @staticmethod
    def _add_samples(messages, n_samples=None):
        if n_samples is None:
            return messages

        samples = OpenAIChatCompletions._sample_jsonl(n_samples=n_samples)
        for i in range(n_samples):
            messages.append({"role": "user", "content": samples.iloc[i]['prompt']})
            messages.append({"role": "assistant", "content": samples.iloc[i]['completion']})

        return messages


    # a method that samples n rows from a jsonl file, returning a pandas dataframe
    @staticmethod
    def _sample_jsonl(
        path_or_buf='data/cookies_train.jsonl',
        # path_or_buf='~/data/cookies_train.jsonl',
        n_samples=5
        ):
        
        # jsonObj = pd.read_json(path_or_buf=path_or_buf, lines=True)
        
        # if running locally, True
        # else running on HF Spaces, False
        if "Kaleidoscope Data" in os.getcwd():
            # file_path = os.path.join(os.getcwd(), "..", path_or_buf)
            file_path = os.path.join("/".join(os.getcwd().split('/')[:-1]), path_or_buf)
        else:
            file_path = os.path.join(os.getcwd(), path_or_buf)


        try:
            with open(file_path, "r") as file:
                jsonl_str = file.read()

            jsonObj = pd.read_json(BytesIO(jsonl_str.encode()), lines=True, engine="pyarrow")
        except FileNotFoundError:
            # Handle the case where the file is not found
            # Display an error message or take appropriate action
            st.write(f"File not found: {file_path}")

        return jsonObj.sample(n_samples, random_state=42)