Spaces:
Running
Running
Kang Suhyun
commited on
[#130] Update non-JSON response handling (#131)
Browse files* [#130] Update non-JSON response handling
* rename
- README.md +0 -1
- model.py +45 -36
- response.py +9 -1
README.md
CHANGED
@@ -56,7 +56,6 @@ Get Involved: [Discuss and contribute on GitHub](https://github.com/yanolja/aren
|
|
56 |
ANTHROPIC_API_KEY=<your key> \
|
57 |
MISTRAL_API_KEY=<your key> \
|
58 |
GEMINI_API_KEY=<your key> \
|
59 |
-
GROQ_API_KEY=<your key> \
|
60 |
DEEPINFRA_API_KEY=<your key> \
|
61 |
python3 app.py
|
62 |
```
|
|
|
56 |
ANTHROPIC_API_KEY=<your key> \
|
57 |
MISTRAL_API_KEY=<your key> \
|
58 |
GEMINI_API_KEY=<your key> \
|
|
|
59 |
DEEPINFRA_API_KEY=<your key> \
|
60 |
python3 app.py
|
61 |
```
|
model.py
CHANGED
@@ -4,7 +4,7 @@ This module contains functions to interact with the models.
|
|
4 |
|
5 |
import json
|
6 |
import os
|
7 |
-
from typing import List
|
8 |
|
9 |
import litellm
|
10 |
|
@@ -34,10 +34,12 @@ class Model:
|
|
34 |
self.summarize_instruction = summarize_instruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
|
35 |
self.translate_instruction = translate_instruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
|
36 |
|
|
|
37 |
def completion(self,
|
38 |
instruction: str,
|
39 |
prompt: str,
|
40 |
-
max_tokens: float = None
|
|
|
41 |
messages = [{
|
42 |
"role":
|
43 |
"system",
|
@@ -50,23 +52,25 @@ Output following this JSON format without using code blocks:
|
|
50 |
"content": prompt
|
51 |
}]
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
70 |
|
71 |
def _get_completion_kwargs(self):
|
72 |
return {
|
@@ -82,7 +86,8 @@ class AnthropicModel(Model):
|
|
82 |
def completion(self,
|
83 |
instruction: str,
|
84 |
prompt: str,
|
85 |
-
max_tokens: float = None
|
|
|
86 |
# Ref: https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency#prefill-claudes-response # pylint: disable=line-too-long
|
87 |
prefix = "<result>"
|
88 |
suffix = "</result>"
|
@@ -99,23 +104,27 @@ Text:
|
|
99 |
"role": "assistant",
|
100 |
"content": prefix
|
101 |
}]
|
102 |
-
try:
|
103 |
-
response = litellm.completion(
|
104 |
-
model=self.provider + "/" + self.name if self.provider else self.name,
|
105 |
-
api_key=self.api_key,
|
106 |
-
api_base=self.api_base,
|
107 |
-
messages=messages,
|
108 |
-
max_tokens=max_tokens,
|
109 |
-
)
|
110 |
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
|
118 |
-
|
|
|
119 |
|
120 |
|
121 |
class VertexModel(Model):
|
@@ -164,8 +173,8 @@ supported_models: List[Model] = [
|
|
164 |
vertex_credentials=os.getenv("VERTEX_CREDENTIALS")),
|
165 |
Model("mistral-small-2402", provider="mistral"),
|
166 |
Model("mistral-large-2402", provider="mistral"),
|
167 |
-
Model("
|
168 |
-
Model("
|
169 |
Model("google/gemma-2-9b-it", provider="deepinfra"),
|
170 |
Model("google/gemma-2-27b-it", provider="deepinfra"),
|
171 |
EeveModel("yanolja/EEVE-Korean-Instruct-10.8B-v1.0",
|
|
|
4 |
|
5 |
import json
|
6 |
import os
|
7 |
+
from typing import List, Optional, Tuple
|
8 |
|
9 |
import litellm
|
10 |
|
|
|
34 |
self.summarize_instruction = summarize_instruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
|
35 |
self.translate_instruction = translate_instruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
|
36 |
|
37 |
+
# Returns the parsed result or raw response, and whether parsing succeeded.
|
38 |
def completion(self,
|
39 |
instruction: str,
|
40 |
prompt: str,
|
41 |
+
max_tokens: Optional[float] = None,
|
42 |
+
max_retries: int = 2) -> Tuple[str, bool]:
|
43 |
messages = [{
|
44 |
"role":
|
45 |
"system",
|
|
|
52 |
"content": prompt
|
53 |
}]
|
54 |
|
55 |
+
for attempt in range(max_retries + 1):
|
56 |
+
try:
|
57 |
+
response = litellm.completion(model=self.provider + "/" +
|
58 |
+
self.name if self.provider else self.name,
|
59 |
+
api_key=self.api_key,
|
60 |
+
api_base=self.api_base,
|
61 |
+
messages=messages,
|
62 |
+
max_tokens=max_tokens,
|
63 |
+
**self._get_completion_kwargs())
|
64 |
+
|
65 |
+
json_response = response.choices[0].message.content
|
66 |
+
parsed_json = json.loads(json_response)
|
67 |
+
return parsed_json["result"], True
|
68 |
+
|
69 |
+
except litellm.ContextWindowExceededError as e:
|
70 |
+
raise ContextWindowExceededError() from e
|
71 |
+
except json.JSONDecodeError:
|
72 |
+
if attempt == max_retries:
|
73 |
+
return json_response, False
|
74 |
|
75 |
def _get_completion_kwargs(self):
|
76 |
return {
|
|
|
86 |
def completion(self,
|
87 |
instruction: str,
|
88 |
prompt: str,
|
89 |
+
max_tokens: Optional[float] = None,
|
90 |
+
max_retries: int = 2) -> Tuple[str, bool]:
|
91 |
# Ref: https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency#prefill-claudes-response # pylint: disable=line-too-long
|
92 |
prefix = "<result>"
|
93 |
suffix = "</result>"
|
|
|
104 |
"role": "assistant",
|
105 |
"content": prefix
|
106 |
}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
+
for attempt in range(max_retries + 1):
|
109 |
+
try:
|
110 |
+
response = litellm.completion(
|
111 |
+
model=self.provider + "/" +
|
112 |
+
self.name if self.provider else self.name,
|
113 |
+
api_key=self.api_key,
|
114 |
+
api_base=self.api_base,
|
115 |
+
messages=messages,
|
116 |
+
max_tokens=max_tokens,
|
117 |
+
)
|
118 |
+
|
119 |
+
except litellm.ContextWindowExceededError as e:
|
120 |
+
raise ContextWindowExceededError() from e
|
121 |
|
122 |
+
result = response.choices[0].message.content
|
123 |
+
if result.endswith(suffix):
|
124 |
+
return result.removesuffix(suffix).strip(), True
|
125 |
|
126 |
+
if attempt == max_retries:
|
127 |
+
return result, False
|
128 |
|
129 |
|
130 |
class VertexModel(Model):
|
|
|
173 |
vertex_credentials=os.getenv("VERTEX_CREDENTIALS")),
|
174 |
Model("mistral-small-2402", provider="mistral"),
|
175 |
Model("mistral-large-2402", provider="mistral"),
|
176 |
+
Model("meta-llama/Meta-Llama-3-8B-Instruct", provider="deepinfra"),
|
177 |
+
Model("meta-llama/Meta-Llama-3-70B-Instruct", provider="deepinfra"),
|
178 |
Model("google/gemma-2-9b-it", provider="deepinfra"),
|
179 |
Model("google/gemma-2-27b-it", provider="deepinfra"),
|
180 |
EeveModel("yanolja/EEVE-Korean-Instruct-10.8B-v1.0",
|
response.py
CHANGED
@@ -22,6 +22,7 @@ logging.basicConfig()
|
|
22 |
logger = logging.getLogger(__name__)
|
23 |
logger.setLevel(logging.INFO)
|
24 |
|
|
|
25 |
# TODO(#37): Move DB operations to db.py.
|
26 |
def get_history_collection(category: str):
|
27 |
if category == Category.SUMMARIZE.value:
|
@@ -89,14 +90,18 @@ def get_responses(prompt: str, category: str, source_lang: str,
|
|
89 |
|
90 |
models: List[Model] = sample(list(supported_models), 2)
|
91 |
responses = []
|
|
|
92 |
for model in models:
|
93 |
instruction = get_instruction(category, model, source_lang, target_lang)
|
94 |
try:
|
95 |
# TODO(#1): Allow user to set configuration.
|
96 |
-
response = model.completion(instruction, prompt)
|
97 |
create_history(category, model.name, instruction, prompt, response)
|
98 |
responses.append(response)
|
99 |
|
|
|
|
|
|
|
100 |
except ContextWindowExceededError as e:
|
101 |
logger.exception("Context window exceeded for model %s.", model.name)
|
102 |
raise gr.Error(
|
@@ -106,6 +111,9 @@ def get_responses(prompt: str, category: str, source_lang: str,
|
|
106 |
logger.exception("Failed to get response from model %s.", model.name)
|
107 |
raise gr.Error("Failed to get response. Please try again.") from e
|
108 |
|
|
|
|
|
|
|
109 |
model_names = [model.name for model in models]
|
110 |
|
111 |
# It simulates concurrent stream response generation.
|
|
|
22 |
logger = logging.getLogger(__name__)
|
23 |
logger.setLevel(logging.INFO)
|
24 |
|
25 |
+
|
26 |
# TODO(#37): Move DB operations to db.py.
|
27 |
def get_history_collection(category: str):
|
28 |
if category == Category.SUMMARIZE.value:
|
|
|
90 |
|
91 |
models: List[Model] = sample(list(supported_models), 2)
|
92 |
responses = []
|
93 |
+
got_invalid_response = False
|
94 |
for model in models:
|
95 |
instruction = get_instruction(category, model, source_lang, target_lang)
|
96 |
try:
|
97 |
# TODO(#1): Allow user to set configuration.
|
98 |
+
response, is_valid_response = model.completion(instruction, prompt)
|
99 |
create_history(category, model.name, instruction, prompt, response)
|
100 |
responses.append(response)
|
101 |
|
102 |
+
if not is_valid_response:
|
103 |
+
got_invalid_response = True
|
104 |
+
|
105 |
except ContextWindowExceededError as e:
|
106 |
logger.exception("Context window exceeded for model %s.", model.name)
|
107 |
raise gr.Error(
|
|
|
111 |
logger.exception("Failed to get response from model %s.", model.name)
|
112 |
raise gr.Error("Failed to get response. Please try again.") from e
|
113 |
|
114 |
+
if got_invalid_response:
|
115 |
+
gr.Warning("An invalid response was received.")
|
116 |
+
|
117 |
model_names = [model.name for model in models]
|
118 |
|
119 |
# It simulates concurrent stream response generation.
|