dj86's picture
Upload Agent.py
5827a20 verified
raw
history blame
3.93 kB
#import openai
import os
import backoff
import time
import random
import traceback
#from openai.error import RateLimitError, APIError, ServiceUnavailableError, APIConnectionError
#from .openai_utils import OutOfQuotaException, AccessTerminatedException
#from .openai_utils import num_tokens_from_string, model2max_context
from together import Together
class Agent:
def __init__(self, model_name: str, name: str, temperature: float, sleep_time: float=0) -> None:
"""Create an agent
Args:
model_name(str): model name
name (str): name of this agent
temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
sleep_time (float): sleep because of rate limits
"""
self.model_name = model_name
self.name = name
self.temperature = temperature
self.memory_lst = []
self.sleep_time = sleep_time
self.client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))
def query(self, messages: "list[dict]", max_tokens: int, temperature: float) -> str:
"""make a query
Args:
messages (list[dict]): chat history in turbo format
max_tokens (int): max token in api call
api_key (str): openai api key
temperature (float): sampling temperature
Raises:
OutOfQuotaException: the apikey has out of quota
AccessTerminatedException: the apikey has been ban
Returns:
str: the return msg
"""
time.sleep(self.sleep_time)
try:
#response = openai.ChatCompletion.create(
# model=self.model_name,
# messages=messages,
# temperature=temperature,
# max_tokens=max_tokens,
# api_key=api_key,
#)
#gen = response['choices'][0]['message']['content']
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=512,
temperature=0.7,
top_p=0.7,
top_k=50,
repetition_penalty=1,
stop=["<|im_start|>","<|im_end|>"],
stream=False
)
#print(response.choices[0].message.content)
gen = response.choices[0].message.content
# 去除字符串中的所有 ```
cleaned_text = gen.replace('```', '')
return cleaned_text
except Exception as e:
print(f"An error occurred: {e}")
traceback.print_exc() # 打印详细的错误堆栈信息
def set_meta_prompt(self, meta_prompt: str):
"""Set the meta_prompt
Args:
meta_prompt (str): the meta prompt
"""
self.memory_lst.append({"role": "system", "content": f"{meta_prompt}"})
def add_event(self, event: str):
"""Add an new event in the memory
Args:
event (str): string that describe the event.
"""
self.memory_lst.append({"role": "user", "content": f"{event}"})
def add_memory(self, memory: str):
"""Monologue in the memory
Args:
memory (str): string that generated by the model in the last round.
"""
self.memory_lst.append({"role": "assistant", "content": f"{memory}"})
print(f"----- {self.name} -----\n{memory}\n")
def ask(self, temperature: float=None):
"""Query for answer
Args:
"""
# query
#num_context_token = sum([num_tokens_from_string(m["content"], self.model_name) for m in self.memory_lst])
#max_token = model2max_context - num_context_token
return self.query(self.memory_lst, 100, temperature=temperature if temperature else self.temperature)