Spaces:
Sleeping
Sleeping
Upload Agent.py
Browse files
Agent.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#import openai
|
2 |
+
import os
|
3 |
+
import backoff
|
4 |
+
import time
|
5 |
+
import random
|
6 |
+
import traceback
|
7 |
+
#from openai.error import RateLimitError, APIError, ServiceUnavailableError, APIConnectionError
|
8 |
+
#from .openai_utils import OutOfQuotaException, AccessTerminatedException
|
9 |
+
#from .openai_utils import num_tokens_from_string, model2max_context
|
10 |
+
from together import Together
|
11 |
+
|
12 |
+
|
13 |
+
class Agent:
|
14 |
+
def __init__(self, model_name: str, name: str, temperature: float, sleep_time: float=0) -> None:
|
15 |
+
"""Create an agent
|
16 |
+
|
17 |
+
Args:
|
18 |
+
model_name(str): model name
|
19 |
+
name (str): name of this agent
|
20 |
+
temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
|
21 |
+
sleep_time (float): sleep because of rate limits
|
22 |
+
"""
|
23 |
+
self.model_name = model_name
|
24 |
+
self.name = name
|
25 |
+
self.temperature = temperature
|
26 |
+
self.memory_lst = []
|
27 |
+
self.sleep_time = sleep_time
|
28 |
+
self.client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))
|
29 |
+
|
30 |
+
def query(self, messages: "list[dict]", max_tokens: int, temperature: float) -> str:
|
31 |
+
"""make a query
|
32 |
+
|
33 |
+
Args:
|
34 |
+
messages (list[dict]): chat history in turbo format
|
35 |
+
max_tokens (int): max token in api call
|
36 |
+
api_key (str): openai api key
|
37 |
+
temperature (float): sampling temperature
|
38 |
+
|
39 |
+
Raises:
|
40 |
+
OutOfQuotaException: the apikey has out of quota
|
41 |
+
AccessTerminatedException: the apikey has been ban
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
str: the return msg
|
45 |
+
"""
|
46 |
+
time.sleep(self.sleep_time)
|
47 |
+
try:
|
48 |
+
#response = openai.ChatCompletion.create(
|
49 |
+
# model=self.model_name,
|
50 |
+
# messages=messages,
|
51 |
+
# temperature=temperature,
|
52 |
+
# max_tokens=max_tokens,
|
53 |
+
# api_key=api_key,
|
54 |
+
#)
|
55 |
+
#gen = response['choices'][0]['message']['content']
|
56 |
+
response = self.client.chat.completions.create(
|
57 |
+
model=self.model_name,
|
58 |
+
messages=messages,
|
59 |
+
max_tokens=512,
|
60 |
+
temperature=0.7,
|
61 |
+
top_p=0.7,
|
62 |
+
top_k=50,
|
63 |
+
repetition_penalty=1,
|
64 |
+
stop=["<|im_start|>","<|im_end|>"],
|
65 |
+
stream=False
|
66 |
+
)
|
67 |
+
#print(response.choices[0].message.content)
|
68 |
+
gen = response.choices[0].message.content
|
69 |
+
# 去除字符串中的所有 ```
|
70 |
+
cleaned_text = gen.replace('```', '')
|
71 |
+
return cleaned_text
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
print(f"An error occurred: {e}")
|
75 |
+
traceback.print_exc() # 打印详细的错误堆栈信息
|
76 |
+
|
77 |
+
def set_meta_prompt(self, meta_prompt: str):
|
78 |
+
"""Set the meta_prompt
|
79 |
+
|
80 |
+
Args:
|
81 |
+
meta_prompt (str): the meta prompt
|
82 |
+
"""
|
83 |
+
self.memory_lst.append({"role": "system", "content": f"{meta_prompt}"})
|
84 |
+
|
85 |
+
def add_event(self, event: str):
|
86 |
+
"""Add an new event in the memory
|
87 |
+
|
88 |
+
Args:
|
89 |
+
event (str): string that describe the event.
|
90 |
+
"""
|
91 |
+
self.memory_lst.append({"role": "user", "content": f"{event}"})
|
92 |
+
|
93 |
+
def add_memory(self, memory: str):
|
94 |
+
"""Monologue in the memory
|
95 |
+
|
96 |
+
Args:
|
97 |
+
memory (str): string that generated by the model in the last round.
|
98 |
+
"""
|
99 |
+
self.memory_lst.append({"role": "assistant", "content": f"{memory}"})
|
100 |
+
print(f"----- {self.name} -----\n{memory}\n")
|
101 |
+
|
102 |
+
def ask(self, temperature: float=None):
|
103 |
+
"""Query for answer
|
104 |
+
|
105 |
+
Args:
|
106 |
+
"""
|
107 |
+
# query
|
108 |
+
#num_context_token = sum([num_tokens_from_string(m["content"], self.model_name) for m in self.memory_lst])
|
109 |
+
#max_token = model2max_context - num_context_token
|
110 |
+
return self.query(self.memory_lst, 100, temperature=temperature if temperature else self.temperature)
|