Kang Suhyun commited on
Commit
fd9a72d
·
unverified ·
1 Parent(s): 8d7d881

[#130] Update non-JSON response handling (#131)

Browse files

* [#130] Update non-JSON response handling

* rename

Files changed (3) hide show
  1. README.md +0 -1
  2. model.py +45 -36
  3. response.py +9 -1
README.md CHANGED
@@ -56,7 +56,6 @@ Get Involved: [Discuss and contribute on GitHub](https://github.com/yanolja/aren
56
  ANTHROPIC_API_KEY=<your key> \
57
  MISTRAL_API_KEY=<your key> \
58
  GEMINI_API_KEY=<your key> \
59
- GROQ_API_KEY=<your key> \
60
  DEEPINFRA_API_KEY=<your key> \
61
  python3 app.py
62
  ```
 
56
  ANTHROPIC_API_KEY=<your key> \
57
  MISTRAL_API_KEY=<your key> \
58
  GEMINI_API_KEY=<your key> \
 
59
  DEEPINFRA_API_KEY=<your key> \
60
  python3 app.py
61
  ```
model.py CHANGED
@@ -4,7 +4,7 @@ This module contains functions to interact with the models.
4
 
5
  import json
6
  import os
7
- from typing import List
8
 
9
  import litellm
10
 
@@ -34,10 +34,12 @@ class Model:
34
  self.summarize_instruction = summarize_instruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
35
  self.translate_instruction = translate_instruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
36
 
 
37
  def completion(self,
38
  instruction: str,
39
  prompt: str,
40
- max_tokens: float = None) -> str:
 
41
  messages = [{
42
  "role":
43
  "system",
@@ -50,23 +52,25 @@ Output following this JSON format without using code blocks:
50
  "content": prompt
51
  }]
52
 
53
- try:
54
- response = litellm.completion(model=self.provider + "/" +
55
- self.name if self.provider else self.name,
56
- api_key=self.api_key,
57
- api_base=self.api_base,
58
- messages=messages,
59
- max_tokens=max_tokens,
60
- **self._get_completion_kwargs())
61
-
62
- json_response = response.choices[0].message.content
63
- parsed_json = json.loads(json_response)
64
- return parsed_json["result"]
65
-
66
- except litellm.ContextWindowExceededError as e:
67
- raise ContextWindowExceededError() from e
68
- except json.JSONDecodeError as e:
69
- raise RuntimeError(f"Failed to get JSON response: {e}") from e
 
 
70
 
71
  def _get_completion_kwargs(self):
72
  return {
@@ -82,7 +86,8 @@ class AnthropicModel(Model):
82
  def completion(self,
83
  instruction: str,
84
  prompt: str,
85
- max_tokens: float = None) -> str:
 
86
  # Ref: https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency#prefill-claudes-response # pylint: disable=line-too-long
87
  prefix = "<result>"
88
  suffix = "</result>"
@@ -99,23 +104,27 @@ Text:
99
  "role": "assistant",
100
  "content": prefix
101
  }]
102
- try:
103
- response = litellm.completion(
104
- model=self.provider + "/" + self.name if self.provider else self.name,
105
- api_key=self.api_key,
106
- api_base=self.api_base,
107
- messages=messages,
108
- max_tokens=max_tokens,
109
- )
110
 
111
- except litellm.ContextWindowExceededError as e:
112
- raise ContextWindowExceededError() from e
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- result = response.choices[0].message.content
115
- if not result.endswith(suffix):
116
- raise RuntimeError(f"Failed to get the formatted response: {result}")
117
 
118
- return result.removesuffix(suffix).strip()
 
119
 
120
 
121
  class VertexModel(Model):
@@ -164,8 +173,8 @@ supported_models: List[Model] = [
164
  vertex_credentials=os.getenv("VERTEX_CREDENTIALS")),
165
  Model("mistral-small-2402", provider="mistral"),
166
  Model("mistral-large-2402", provider="mistral"),
167
- Model("llama3-8b-8192", provider="groq"),
168
- Model("llama3-70b-8192", provider="groq"),
169
  Model("google/gemma-2-9b-it", provider="deepinfra"),
170
  Model("google/gemma-2-27b-it", provider="deepinfra"),
171
  EeveModel("yanolja/EEVE-Korean-Instruct-10.8B-v1.0",
 
4
 
5
  import json
6
  import os
7
+ from typing import List, Optional, Tuple
8
 
9
  import litellm
10
 
 
34
  self.summarize_instruction = summarize_instruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
35
  self.translate_instruction = translate_instruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
36
 
37
+ # Returns the parsed result or raw response, and whether parsing succeeded.
38
  def completion(self,
39
  instruction: str,
40
  prompt: str,
41
+ max_tokens: Optional[float] = None,
42
+ max_retries: int = 2) -> Tuple[str, bool]:
43
  messages = [{
44
  "role":
45
  "system",
 
52
  "content": prompt
53
  }]
54
 
55
+ for attempt in range(max_retries + 1):
56
+ try:
57
+ response = litellm.completion(model=self.provider + "/" +
58
+ self.name if self.provider else self.name,
59
+ api_key=self.api_key,
60
+ api_base=self.api_base,
61
+ messages=messages,
62
+ max_tokens=max_tokens,
63
+ **self._get_completion_kwargs())
64
+
65
+ json_response = response.choices[0].message.content
66
+ parsed_json = json.loads(json_response)
67
+ return parsed_json["result"], True
68
+
69
+ except litellm.ContextWindowExceededError as e:
70
+ raise ContextWindowExceededError() from e
71
+ except json.JSONDecodeError:
72
+ if attempt == max_retries:
73
+ return json_response, False
74
 
75
  def _get_completion_kwargs(self):
76
  return {
 
86
  def completion(self,
87
  instruction: str,
88
  prompt: str,
89
+ max_tokens: Optional[float] = None,
90
+ max_retries: int = 2) -> Tuple[str, bool]:
91
  # Ref: https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency#prefill-claudes-response # pylint: disable=line-too-long
92
  prefix = "<result>"
93
  suffix = "</result>"
 
104
  "role": "assistant",
105
  "content": prefix
106
  }]
 
 
 
 
 
 
 
 
107
 
108
+ for attempt in range(max_retries + 1):
109
+ try:
110
+ response = litellm.completion(
111
+ model=self.provider + "/" +
112
+ self.name if self.provider else self.name,
113
+ api_key=self.api_key,
114
+ api_base=self.api_base,
115
+ messages=messages,
116
+ max_tokens=max_tokens,
117
+ )
118
+
119
+ except litellm.ContextWindowExceededError as e:
120
+ raise ContextWindowExceededError() from e
121
 
122
+ result = response.choices[0].message.content
123
+ if result.endswith(suffix):
124
+ return result.removesuffix(suffix).strip(), True
125
 
126
+ if attempt == max_retries:
127
+ return result, False
128
 
129
 
130
  class VertexModel(Model):
 
173
  vertex_credentials=os.getenv("VERTEX_CREDENTIALS")),
174
  Model("mistral-small-2402", provider="mistral"),
175
  Model("mistral-large-2402", provider="mistral"),
176
+ Model("meta-llama/Meta-Llama-3-8B-Instruct", provider="deepinfra"),
177
+ Model("meta-llama/Meta-Llama-3-70B-Instruct", provider="deepinfra"),
178
  Model("google/gemma-2-9b-it", provider="deepinfra"),
179
  Model("google/gemma-2-27b-it", provider="deepinfra"),
180
  EeveModel("yanolja/EEVE-Korean-Instruct-10.8B-v1.0",
response.py CHANGED
@@ -22,6 +22,7 @@ logging.basicConfig()
22
  logger = logging.getLogger(__name__)
23
  logger.setLevel(logging.INFO)
24
 
 
25
  # TODO(#37): Move DB operations to db.py.
26
  def get_history_collection(category: str):
27
  if category == Category.SUMMARIZE.value:
@@ -89,14 +90,18 @@ def get_responses(prompt: str, category: str, source_lang: str,
89
 
90
  models: List[Model] = sample(list(supported_models), 2)
91
  responses = []
 
92
  for model in models:
93
  instruction = get_instruction(category, model, source_lang, target_lang)
94
  try:
95
  # TODO(#1): Allow user to set configuration.
96
- response = model.completion(instruction, prompt)
97
  create_history(category, model.name, instruction, prompt, response)
98
  responses.append(response)
99
 
 
 
 
100
  except ContextWindowExceededError as e:
101
  logger.exception("Context window exceeded for model %s.", model.name)
102
  raise gr.Error(
@@ -106,6 +111,9 @@ def get_responses(prompt: str, category: str, source_lang: str,
106
  logger.exception("Failed to get response from model %s.", model.name)
107
  raise gr.Error("Failed to get response. Please try again.") from e
108
 
 
 
 
109
  model_names = [model.name for model in models]
110
 
111
  # It simulates concurrent stream response generation.
 
22
  logger = logging.getLogger(__name__)
23
  logger.setLevel(logging.INFO)
24
 
25
+
26
  # TODO(#37): Move DB operations to db.py.
27
  def get_history_collection(category: str):
28
  if category == Category.SUMMARIZE.value:
 
90
 
91
  models: List[Model] = sample(list(supported_models), 2)
92
  responses = []
93
+ got_invalid_response = False
94
  for model in models:
95
  instruction = get_instruction(category, model, source_lang, target_lang)
96
  try:
97
  # TODO(#1): Allow user to set configuration.
98
+ response, is_valid_response = model.completion(instruction, prompt)
99
  create_history(category, model.name, instruction, prompt, response)
100
  responses.append(response)
101
 
102
+ if not is_valid_response:
103
+ got_invalid_response = True
104
+
105
  except ContextWindowExceededError as e:
106
  logger.exception("Context window exceeded for model %s.", model.name)
107
  raise gr.Error(
 
111
  logger.exception("Failed to get response from model %s.", model.name)
112
  raise gr.Error("Failed to get response. Please try again.") from e
113
 
114
+ if got_invalid_response:
115
+ gr.Warning("An invalid response was received.")
116
+
117
  model_names = [model.name for model in models]
118
 
119
  # It simulates concurrent stream response generation.