Spaces:
Runtime error
Runtime error
Wonderplex
commited on
Commit
•
fe95067
1
Parent(s):
c3a4051
Feature/hf pipeline format (#47)
Browse files* removed older version utils
* changed requirements
* added changes for inference
- .gitignore +1 -0
- app.py +2 -2
- langchain_callback_handler.py +60 -0
- message_classes.py +343 -0
- requirements.txt +148 -6
- sotopia_pi_generate.py +57 -50
- utils.py +7 -99
.gitignore
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
__pycache__/
|
2 |
.cache/
|
3 |
openai_api.key
|
|
|
4 |
core
|
|
|
1 |
__pycache__/
|
2 |
.cache/
|
3 |
openai_api.key
|
4 |
+
hf_token.key
|
5 |
core
|
app.py
CHANGED
@@ -12,7 +12,7 @@ with open("openai_api.key", "r") as f:
|
|
12 |
os.environ["OPENAI_API_KEY"] = f.read().strip()
|
13 |
|
14 |
DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
|
15 |
-
DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo"
|
16 |
TEMPERATURE = 0.7
|
17 |
TOP_P = 1
|
18 |
MAX_TOKENS = 1024
|
@@ -147,7 +147,7 @@ def sotopia_info_accordion(accordion_visible=True):
|
|
147 |
interactive=True,
|
148 |
)
|
149 |
model_name_dropdown = gr.Dropdown(
|
150 |
-
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo", "gpt-4-turbo"],
|
151 |
value=DEFAULT_MODEL_SELECTION,
|
152 |
interactive=True,
|
153 |
label="Model Selection"
|
|
|
12 |
os.environ["OPENAI_API_KEY"] = f.read().strip()
|
13 |
|
14 |
DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
|
15 |
+
DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo"
|
16 |
TEMPERATURE = 0.7
|
17 |
TOP_P = 1
|
18 |
MAX_TOKENS = 1024
|
|
|
147 |
interactive=True,
|
148 |
)
|
149 |
model_name_dropdown = gr.Dropdown(
|
150 |
+
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo", "gpt-4-turbo"],
|
151 |
value=DEFAULT_MODEL_SELECTION,
|
152 |
interactive=True,
|
153 |
label="Model Selection"
|
langchain_callback_handler.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from langchain.callbacks import StdOutCallbackHandler
|
5 |
+
|
6 |
+
logging.addLevelName(15, "LangChain")
|
7 |
+
|
8 |
+
|
9 |
+
class LoggingCallbackHandler(StdOutCallbackHandler):
|
10 |
+
"""Callback Handler that prints to std out."""
|
11 |
+
|
12 |
+
always_verbose = True
|
13 |
+
|
14 |
+
def __init__(self, name: str) -> None:
|
15 |
+
"""Initialize callback handler."""
|
16 |
+
super().__init__()
|
17 |
+
self.logger = logging.getLogger(name)
|
18 |
+
self.prompt = ""
|
19 |
+
|
20 |
+
def on_chain_start(self, *args: Any, **kwargs: Any) -> None:
|
21 |
+
pass
|
22 |
+
|
23 |
+
def on_chain_end(self, *args: Any, **kwargs: Any) -> None:
|
24 |
+
pass
|
25 |
+
|
26 |
+
def on_agent_action(self, *args: Any, **kwargs: Any) -> Any:
|
27 |
+
pass
|
28 |
+
|
29 |
+
def on_tool_end(
|
30 |
+
self,
|
31 |
+
*args: Any,
|
32 |
+
**kwargs: Any,
|
33 |
+
) -> None:
|
34 |
+
pass
|
35 |
+
|
36 |
+
def on_tool_error(
|
37 |
+
self, error: BaseException | KeyboardInterrupt, **kwargs: Any
|
38 |
+
) -> None:
|
39 |
+
"""Do nothing."""
|
40 |
+
pass
|
41 |
+
|
42 |
+
def on_text(
|
43 |
+
self,
|
44 |
+
text: str,
|
45 |
+
color: str | None = None,
|
46 |
+
end: str = "",
|
47 |
+
**kwargs: Any,
|
48 |
+
) -> None:
|
49 |
+
"""Run when agent ends."""
|
50 |
+
# leave only prompt for environment
|
51 |
+
text = text.replace("\x1b[32;1m\x1b[1;3mHuman: ", "")
|
52 |
+
logging.log(15, f"LLM Call: {text}")
|
53 |
+
self.prompt = text
|
54 |
+
|
55 |
+
def retrive_prompt(self) -> str:
|
56 |
+
return self.prompt
|
57 |
+
|
58 |
+
def on_agent_finish(self, *args: Any, **kwargs: Any) -> None:
|
59 |
+
"""Run on agent end."""
|
60 |
+
pass
|
message_classes.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import Literal, cast
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
from utils import format_docstring
|
7 |
+
|
8 |
+
ActionType = Literal["none", "speak", "non-verbal communication", "action", "leave"]
|
9 |
+
|
10 |
+
|
11 |
+
class Message(BaseModel):
|
12 |
+
"""
|
13 |
+
An interface for messages.
|
14 |
+
There is only one required method: to_natural_language
|
15 |
+
"""
|
16 |
+
|
17 |
+
def to_natural_language(self) -> str:
|
18 |
+
raise NotImplementedError
|
19 |
+
|
20 |
+
|
21 |
+
class SimpleMessage(Message):
|
22 |
+
"""
|
23 |
+
A simple message with a single string field.
|
24 |
+
"""
|
25 |
+
|
26 |
+
message: str = Field(description="the message")
|
27 |
+
|
28 |
+
def to_natural_language(self) -> str:
|
29 |
+
return self.message
|
30 |
+
|
31 |
+
|
32 |
+
class Observation(Message):
|
33 |
+
last_turn: str = Field(description="the last turn of the conversation")
|
34 |
+
turn_number: int = Field(description="the turn number of the conversation")
|
35 |
+
available_actions: list[ActionType] = Field(description="the available actions")
|
36 |
+
|
37 |
+
def to_natural_language(self) -> str:
|
38 |
+
if self.turn_number == 0:
|
39 |
+
return f"\n{self.last_turn}\nConversation Starts:\n"
|
40 |
+
else:
|
41 |
+
return f"Turn #{self.turn_number-1}: {self.last_turn}\n"
|
42 |
+
|
43 |
+
|
44 |
+
class ScriptBackground(Message):
|
45 |
+
scenario: str = Field(description="scenario of the episode")
|
46 |
+
p1_name: str = Field(description="name of participant 1")
|
47 |
+
p2_name: str = Field(description="name of participant 2")
|
48 |
+
p1_background: str = Field(description="background of participant 1")
|
49 |
+
p2_background: str = Field(description="background of participant 2")
|
50 |
+
p1_goal: str = Field(description="goal of participant 1")
|
51 |
+
p2_goal: str = Field(description="goal of participant 2")
|
52 |
+
|
53 |
+
def to_natural_language(self) -> str:
|
54 |
+
if self.p1_background or self.p2_background:
|
55 |
+
p1_background = self.p1_background if self.p1_background else "Unknown"
|
56 |
+
p2_background = self.p2_background if self.p2_background else "Unknown"
|
57 |
+
# Not using AND, since in stranger relation the background is not visible
|
58 |
+
return format_docstring(
|
59 |
+
f"""Here is the context of this interaction:
|
60 |
+
Scenario: {self.scenario}
|
61 |
+
Participants: {self.p1_name} and {self.p2_name}
|
62 |
+
{self.p1_name}'s background: {p1_background}
|
63 |
+
{self.p2_name}'s background: {p2_background}
|
64 |
+
{self.p1_name}'s goal: {self.p1_goal}
|
65 |
+
{self.p2_name}'s goal: {self.p2_goal}
|
66 |
+
"""
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
return format_docstring(
|
70 |
+
f"""Here is the context of this interaction:
|
71 |
+
Scenario: {self.scenario}
|
72 |
+
Participants: {self.p1_name} and {self.p2_name}
|
73 |
+
{self.p1_name}'s goal: {self.p1_goal}
|
74 |
+
{self.p2_name}'s goal: {self.p2_goal}
|
75 |
+
"""
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
class ScriptEnvironmentResponse(Message):
|
80 |
+
terminated: bool = Field(
|
81 |
+
description="whether the conversation is terminated",
|
82 |
+
default_factory=lambda: False,
|
83 |
+
)
|
84 |
+
p1_rate: float | tuple[float, dict[str, float]] | None = Field(
|
85 |
+
description="rating of participant 1, on the scale of 1 to 10"
|
86 |
+
)
|
87 |
+
p2_rate: float | tuple[float, dict[str, float]] | None = Field(
|
88 |
+
description="rating of participant 2, on the scale of 1 to 10"
|
89 |
+
)
|
90 |
+
comments: str | None = Field(
|
91 |
+
description="All of the comments supporting the termination and rating"
|
92 |
+
)
|
93 |
+
|
94 |
+
def to_natural_language(self) -> str:
|
95 |
+
reason_to_stop = format_docstring(
|
96 |
+
f"""Environment response:
|
97 |
+
{"The conversation is terminated." if self.terminated else ""}
|
98 |
+
{"Rating of participant 1" + str(self.p1_rate) if self.p1_rate is not None else ""}
|
99 |
+
{"Rating of participant 2" + str(self.p2_rate) if self.p2_rate is not None else ""}
|
100 |
+
{self.comments if self.comments is not None else ""}
|
101 |
+
"""
|
102 |
+
)
|
103 |
+
clean_text = ""
|
104 |
+
for line in reason_to_stop.split("\n"):
|
105 |
+
if line.strip():
|
106 |
+
clean_text += line + "\n"
|
107 |
+
return clean_text
|
108 |
+
|
109 |
+
|
110 |
+
class AgentAction(Message):
|
111 |
+
action_type: ActionType = Field(
|
112 |
+
description="whether to speak at this turn or choose to not do anything"
|
113 |
+
)
|
114 |
+
argument: str = Field(
|
115 |
+
description="the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action"
|
116 |
+
)
|
117 |
+
|
118 |
+
def to_natural_language(self) -> str:
|
119 |
+
match self.action_type:
|
120 |
+
case "none":
|
121 |
+
return "did nothing"
|
122 |
+
case "speak":
|
123 |
+
return f'said: "{self.argument}"'
|
124 |
+
case "non-verbal communication":
|
125 |
+
return f"[{self.action_type}] {self.argument}"
|
126 |
+
case "action":
|
127 |
+
return f"[{self.action_type}] {self.argument}"
|
128 |
+
case "leave":
|
129 |
+
return "left the conversation"
|
130 |
+
|
131 |
+
|
132 |
+
ScriptInteractionReturnType = tuple[
|
133 |
+
list[list[tuple[str, str, Message]]], list[tuple[str, Message]]
|
134 |
+
]
|
135 |
+
|
136 |
+
|
137 |
+
class ScriptInteraction(Message):
|
138 |
+
interactions: str = Field(
|
139 |
+
description="""The interaction between the two participants in maximum 20 turns. Each turn is separated by a newline, and should only describe one agent. Following the structure:
|
140 |
+
Turn #x
|
141 |
+
[participant's name] [action] {argument for some actions}
|
142 |
+
|
143 |
+
You can use different types of actions, but only use one in each turn. You should move other information into argument part. Below shows a python code snippet of the format for each action type:
|
144 |
+
match self.action_type:
|
145 |
+
case "none":
|
146 |
+
return "did nothing"
|
147 |
+
case "speak":
|
148 |
+
return f'said: "{self.argument}"'
|
149 |
+
case "non-verbal communication":
|
150 |
+
return f"[{self.action_type}] {self.argument}"
|
151 |
+
case "action":
|
152 |
+
return f"[{self.action_type}] {self.argument}"
|
153 |
+
case "leave":
|
154 |
+
return "left the conversation"
|
155 |
+
|
156 |
+
For example, the following is acceptable:
|
157 |
+
Turn #x
|
158 |
+
Oliver Thompson said: "Hey Esmeralda, what's wrong? You seem upset."
|
159 |
+
Turn #x
|
160 |
+
Esmeralda Solis [action] moved closer
|
161 |
+
Turn #x
|
162 |
+
Oliver Thompson [non-verbal communication] smiled
|
163 |
+
Turn #x
|
164 |
+
Esmeralda Solis did nothing
|
165 |
+
Turn #x
|
166 |
+
Oliver Thompson left the conversation
|
167 |
+
Turn #x
|
168 |
+
Esmeralda Solis [action] leaned in and lowered her voice: "Sorry"
|
169 |
+
|
170 |
+
And the following is not acceptable:
|
171 |
+
Turn #1
|
172 |
+
Oliver Thompson [speak] said: "Hey Esmeralda, what's wrong? You seem upset."
|
173 |
+
Turn #1
|
174 |
+
Esmeralda Solis non-verbal communication moved closer
|
175 |
+
"""
|
176 |
+
)
|
177 |
+
|
178 |
+
def to_natural_language(self) -> str:
|
179 |
+
return self.interactions
|
180 |
+
|
181 |
+
def parse(
|
182 |
+
self, agent_names: list[str], background: str
|
183 |
+
) -> tuple[list[list[tuple[str, str, Message]]], list[tuple[str, Message]]]:
|
184 |
+
interaction = self.interactions
|
185 |
+
# print("Interaction: ", interaction)
|
186 |
+
lines = self.split_by_turn(interaction)
|
187 |
+
|
188 |
+
agent_results = []
|
189 |
+
results: list[list[tuple[str, str, Message]]] = [
|
190 |
+
[
|
191 |
+
(
|
192 |
+
"Environment",
|
193 |
+
name,
|
194 |
+
Observation(
|
195 |
+
last_turn=background,
|
196 |
+
turn_number=0,
|
197 |
+
available_actions=["none"],
|
198 |
+
),
|
199 |
+
)
|
200 |
+
for name in agent_names
|
201 |
+
]
|
202 |
+
]
|
203 |
+
|
204 |
+
for line_idx, line in enumerate(lines):
|
205 |
+
try:
|
206 |
+
res = self.parse_single_dialogue(line)
|
207 |
+
action: AgentAction = cast(AgentAction, res["action"])
|
208 |
+
argument: str = cast(str, res["argument"])
|
209 |
+
cast(int, res["turn"])
|
210 |
+
name: str = cast(str, res["name"])
|
211 |
+
|
212 |
+
parsed_action = AgentAction(action_type=action, argument=argument)
|
213 |
+
if name not in agent_names:
|
214 |
+
print(
|
215 |
+
f"The name of the agent, {name}, is not in the list of agent names, {agent_names}"
|
216 |
+
)
|
217 |
+
name = agent_names[
|
218 |
+
line_idx % 2
|
219 |
+
] # TODO Not sure what name to be set here
|
220 |
+
except Exception as e:
|
221 |
+
print(
|
222 |
+
f"Error when parsing the dialogue: {line}",
|
223 |
+
f"The error is: {e}",
|
224 |
+
)
|
225 |
+
raise e
|
226 |
+
parsed_action = AgentAction(action_type="none", argument="")
|
227 |
+
name = agent_names[line_idx % 2] # TODO same question as above
|
228 |
+
inactive_agent_name = (
|
229 |
+
agent_names[0] if name == agent_names[1] else agent_names[1]
|
230 |
+
)
|
231 |
+
results.append(
|
232 |
+
[
|
233 |
+
(
|
234 |
+
"Environment",
|
235 |
+
name,
|
236 |
+
Observation(
|
237 |
+
last_turn="environment is the agent",
|
238 |
+
turn_number=line_idx + 1,
|
239 |
+
available_actions=["none"],
|
240 |
+
),
|
241 |
+
)
|
242 |
+
for name in agent_names
|
243 |
+
]
|
244 |
+
+ [
|
245 |
+
(name, "Environment", parsed_action),
|
246 |
+
(
|
247 |
+
inactive_agent_name,
|
248 |
+
"Environment",
|
249 |
+
AgentAction(action_type="none", argument="did nothing"),
|
250 |
+
),
|
251 |
+
]
|
252 |
+
)
|
253 |
+
|
254 |
+
agent_results.append((name, parsed_action))
|
255 |
+
# print("Parsed agent results: ", agent_results)
|
256 |
+
return (results, agent_results) # type: ignore
|
257 |
+
|
258 |
+
def parse_single_dialogue(
|
259 |
+
self, dialogue: str
|
260 |
+
) -> dict[str, str | int | AgentAction | None]:
|
261 |
+
"""Parse a single dialogue string and return a dictionary with turn, name, action, and argument."""
|
262 |
+
|
263 |
+
# Match the turn number and name. Assume all agent name starts with a capital letter and is followed by lowercase letters
|
264 |
+
match_turn_name = re.match(
|
265 |
+
r"Turn #?(\d+):?\s*\n((?:[A-Z]['a-z]* ?)+)", dialogue
|
266 |
+
)
|
267 |
+
|
268 |
+
if not match_turn_name:
|
269 |
+
raise ValueError(
|
270 |
+
f"The dialogue does not match the expected format: {dialogue}"
|
271 |
+
)
|
272 |
+
return None # TODO Which should we use, return None or raise error?
|
273 |
+
|
274 |
+
turn, name = match_turn_name.groups()
|
275 |
+
action_content = dialogue[
|
276 |
+
len(match_turn_name.group(0)) :
|
277 |
+
].strip() # Extract the action content
|
278 |
+
|
279 |
+
# Check for different action types
|
280 |
+
if "did nothing" in action_content:
|
281 |
+
action, argument = "none", ""
|
282 |
+
elif match := re.match(r'said: "(.*?)"', action_content):
|
283 |
+
action, argument = "speak", match.group(1)
|
284 |
+
action, argument = action.strip(), argument.strip()
|
285 |
+
elif match := re.match(r'\[speak\] said: "(.*?)"', action_content):
|
286 |
+
action, argument = "speak", match.group(1)
|
287 |
+
action, argument = action.strip(), argument.strip()
|
288 |
+
elif match := re.match(
|
289 |
+
r"\[(non-verbal communication|action)\] (.*)", action_content
|
290 |
+
):
|
291 |
+
action, argument = match.groups()
|
292 |
+
elif "left the conversation" in action_content:
|
293 |
+
# TODO Make it more elegant to handle the situation of `left the conversation.`
|
294 |
+
action, argument = "leave", ""
|
295 |
+
else:
|
296 |
+
action, argument = None, None
|
297 |
+
|
298 |
+
parsed_item = {
|
299 |
+
"turn": int(turn),
|
300 |
+
"name": name.strip(),
|
301 |
+
"action": action,
|
302 |
+
"argument": argument,
|
303 |
+
}
|
304 |
+
return parsed_item
|
305 |
+
|
306 |
+
def split_by_turn(self, input_string: str) -> list[str]:
|
307 |
+
"""Split the input dialogue string by turn and return a list of dialogues."""
|
308 |
+
# Split using 'Turn #' as delimiter, but keep the delimiter in the results
|
309 |
+
dialogues = re.split(r"(?=Turn #?\d+)", input_string)
|
310 |
+
# Remove any empty strings and strip whitespace
|
311 |
+
dialogues = [dialogue.strip() for dialogue in dialogues if dialogue.strip()]
|
312 |
+
dialogues = [dialogue for dialogue in dialogues if dialogue.startswith("Turn")]
|
313 |
+
# Change from Turn #x to Turn (#)x (# is optional)
|
314 |
+
dialogues[-1] = "\n".join(
|
315 |
+
dialogues[-1].split("\n")[:2]
|
316 |
+
) # Discard further input in the last turn
|
317 |
+
|
318 |
+
for dialogue in dialogues:
|
319 |
+
# TODO this is current workaround for the issue of multiple agents in one turn
|
320 |
+
if len(dialogue.split("\n")) >= 3:
|
321 |
+
raise ValueError("Only one agent can act per turn.")
|
322 |
+
return dialogues
|
323 |
+
|
324 |
+
@staticmethod
|
325 |
+
def default_value_for_return_type() -> ScriptInteractionReturnType:
|
326 |
+
results_1: list[list[tuple[str, str, Message]]] = [
|
327 |
+
[
|
328 |
+
(
|
329 |
+
"Environment",
|
330 |
+
name,
|
331 |
+
Observation(
|
332 |
+
last_turn="Environment is the agent",
|
333 |
+
turn_number=0,
|
334 |
+
available_actions=["none"],
|
335 |
+
),
|
336 |
+
)
|
337 |
+
for name in ["none", "none"]
|
338 |
+
]
|
339 |
+
]
|
340 |
+
results_2: list[tuple[str, Message]] = [
|
341 |
+
("", AgentAction(action_type="none", argument=""))
|
342 |
+
]
|
343 |
+
return (results_1, results_2)
|
requirements.txt
CHANGED
@@ -1,6 +1,148 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
accelerate==0.29.3
|
3 |
+
aiofiles==23.2.1
|
4 |
+
aiohttp==3.9.5
|
5 |
+
aiosignal==1.3.1
|
6 |
+
altair==5.3.0
|
7 |
+
annotated-types==0.6.0
|
8 |
+
anyio==3.7.1
|
9 |
+
attrs==23.2.0
|
10 |
+
beartype==0.14.1
|
11 |
+
bitsandbytes==0.43.1
|
12 |
+
certifi==2024.2.2
|
13 |
+
cffi==1.16.0
|
14 |
+
charset-normalizer==3.3.2
|
15 |
+
click==8.1.7
|
16 |
+
cloudpickle==3.0.0
|
17 |
+
contourpy==1.2.1
|
18 |
+
cryptography==42.0.5
|
19 |
+
cycler==0.12.1
|
20 |
+
dataclasses-json==0.6.4
|
21 |
+
datasets==2.18.0
|
22 |
+
dill==0.3.8
|
23 |
+
distro==1.9.0
|
24 |
+
Farama-Notifications==0.0.4
|
25 |
+
ffmpy==0.3.2
|
26 |
+
filelock==3.13.4
|
27 |
+
fonttools==4.51.0
|
28 |
+
frozenlist==1.4.1
|
29 |
+
fsspec==2024.2.0
|
30 |
+
gin-config==0.5.0
|
31 |
+
gradio==4.27.0
|
32 |
+
gradio_client==0.15.1
|
33 |
+
greenlet==3.0.3
|
34 |
+
gymnasium==0.29.1
|
35 |
+
h11==0.14.0
|
36 |
+
hiredis==2.3.2
|
37 |
+
httpcore==1.0.5
|
38 |
+
httpx==0.27.0
|
39 |
+
huggingface-hub==0.22.2
|
40 |
+
idna==3.7
|
41 |
+
importlib_metadata==7.1.0
|
42 |
+
importlib_resources==6.4.0
|
43 |
+
Jinja2==3.1.3
|
44 |
+
jsonpatch==1.33
|
45 |
+
jsonpointer==2.4
|
46 |
+
jsonschema==4.21.1
|
47 |
+
jsonschema-specifications==2023.12.1
|
48 |
+
kiwisolver==1.4.5
|
49 |
+
langchain==0.1.16
|
50 |
+
langchain-community==0.0.33
|
51 |
+
langchain-core==0.1.44
|
52 |
+
langchain-openai==0.0.5
|
53 |
+
langchain-text-splitters==0.0.1
|
54 |
+
langsmith==0.1.48
|
55 |
+
litellm==1.35.12
|
56 |
+
lxml==4.9.4
|
57 |
+
markdown-it-py==3.0.0
|
58 |
+
MarkupSafe==2.1.5
|
59 |
+
marshmallow==3.21.1
|
60 |
+
matplotlib==3.8.4
|
61 |
+
mdurl==0.1.2
|
62 |
+
more-itertools==10.2.0
|
63 |
+
mpmath==1.3.0
|
64 |
+
multidict==6.0.5
|
65 |
+
multiprocess==0.70.16
|
66 |
+
mypy==1.9.0
|
67 |
+
mypy-extensions==1.0.0
|
68 |
+
names==0.3.0
|
69 |
+
networkx==3.3
|
70 |
+
numpy==1.26.4
|
71 |
+
nvidia-cublas-cu12==12.1.3.1
|
72 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
73 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
74 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
75 |
+
nvidia-cudnn-cu12==8.9.2.26
|
76 |
+
nvidia-cufft-cu12==11.0.2.54
|
77 |
+
nvidia-curand-cu12==10.3.2.106
|
78 |
+
nvidia-cusolver-cu12==11.4.5.107
|
79 |
+
nvidia-cusparse-cu12==12.1.0.106
|
80 |
+
nvidia-nccl-cu12==2.19.3
|
81 |
+
nvidia-nvjitlink-cu12==12.4.127
|
82 |
+
nvidia-nvtx-cu12==12.1.105
|
83 |
+
openai==1.22.0
|
84 |
+
orjson==3.10.1
|
85 |
+
packaging==23.2
|
86 |
+
pandas==2.2.2
|
87 |
+
pandas-stubs==2.2.1.240316
|
88 |
+
peft==0.10.0
|
89 |
+
pettingzoo==1.24.0
|
90 |
+
pillow==10.3.0
|
91 |
+
psutil==5.9.8
|
92 |
+
pyarrow==15.0.2
|
93 |
+
pyarrow-hotfix==0.6
|
94 |
+
pycparser==2.22
|
95 |
+
pydantic==2.7.0
|
96 |
+
pydantic_core==2.18.1
|
97 |
+
pydub==0.25.1
|
98 |
+
Pygments==2.17.2
|
99 |
+
pyparsing==3.1.2
|
100 |
+
python-dateutil==2.9.0.post0
|
101 |
+
python-dotenv==1.0.1
|
102 |
+
python-multipart==0.0.9
|
103 |
+
python-ulid==1.1.0
|
104 |
+
pytz==2024.1
|
105 |
+
PyYAML==6.0.1
|
106 |
+
redis==5.0.3
|
107 |
+
referencing==0.34.0
|
108 |
+
regex==2024.4.16
|
109 |
+
requests==2.31.0
|
110 |
+
rich==13.7.1
|
111 |
+
rpds-py==0.18.0
|
112 |
+
ruff==0.3.7
|
113 |
+
safetensors==0.4.3
|
114 |
+
scipy==1.13.0
|
115 |
+
semantic-version==2.10.0
|
116 |
+
shellingham==1.5.4
|
117 |
+
six==1.16.0
|
118 |
+
sniffio==1.3.1
|
119 |
+
SQLAlchemy==2.0.29
|
120 |
+
sseclient-py==1.8.0
|
121 |
+
starlette==0.27.0
|
122 |
+
sympy==1.12
|
123 |
+
tabulate==0.9.0
|
124 |
+
tenacity==8.2.3
|
125 |
+
tiktoken==0.5.2
|
126 |
+
tokenizers==0.19.1
|
127 |
+
tomlkit==0.12.0
|
128 |
+
toolz==0.12.1
|
129 |
+
torch==2.2.2
|
130 |
+
tqdm==4.66.2
|
131 |
+
transformers==4.40.0
|
132 |
+
triton==2.2.0
|
133 |
+
typer==0.12.3
|
134 |
+
types-cffi==1.16.0.20240331
|
135 |
+
types-pyOpenSSL==24.0.0.20240417
|
136 |
+
types-pytz==2024.1.0.20240417
|
137 |
+
types-redis==4.6.0.20240417
|
138 |
+
types-setuptools==69.5.0.20240415
|
139 |
+
types-tqdm==4.66.0.20240417
|
140 |
+
typing-inspect==0.9.0
|
141 |
+
typing_extensions==4.11.0
|
142 |
+
tzdata==2024.1
|
143 |
+
urllib3==2.2.1
|
144 |
+
uvicorn==0.23.2
|
145 |
+
websockets==11.0.3
|
146 |
+
xxhash==3.4.1
|
147 |
+
yarl==1.9.4
|
148 |
+
zipp==3.18.1
|
sotopia_pi_generate.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1 |
import re
|
|
|
|
|
|
|
2 |
|
3 |
import torch
|
4 |
-
from
|
5 |
from transformers import (
|
6 |
AutoModelForCausalLM,
|
7 |
AutoTokenizer,
|
8 |
BitsAndBytesConfig,
|
9 |
)
|
10 |
-
|
11 |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
12 |
from langchain_community.chat_models import ChatLiteLLM
|
13 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
14 |
-
|
15 |
from langchain.chains import LLMChain
|
16 |
from langchain.output_parsers import PydanticOutputParser
|
17 |
from langchain.prompts import (
|
@@ -20,17 +21,16 @@ from langchain.prompts import (
|
|
20 |
PromptTemplate,
|
21 |
)
|
22 |
from langchain.schema import BaseOutputParser, OutputParserException
|
23 |
-
from
|
|
|
24 |
|
25 |
-
from
|
26 |
-
from sotopia.utils import format_docstring
|
27 |
-
from functools import cache
|
28 |
-
import logging
|
29 |
|
30 |
-
|
31 |
|
|
|
32 |
log = logging.getLogger("generate")
|
33 |
-
|
34 |
|
35 |
def generate_action(
|
36 |
model_name: str,
|
@@ -39,7 +39,7 @@ def generate_action(
|
|
39 |
action_types: list[ActionType],
|
40 |
agent: str,
|
41 |
temperature: float = 0.7,
|
42 |
-
) ->
|
43 |
"""
|
44 |
Using langchain to generate an example episode
|
45 |
"""
|
@@ -73,14 +73,26 @@ def generate_action(
|
|
73 |
temperature=temperature,
|
74 |
)
|
75 |
except Exception:
|
76 |
-
return AgentAction(action_type="none", argument="")
|
77 |
|
78 |
@cache
|
79 |
-
def prepare_model(model_name):
|
80 |
compute_type = torch.float16
|
|
|
|
|
81 |
|
82 |
-
if 'cmu-lti/sotopia-pi-mistral-7b-BC_SR'
|
83 |
-
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
model = AutoModelForCausalLM.from_pretrained(
|
85 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
86 |
cache_dir="./.cache",
|
@@ -91,11 +103,22 @@ def prepare_model(model_name):
|
|
91 |
bnb_4bit_quant_type="nf4",
|
92 |
bnb_4bit_compute_dtype=compute_type,
|
93 |
),
|
94 |
-
token=
|
95 |
)
|
96 |
model = PeftModel.from_pretrained(model, model_name).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
else:
|
98 |
raise RuntimeError(f"Model {model_name} not supported")
|
|
|
99 |
return model, tokenizer
|
100 |
|
101 |
def obtain_chain_hf(
|
@@ -111,9 +134,17 @@ def obtain_chain_hf(
|
|
111 |
)
|
112 |
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
|
113 |
model, tokenizer = prepare_model(model_name)
|
114 |
-
pipe = pipeline("text-generation",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
hf = HuggingFacePipeline(pipeline=pipe)
|
116 |
-
# import pdb; pdb.set_trace()
|
117 |
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
|
118 |
return chain
|
119 |
|
@@ -123,7 +154,7 @@ def generate(
|
|
123 |
input_values: dict[str, str],
|
124 |
output_parser: BaseOutputParser[OutputType],
|
125 |
temperature: float = 0.7,
|
126 |
-
) ->
|
127 |
# import pdb; pdb.set_trace()
|
128 |
input_variables = re.findall(r"{(.*?)}", template)
|
129 |
assert (
|
@@ -135,8 +166,9 @@ def generate(
|
|
135 |
chain = obtain_chain(model_name, template, input_variables, temperature)
|
136 |
if "format_instructions" not in input_values:
|
137 |
input_values["format_instructions"] = output_parser.get_format_instructions()
|
138 |
-
result = chain.predict([], **input_values)
|
139 |
-
|
|
|
140 |
try:
|
141 |
parsed_result = output_parser.parse(result)
|
142 |
except KeyboardInterrupt:
|
@@ -146,6 +178,7 @@ def generate(
|
|
146 |
f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse",
|
147 |
extra={"markup": True},
|
148 |
)
|
|
|
149 |
reformat_parsed_result = format_bad_output(
|
150 |
result, format_instructions=output_parser.get_format_instructions()
|
151 |
)
|
@@ -175,7 +208,7 @@ def format_bad_output(
|
|
175 |
"ill_formed_output": ill_formed_output,
|
176 |
"format_instructions": format_instructions,
|
177 |
}
|
178 |
-
reformat = chain.predict([], **input_values)
|
179 |
log.info(f"Reformated output: {reformat}")
|
180 |
return reformat
|
181 |
|
@@ -189,7 +222,7 @@ def obtain_chain(
|
|
189 |
"""
|
190 |
Using langchain to sample profiles for participants
|
191 |
"""
|
192 |
-
if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR"]:
|
193 |
return obtain_chain_hf(
|
194 |
model_name=model_name,
|
195 |
template=template,
|
@@ -212,32 +245,6 @@ def obtain_chain(
|
|
212 |
chain = LLMChain(llm=chat, prompt=chat_prompt_template)
|
213 |
return chain
|
214 |
|
215 |
-
def format_bad_output(
|
216 |
-
ill_formed_output: str,
|
217 |
-
format_instructions: str,
|
218 |
-
model_name: str = "gpt-3.5-turbo",
|
219 |
-
) -> str:
|
220 |
-
template = """
|
221 |
-
Given the string that can not be parsed by json parser, reformat it to a string that can be parsed by json parser.
|
222 |
-
Original string: {ill_formed_output}
|
223 |
-
|
224 |
-
Format instructions: {format_instructions}
|
225 |
-
|
226 |
-
Please only generate the JSON:
|
227 |
-
"""
|
228 |
-
chain = obtain_chain(
|
229 |
-
model_name=model_name,
|
230 |
-
template=template,
|
231 |
-
input_variables=re.findall(r"{(.*?)}", template),
|
232 |
-
)
|
233 |
-
input_values = {
|
234 |
-
"ill_formed_output": ill_formed_output,
|
235 |
-
"format_instructions": format_instructions,
|
236 |
-
}
|
237 |
-
reformat = chain.predict([], **input_values)
|
238 |
-
log.info(f"Reformated output: {reformat}")
|
239 |
-
return reformat
|
240 |
-
|
241 |
def _return_fixed_model_version(model_name: str) -> str:
|
242 |
return {
|
243 |
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
|
|
|
1 |
import re
|
2 |
+
from typing import TypeVar
|
3 |
+
from functools import cache
|
4 |
+
import logging
|
5 |
|
6 |
import torch
|
7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
10 |
AutoTokenizer,
|
11 |
BitsAndBytesConfig,
|
12 |
)
|
13 |
+
from peft import PeftModel
|
14 |
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
|
15 |
from langchain_community.chat_models import ChatLiteLLM
|
|
|
|
|
16 |
from langchain.chains import LLMChain
|
17 |
from langchain.output_parsers import PydanticOutputParser
|
18 |
from langchain.prompts import (
|
|
|
21 |
PromptTemplate,
|
22 |
)
|
23 |
from langchain.schema import BaseOutputParser, OutputParserException
|
24 |
+
from message_classes import ActionType, AgentAction
|
25 |
+
from utils import format_docstring
|
26 |
|
27 |
+
from langchain_callback_handler import LoggingCallbackHandler
|
|
|
|
|
|
|
28 |
|
29 |
+
HF_TOKEN_KEY_FILE="./hf_token.key"
|
30 |
|
31 |
+
OutputType = TypeVar("OutputType", bound=object)
|
32 |
log = logging.getLogger("generate")
|
33 |
+
logging_handler = LoggingCallbackHandler("langchain")
|
34 |
|
35 |
def generate_action(
|
36 |
model_name: str,
|
|
|
39 |
action_types: list[ActionType],
|
40 |
agent: str,
|
41 |
temperature: float = 0.7,
|
42 |
+
) -> AgentAction:
|
43 |
"""
|
44 |
Using langchain to generate an example episode
|
45 |
"""
|
|
|
73 |
temperature=temperature,
|
74 |
)
|
75 |
except Exception:
|
76 |
+
return AgentAction(action_type="none", argument="")
|
77 |
|
78 |
@cache
|
79 |
+
def prepare_model(model_name, hf_token_key_file=HF_TOKEN_KEY_FILE):
|
80 |
compute_type = torch.float16
|
81 |
+
with open (hf_token_key_file, 'r') as f:
|
82 |
+
hf_token = f.read().strip()
|
83 |
|
84 |
+
if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
|
85 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
|
86 |
+
model = AutoModelForCausalLM.from_pretrained(
|
87 |
+
"mistralai/Mistral-7B-Instruct-v0.1",
|
88 |
+
cache_dir="./.cache",
|
89 |
+
device_map='cuda',
|
90 |
+
token=hf_token
|
91 |
+
)
|
92 |
+
model = PeftModel.from_pretrained(model, model_name).to("cuda")
|
93 |
+
|
94 |
+
elif model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit':
|
95 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
|
96 |
model = AutoModelForCausalLM.from_pretrained(
|
97 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
98 |
cache_dir="./.cache",
|
|
|
103 |
bnb_4bit_quant_type="nf4",
|
104 |
bnb_4bit_compute_dtype=compute_type,
|
105 |
),
|
106 |
+
token=hf_token
|
107 |
)
|
108 |
model = PeftModel.from_pretrained(model, model_name).to("cuda")
|
109 |
+
|
110 |
+
elif model_name == 'mistralai/Mistral-7B-Instruct-v0.1':
|
111 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
|
112 |
+
model = AutoModelForCausalLM.from_pretrained(
|
113 |
+
"mistralai/Mistral-7B-Instruct-v0.1",
|
114 |
+
cache_dir="./.cache",
|
115 |
+
device_map='cuda',
|
116 |
+
token=hf_token
|
117 |
+
)
|
118 |
+
|
119 |
else:
|
120 |
raise RuntimeError(f"Model {model_name} not supported")
|
121 |
+
|
122 |
return model, tokenizer
|
123 |
|
124 |
def obtain_chain_hf(
|
|
|
134 |
)
|
135 |
chat_prompt_template = ChatPromptTemplate.from_messages([human_message_prompt])
|
136 |
model, tokenizer = prepare_model(model_name)
|
137 |
+
pipe = pipeline("text-generation",
|
138 |
+
model=model,
|
139 |
+
tokenizer=tokenizer,
|
140 |
+
max_new_tokens=100,
|
141 |
+
temperature=temperature,
|
142 |
+
return_full_text=False,
|
143 |
+
do_sample=True,
|
144 |
+
num_beams=3,
|
145 |
+
length_penalty=-1.0,
|
146 |
+
)
|
147 |
hf = HuggingFacePipeline(pipeline=pipe)
|
|
|
148 |
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
|
149 |
return chain
|
150 |
|
|
|
154 |
input_values: dict[str, str],
|
155 |
output_parser: BaseOutputParser[OutputType],
|
156 |
temperature: float = 0.7,
|
157 |
+
) -> OutputType:
|
158 |
# import pdb; pdb.set_trace()
|
159 |
input_variables = re.findall(r"{(.*?)}", template)
|
160 |
assert (
|
|
|
166 |
chain = obtain_chain(model_name, template, input_variables, temperature)
|
167 |
if "format_instructions" not in input_values:
|
168 |
input_values["format_instructions"] = output_parser.get_format_instructions()
|
169 |
+
result = chain.predict([logging_handler], **input_values)
|
170 |
+
prompt = logging_handler.retrive_prompt()
|
171 |
+
import pdb; pdb.set_trace()
|
172 |
try:
|
173 |
parsed_result = output_parser.parse(result)
|
174 |
except KeyboardInterrupt:
|
|
|
178 |
f"[red] Failed to parse result: {result}\nEncounter Exception {e}\nstart to reparse",
|
179 |
extra={"markup": True},
|
180 |
)
|
181 |
+
import pdb; pdb.set_trace()
|
182 |
reformat_parsed_result = format_bad_output(
|
183 |
result, format_instructions=output_parser.get_format_instructions()
|
184 |
)
|
|
|
208 |
"ill_formed_output": ill_formed_output,
|
209 |
"format_instructions": format_instructions,
|
210 |
}
|
211 |
+
reformat = chain.predict([logging_handler], **input_values)
|
212 |
log.info(f"Reformated output: {reformat}")
|
213 |
return reformat
|
214 |
|
|
|
222 |
"""
|
223 |
Using langchain to sample profiles for participants
|
224 |
"""
|
225 |
+
if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit"]:
|
226 |
return obtain_chain_hf(
|
227 |
model_name=model_name,
|
228 |
template=template,
|
|
|
245 |
chain = LLMChain(llm=chat, prompt=chat_prompt_template)
|
246 |
return chain
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
def _return_fixed_model_version(model_name: str) -> str:
|
249 |
return {
|
250 |
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
|
utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from typing import List, Tuple
|
2 |
import ast
|
|
|
3 |
|
4 |
class Agent:
|
5 |
def __init__(self, agent_profile):
|
@@ -31,80 +32,10 @@ class Environment:
|
|
31 |
self.agent_goals = env_profile["agent_goals"]
|
32 |
self.relationship = env_profile["relationship"]
|
33 |
|
34 |
-
|
35 |
-
return """ Your available action types are
|
36 |
-
"none action speak non-verbal communication leave".
|
37 |
-
Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
|
38 |
-
|
39 |
-
Please only generate a JSON string including the action type and the argument.
|
40 |
-
Your action should follow the given format:
|
41 |
-
\nAs an example, for the schema {\"properties\": {\"foo\": {\"title\": \"Foo\", \"description\": \"a list of strings\", \"type\": \"array\", \"items\": {\"type\": \"string\"}}}, \"required\": [\"foo\"]}
|
42 |
-
the object {\"foo\": [\"bar\", \"baz\"]} is a well-formatted instance of the schema. The object {\"properties\": {\"foo\": [\"bar\", \"baz\"]}} is not well-formatted.
|
43 |
-
\nHere is the output schema:\n```\n{\"description\": \"An interface for messages.\\nThere is only one required method: to_natural_language\", \"properties\": {\"action_type\": {\"title\": \"Action Type\", \"description\": \"whether to speak at this turn or choose to not do anything\", \"enum\": [\"none\", \"speak\", \"non-verbal communication\", \"action\", \"leave\"], \"type\": \"string\"}, \"argument\": {\"title\": \"Argument\", \"description\": \"the utterance if choose to speak, the expression or gesture if choose non-verbal communication, or the physical action if choose action\", \"type\": \"string\"}}, \"required\": [\"action_type\", \"argument\"]}\n```\u001b[0m
|
44 |
-
"""
|
45 |
-
|
46 |
-
def get_starter_prompt(machine_agent, human_agent, environment):
|
47 |
-
return f"Imagine you are {machine_agent.name}, your task is to act/speak as {machine_agent.name} would, keeping in mind {machine_agent.name}'s social goal.\nYou can find {machine_agent.name}'s background and goal in the 'Here is the context of the interaction' field.\nNote that {machine_agent.name}'s secret and goal is only visible to you.\nYou should try your best to achieve {machine_agent.name}'s goal in a way that align with their character traits.\nAdditionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n\nHere is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
|
48 |
-
|
49 |
def get_context_prompt(machine_agent, human_agent, environment):
|
50 |
return f"Here is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
|
51 |
-
|
52 |
-
|
53 |
-
# we define history as
|
54 |
-
# [(user_message, bot_message), (user_message, bot_message)]
|
55 |
-
|
56 |
-
# we define dialogue history as
|
57 |
-
# user_name: user_message\nbot_name: bot_message\nuser_name: user_message\nbot_name: bot_message\n
|
58 |
-
|
59 |
-
|
60 |
-
def dialogue_history_length_check(string, max_token, tokenizer):
|
61 |
-
prompt_tokens = len(tokenizer(string)["input_ids"])
|
62 |
-
return max(prompt_tokens - max_token, 0)
|
63 |
-
|
64 |
-
|
65 |
-
def truncate_dialogue_history_to_length(dia_his, surpass_num, tokenizer):
|
66 |
-
dia_sen = dia_his.split("\n")
|
67 |
-
remove_len = 0
|
68 |
-
i = 0
|
69 |
-
while remove_len < surpass_num:
|
70 |
-
remove_len += len(tokenizer(dia_sen[i])["input_ids"])
|
71 |
-
i += 1
|
72 |
-
trunc_dia = "\n".join(p for p in dia_sen[i:])
|
73 |
-
return trunc_dia
|
74 |
-
|
75 |
-
|
76 |
-
def format_bot_message(bot_message) -> str:
|
77 |
-
# # import pdb; pdb.set_trace()
|
78 |
-
start_idx, end_idx = bot_message.index("{"), bot_message.index("}")
|
79 |
-
if end_idx == -1:
|
80 |
-
bot_message += "'}"
|
81 |
-
end_idx = len(bot_message)
|
82 |
-
json_response = ast.literal_eval(bot_message[start_idx:end_idx+1])
|
83 |
-
match json_response["action_type"]:
|
84 |
-
case "none":
|
85 |
-
return 'did nothing'
|
86 |
-
case "speak":
|
87 |
-
return json_response["argument"]
|
88 |
-
case "non-verbal communication":
|
89 |
-
return f'[{json_response["action_type"]}] {json_response["argument"]}'
|
90 |
-
case "action":
|
91 |
-
return f'[{json_response["action_type"]}] {json_response["argument"]}'
|
92 |
-
case "leave":
|
93 |
-
return 'left the conversation'
|
94 |
-
|
95 |
-
def dialogue_history_creation(history, user_name, bot_name):
|
96 |
-
dialogue_history = ""
|
97 |
-
for idx, turn in enumerate(history):
|
98 |
-
user_message, bot_message = turn
|
99 |
-
# TODOTODO (haofeiyu): we first assume that human talks first
|
100 |
-
user_turn_idx = idx * 2
|
101 |
-
bot_turn_idx = idx * 2 + 1
|
102 |
-
if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
|
103 |
-
bot_message = "said :" + bot_message
|
104 |
-
dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_name}: {bot_message}"
|
105 |
-
last_turn_idx = len(history) * 2
|
106 |
-
return dialogue_history, last_turn_idx
|
107 |
-
|
108 |
def dialogue_history_prompt(message, history, user_agent, bot_agent):
|
109 |
dialogue_history = ""
|
110 |
for idx, turn in enumerate(history):
|
@@ -117,31 +48,8 @@ def dialogue_history_prompt(message, history, user_agent, bot_agent):
|
|
117 |
dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_agent.name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_agent.name}: {bot_message}"
|
118 |
last_turn_idx = len(history) * 2
|
119 |
dialogue_history = f"{dialogue_history}\n\nTurn #{last_turn_idx+1}: {user_agent.name}: {message}\n."
|
120 |
-
return dialogue_history, last_turn_idx+2
|
121 |
-
|
122 |
-
|
123 |
-
def dialogue_history_truncation(dialogue_history, max_token_num, tokenizer):
|
124 |
-
surpass_num = dialogue_history_length_check(
|
125 |
-
dialogue_history, max_token_num, tokenizer
|
126 |
-
)
|
127 |
-
if surpass_num > 0:
|
128 |
-
dialogue_history = truncate_dialogue_history_to_length(
|
129 |
-
dialogue_history, surpass_num, tokenizer
|
130 |
-
)
|
131 |
-
return dialogue_history
|
132 |
-
|
133 |
|
134 |
-
def
|
135 |
-
|
136 |
-
|
137 |
-
instructions: str,
|
138 |
-
user_name: str,
|
139 |
-
bot_name: str,
|
140 |
-
) -> str:
|
141 |
-
prompt = instructions.strip()
|
142 |
-
dialogue_history, last_turn_idx = dialogue_history_creation(
|
143 |
-
history, user_name, bot_name
|
144 |
-
)
|
145 |
-
prompt = f"{prompt}\n{dialogue_history}"
|
146 |
-
prompt = f"{prompt}\n\nTurn #{last_turn_idx+1}: {user_name}: {message}\n.\nYou are at Turn #{last_turn_idx+2}."
|
147 |
-
return prompt
|
|
|
1 |
from typing import List, Tuple
|
2 |
import ast
|
3 |
+
import re
|
4 |
|
5 |
class Agent:
|
6 |
def __init__(self, agent_profile):
|
|
|
32 |
self.agent_goals = env_profile["agent_goals"]
|
33 |
self.relationship = env_profile["relationship"]
|
34 |
|
35 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def get_context_prompt(machine_agent, human_agent, environment):
|
37 |
return f"Here is the context of this interaction:\n Scenario: {environment.scenario}\nParticipants: {human_agent.name} and {machine_agent.name}\n{human_agent.name}'s background: {human_agent.background} Personality and values description: {human_agent.personality} \n{machine_agent.name}'s background: {machine_agent.background} Personality and values description: {machine_agent.personality} {machine_agent.name}'s secrets: {machine_agent.secret}\n{human_agent.name}'s goal: Unknown\n{machine_agent.name}'s goal: {environment.agent_goals[1]}\nConversation Starts:"
|
38 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def dialogue_history_prompt(message, history, user_agent, bot_agent):
|
40 |
dialogue_history = ""
|
41 |
for idx, turn in enumerate(history):
|
|
|
48 |
dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_agent.name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_agent.name}: {bot_message}"
|
49 |
last_turn_idx = len(history) * 2
|
50 |
dialogue_history = f"{dialogue_history}\n\nTurn #{last_turn_idx+1}: {user_agent.name}: {message}\n."
|
51 |
+
return dialogue_history, last_turn_idx + 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
def format_docstring(docstring: str) -> str:
|
54 |
+
"""Format a docstring for use in a prompt template."""
|
55 |
+
return re.sub("\n +", "\n", docstring).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|