Amirkakeh commited on
Commit
247e7e4
·
verified ·
1 Parent(s): 1156c85

Upload octotools_engine_openai.py

Browse files
octotools/engine/octotools_engine_openai.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from openai import OpenAI
3
+ except ImportError:
4
+ raise ImportError("If you'd like to use OpenAI models, please install the openai package by running `pip install openai`, and add 'OPENAI_API_KEY' to your environment variables.")
5
+
6
+ import os
7
+ import json
8
+ import base64
9
+ import platformdirs
10
+ from tenacity import (
11
+ retry,
12
+ stop_after_attempt,
13
+ wait_random_exponential,
14
+ )
15
+ from typing import List, Union
16
+
17
+ from .base import EngineLM, CachedEngine
18
+
19
+ import openai
20
+
21
+ from dotenv import load_dotenv
22
+ load_dotenv()
23
+
24
+ # Define global constant for structured models
25
+ # https://platform.openai.com/docs/guides/structured-outputs
26
+ # https://cookbook.openai.com/examples/structured_outputs_intro
27
+ from pydantic import BaseModel
28
+
29
+ class DefaultFormat(BaseModel):
30
+ response: str
31
+
32
+ # Define global constant for structured models
33
+ OPENAI_STRUCTURED_MODELS = ['gpt-4o', 'gpt-4o-2024-08-06','gpt-4o-mini', 'gpt-4o-mini-2024-07-18','deepseek']
34
+
35
+
36
+ class ChatOpenAI(EngineLM, CachedEngine):
37
+ DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."
38
+
39
+ def __init__(
40
+ self,
41
+ model_string="gpt-4o-mini-2024-07-18",
42
+ system_prompt=DEFAULT_SYSTEM_PROMPT,
43
+ is_multimodal: bool=False,
44
+ # enable_cache: bool=True,
45
+ enable_cache: bool=False, # NOTE: disable cache for now
46
+ api_key: str=None,
47
+ **kwargs):
48
+ """
49
+ :param model_string:
50
+ :param system_prompt:
51
+ :param is_multimodal:
52
+ """
53
+ if enable_cache:
54
+ root = platformdirs.user_cache_dir("octotools")
55
+ cache_path = os.path.join(root, f"cache_openai_{model_string}.db")
56
+ # For example, cache_path = /root/.cache/octotools/cache_openai_gpt-4o-mini.db
57
+ # print(f"Cache path: {cache_path}")
58
+
59
+ self.image_cache_dir = os.path.join(root, "image_cache")
60
+ os.makedirs(self.image_cache_dir, exist_ok=True)
61
+
62
+ super().__init__(cache_path=cache_path)
63
+
64
+ self.system_prompt = system_prompt
65
+ if api_key is None:
66
+ raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")
67
+
68
+ self.client = OpenAI(
69
+ api_key=api_key,
70
+ )
71
+ self.model_string = model_string
72
+ self.is_multimodal = is_multimodal
73
+ self.enable_cache = enable_cache
74
+
75
+ if enable_cache:
76
+ print(f"!! Cache enabled for model: {self.model_string}")
77
+ else:
78
+ print(f"!! Cache disabled for model: {self.model_string}")
79
+
80
+ @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5))
81
+ def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs):
82
+ try:
83
+ # Print retry attempt information
84
+ attempt_number = self.generate.retry.statistics.get('attempt_number', 0) + 1
85
+ if attempt_number > 1:
86
+ print(f"Attempt {attempt_number} of 5")
87
+
88
+ if isinstance(content, str):
89
+ return self._generate_text(content, system_prompt=system_prompt, **kwargs)
90
+
91
+ elif isinstance(content, list):
92
+ if (not self.is_multimodal):
93
+ raise NotImplementedError("Multimodal generation is only supported for GPT-4 models.")
94
+
95
+ return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs)
96
+
97
+ except openai.LengthFinishReasonError as e:
98
+ print(f"Token limit exceeded: {str(e)}")
99
+ print(f"Tokens used - Completion: {e.completion.usage.completion_tokens}, Prompt: {e.completion.usage.prompt_tokens}, Total: {e.completion.usage.total_tokens}")
100
+ return {
101
+ "error": "token_limit_exceeded",
102
+ "message": str(e),
103
+ "details": {
104
+ "completion_tokens": e.completion.usage.completion_tokens,
105
+ "prompt_tokens": e.completion.usage.prompt_tokens,
106
+ "total_tokens": e.completion.usage.total_tokens
107
+ }
108
+ }
109
+ except openai.RateLimitError as e:
110
+ print(f"Rate limit error encountered: {str(e)}")
111
+ return {
112
+ "error": "rate_limit",
113
+ "message": str(e),
114
+ "details": getattr(e, 'args', None)
115
+ }
116
+ except Exception as e:
117
+ print(f"Error in generate method: {str(e)}")
118
+ print(f"Error type: {type(e).__name__}")
119
+ print(f"Error details: {e.args}")
120
+ return {
121
+ "error": type(e).__name__,
122
+ "message": str(e),
123
+ "details": getattr(e, 'args', None)
124
+ }
125
+
126
+ def _generate_text(
127
+ self, prompt, system_prompt=None, temperature=0.5, max_tokens=4000, top_p=0.99, response_format=None
128
+ ):
129
+
130
+ sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
131
+
132
+ if self.enable_cache:
133
+ cache_key = sys_prompt_arg + prompt
134
+ cache_or_none = self._check_cache(cache_key)
135
+ if cache_or_none is not None:
136
+ return cache_or_none
137
+
138
+ if self.model_string in ['o1', 'o1-mini']: # only supports base response currently
139
+ # print(f"Using structured model: {self.model_string}")
140
+ response = self.client.beta.chat.completions.parse(
141
+ model=self.model_string,
142
+ messages=[
143
+ {"role": "user", "content": prompt},
144
+ ],
145
+ max_completion_tokens=max_tokens
146
+ )
147
+ if response.choices[0].finishreason == "length":
148
+ response = "Token limit exceeded"
149
+ else:
150
+ response = response.choices[0].message.parsed
151
+ elif self.model_string in OPENAI_STRUCTURED_MODELS and response_format is not None:
152
+ # print(f"Using structured model: {self.model_string}")
153
+ response = self.client.beta.chat.completions.parse(
154
+ model=self.model_string,
155
+ messages=[
156
+ {"role": "system", "content": sys_prompt_arg},
157
+ {"role": "user", "content": prompt},
158
+ ],
159
+ frequency_penalty=0,
160
+ presence_penalty=0,
161
+ stop=None,
162
+ temperature=temperature,
163
+ max_tokens=max_tokens,
164
+ top_p=top_p,
165
+ response_format=response_format
166
+ )
167
+ response = response.choices[0].message.parsed
168
+ else:
169
+ # print(f"Using non-structured model: {self.model_string}")
170
+ response = self.client.chat.completions.create(
171
+ model=self.model_string,
172
+ messages=[
173
+ {"role": "system", "content": sys_prompt_arg},
174
+ {"role": "user", "content": prompt},
175
+ ],
176
+ frequency_penalty=0,
177
+ presence_penalty=0,
178
+ stop=None,
179
+ temperature=temperature,
180
+ max_tokens=max_tokens,
181
+ top_p=top_p,
182
+ )
183
+ response = response.choices[0].message.content
184
+
185
+ if self.enable_cache:
186
+ self._save_cache(cache_key, response)
187
+ return response
188
+
189
+ def __call__(self, prompt, **kwargs):
190
+ return self.generate(prompt, **kwargs)
191
+
192
+ def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]:
193
+ formatted_content = []
194
+ for item in content:
195
+ if isinstance(item, bytes):
196
+ base64_image = base64.b64encode(item).decode('utf-8')
197
+ formatted_content.append({
198
+ "type": "image_url",
199
+ "image_url": {
200
+ "url": f"data:image/jpeg;base64,{base64_image}"
201
+ }
202
+ })
203
+ elif isinstance(item, str):
204
+ formatted_content.append({
205
+ "type": "text",
206
+ "text": item
207
+ })
208
+ else:
209
+ raise ValueError(f"Unsupported input type: {type(item)}")
210
+ return formatted_content
211
+
212
+ def _generate_multimodal(
213
+ self, content: List[Union[str, bytes]], system_prompt=None, temperature=0.5, max_tokens=4000, top_p=0.99, response_format=None
214
+ ):
215
+ sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
216
+ formatted_content = self._format_content(content)
217
+
218
+ if self.enable_cache:
219
+ cache_key = sys_prompt_arg + json.dumps(formatted_content)
220
+ cache_or_none = self._check_cache(cache_key)
221
+ if cache_or_none is not None:
222
+ # print(f"Cache hit for prompt: {cache_key[:200]}")
223
+ return cache_or_none
224
+
225
+ if self.model_string in ['o1', 'o1-mini']: # only supports base response currently
226
+ # print(f"Using structured model: {self.model_string}")
227
+ print(f'Max tokens: {max_tokens}')
228
+ response = self.client.chat.completions.create(
229
+ model=self.model_string,
230
+ messages=[
231
+ {"role": "user", "content": formatted_content},
232
+ ],
233
+ max_completion_tokens=max_tokens
234
+ )
235
+ if response.choices[0].finish_reason == "length":
236
+ response_text = "Token limit exceeded"
237
+ else:
238
+ response_text = response.choices[0].message.content
239
+ elif self.model_string in OPENAI_STRUCTURED_MODELS and response_format is not None:
240
+ # print(f"Using structured model: {self.model_string}")
241
+ response = self.client.beta.chat.completions.parse(
242
+ model=self.model_string,
243
+ messages=[
244
+ {"role": "system", "content": sys_prompt_arg},
245
+ {"role": "user", "content": formatted_content},
246
+ ],
247
+ temperature=temperature,
248
+ max_tokens=max_tokens,
249
+ top_p=top_p,
250
+ response_format=response_format
251
+ )
252
+ response_text = response.choices[0].message.parsed
253
+ else:
254
+ # print(f"Using non-structured model: {self.model_string}")
255
+ response = self.client.chat.completions.create(
256
+ model=self.model_string,
257
+ messages=[
258
+ {"role": "system", "content": sys_prompt_arg},
259
+ {"role": "user", "content": formatted_content},
260
+ ],
261
+ temperature=temperature,
262
+ max_tokens=max_tokens,
263
+ top_p=top_p,
264
+ )
265
+ response_text = response.choices[0].message.content
266
+
267
+ if self.enable_cache:
268
+ self._save_cache(cache_key, response_text)
269
+ return response_text