Spaces:
Sleeping
Sleeping
#import openai | |
import os | |
import backoff | |
import time | |
import random | |
import traceback | |
#from openai.error import RateLimitError, APIError, ServiceUnavailableError, APIConnectionError | |
#from .openai_utils import OutOfQuotaException, AccessTerminatedException | |
#from .openai_utils import num_tokens_from_string, model2max_context | |
from together import Together | |
class Agent: | |
def __init__(self, model_name: str, name: str, temperature: float, sleep_time: float=0) -> None: | |
"""Create an agent | |
Args: | |
model_name(str): model name | |
name (str): name of this agent | |
temperature (float): higher values make the output more random, while lower values make it more focused and deterministic | |
sleep_time (float): sleep because of rate limits | |
""" | |
self.model_name = model_name | |
self.name = name | |
self.temperature = temperature | |
self.memory_lst = [] | |
self.sleep_time = sleep_time | |
self.client = Together(api_key=os.environ.get('TOGETHER_API_KEY')) | |
def query(self, messages: "list[dict]", max_tokens: int, temperature: float) -> str: | |
"""make a query | |
Args: | |
messages (list[dict]): chat history in turbo format | |
max_tokens (int): max token in api call | |
api_key (str): openai api key | |
temperature (float): sampling temperature | |
Raises: | |
OutOfQuotaException: the apikey has out of quota | |
AccessTerminatedException: the apikey has been ban | |
Returns: | |
str: the return msg | |
""" | |
time.sleep(self.sleep_time) | |
try: | |
#response = openai.ChatCompletion.create( | |
# model=self.model_name, | |
# messages=messages, | |
# temperature=temperature, | |
# max_tokens=max_tokens, | |
# api_key=api_key, | |
#) | |
#gen = response['choices'][0]['message']['content'] | |
response = self.client.chat.completions.create( | |
model=self.model_name, | |
messages=messages, | |
max_tokens=512, | |
temperature=0.7, | |
top_p=0.7, | |
top_k=50, | |
repetition_penalty=1, | |
stop=["<|im_start|>","<|im_end|>"], | |
stream=False | |
) | |
#print(response.choices[0].message.content) | |
gen = response.choices[0].message.content | |
# 去除字符串中的所有 ``` | |
cleaned_text = gen.replace('```', '') | |
return cleaned_text | |
except Exception as e: | |
print(f"An error occurred: {e}") | |
traceback.print_exc() # 打印详细的错误堆栈信息 | |
def set_meta_prompt(self, meta_prompt: str): | |
"""Set the meta_prompt | |
Args: | |
meta_prompt (str): the meta prompt | |
""" | |
self.memory_lst.append({"role": "system", "content": f"{meta_prompt}"}) | |
def add_event(self, event: str): | |
"""Add an new event in the memory | |
Args: | |
event (str): string that describe the event. | |
""" | |
self.memory_lst.append({"role": "user", "content": f"{event}"}) | |
def add_memory(self, memory: str): | |
"""Monologue in the memory | |
Args: | |
memory (str): string that generated by the model in the last round. | |
""" | |
self.memory_lst.append({"role": "assistant", "content": f"{memory}"}) | |
print(f"----- {self.name} -----\n{memory}\n") | |
def ask(self, temperature: float=None): | |
"""Query for answer | |
Args: | |
""" | |
# query | |
#num_context_token = sum([num_tokens_from_string(m["content"], self.model_name) for m in self.memory_lst]) | |
#max_token = model2max_context - num_context_token | |
return self.query(self.memory_lst, 100, temperature=temperature if temperature else self.temperature) | |