dj86 commited on
Commit
5827a20
1 Parent(s): 9b2ec19

Upload Agent.py

Browse files
Files changed (1) hide show
  1. Agent.py +110 -0
Agent.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import openai
2
+ import os
3
+ import backoff
4
+ import time
5
+ import random
6
+ import traceback
7
+ #from openai.error import RateLimitError, APIError, ServiceUnavailableError, APIConnectionError
8
+ #from .openai_utils import OutOfQuotaException, AccessTerminatedException
9
+ #from .openai_utils import num_tokens_from_string, model2max_context
10
+ from together import Together
11
+
12
+
13
+ class Agent:
14
+ def __init__(self, model_name: str, name: str, temperature: float, sleep_time: float=0) -> None:
15
+ """Create an agent
16
+
17
+ Args:
18
+ model_name(str): model name
19
+ name (str): name of this agent
20
+ temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
21
+ sleep_time (float): sleep because of rate limits
22
+ """
23
+ self.model_name = model_name
24
+ self.name = name
25
+ self.temperature = temperature
26
+ self.memory_lst = []
27
+ self.sleep_time = sleep_time
28
+ self.client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))
29
+
30
+ def query(self, messages: "list[dict]", max_tokens: int, temperature: float) -> str:
31
+ """make a query
32
+
33
+ Args:
34
+ messages (list[dict]): chat history in turbo format
35
+ max_tokens (int): max token in api call
36
+ api_key (str): openai api key
37
+ temperature (float): sampling temperature
38
+
39
+ Raises:
40
+ OutOfQuotaException: the apikey has out of quota
41
+ AccessTerminatedException: the apikey has been ban
42
+
43
+ Returns:
44
+ str: the return msg
45
+ """
46
+ time.sleep(self.sleep_time)
47
+ try:
48
+ #response = openai.ChatCompletion.create(
49
+ # model=self.model_name,
50
+ # messages=messages,
51
+ # temperature=temperature,
52
+ # max_tokens=max_tokens,
53
+ # api_key=api_key,
54
+ #)
55
+ #gen = response['choices'][0]['message']['content']
56
+ response = self.client.chat.completions.create(
57
+ model=self.model_name,
58
+ messages=messages,
59
+ max_tokens=512,
60
+ temperature=0.7,
61
+ top_p=0.7,
62
+ top_k=50,
63
+ repetition_penalty=1,
64
+ stop=["<|im_start|>","<|im_end|>"],
65
+ stream=False
66
+ )
67
+ #print(response.choices[0].message.content)
68
+ gen = response.choices[0].message.content
69
+ # 去除字符串中的所有 ```
70
+ cleaned_text = gen.replace('```', '')
71
+ return cleaned_text
72
+
73
+ except Exception as e:
74
+ print(f"An error occurred: {e}")
75
+ traceback.print_exc() # 打印详细的错误堆栈信息
76
+
77
+ def set_meta_prompt(self, meta_prompt: str):
78
+ """Set the meta_prompt
79
+
80
+ Args:
81
+ meta_prompt (str): the meta prompt
82
+ """
83
+ self.memory_lst.append({"role": "system", "content": f"{meta_prompt}"})
84
+
85
+ def add_event(self, event: str):
86
+ """Add an new event in the memory
87
+
88
+ Args:
89
+ event (str): string that describe the event.
90
+ """
91
+ self.memory_lst.append({"role": "user", "content": f"{event}"})
92
+
93
+ def add_memory(self, memory: str):
94
+ """Monologue in the memory
95
+
96
+ Args:
97
+ memory (str): string that generated by the model in the last round.
98
+ """
99
+ self.memory_lst.append({"role": "assistant", "content": f"{memory}"})
100
+ print(f"----- {self.name} -----\n{memory}\n")
101
+
102
+ def ask(self, temperature: float=None):
103
+ """Query for answer
104
+
105
+ Args:
106
+ """
107
+ # query
108
+ #num_context_token = sum([num_tokens_from_string(m["content"], self.model_name) for m in self.memory_lst])
109
+ #max_token = model2max_context - num_context_token
110
+ return self.query(self.memory_lst, 100, temperature=temperature if temperature else self.temperature)