File size: 3,932 Bytes
5827a20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#import openai
import os
import backoff
import time
import random
import traceback
#from openai.error import RateLimitError, APIError, ServiceUnavailableError, APIConnectionError
#from .openai_utils import OutOfQuotaException, AccessTerminatedException
#from .openai_utils import num_tokens_from_string, model2max_context
from together import Together


class Agent:
    def __init__(self, model_name: str, name: str, temperature: float, sleep_time: float=0) -> None:
        """Create an agent

        Args:
            model_name(str): model name
            name (str): name of this agent
            temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
            sleep_time (float): sleep because of rate limits
        """
        self.model_name = model_name
        self.name = name
        self.temperature = temperature
        self.memory_lst = []
        self.sleep_time = sleep_time
        self.client = Together(api_key=os.environ.get('TOGETHER_API_KEY'))

    def query(self, messages: "list[dict]", max_tokens: int, temperature: float) -> str:
        """make a query

        Args:
            messages (list[dict]): chat history in turbo format
            max_tokens (int): max token in api call
            api_key (str): openai api key
            temperature (float): sampling temperature

        Raises:
            OutOfQuotaException: the apikey has out of quota
            AccessTerminatedException: the apikey has been ban

        Returns:
            str: the return msg
        """
        time.sleep(self.sleep_time)
        try:
            #response = openai.ChatCompletion.create(
            #    model=self.model_name,
            #    messages=messages,
            #    temperature=temperature,
            #    max_tokens=max_tokens,
            #    api_key=api_key,
            #)
            #gen = response['choices'][0]['message']['content']                
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=messages,
                max_tokens=512,
                temperature=0.7,
                top_p=0.7,
                top_k=50,
                repetition_penalty=1,
                stop=["<|im_start|>","<|im_end|>"],
                stream=False
            )
            #print(response.choices[0].message.content)
            gen = response.choices[0].message.content
            # 去除字符串中的所有 ```
            cleaned_text = gen.replace('```', '')
            return cleaned_text

        except Exception as e:
            print(f"An error occurred: {e}")
            traceback.print_exc()  # 打印详细的错误堆栈信息

    def set_meta_prompt(self, meta_prompt: str):
        """Set the meta_prompt

        Args:
            meta_prompt (str): the meta prompt
        """
        self.memory_lst.append({"role": "system", "content": f"{meta_prompt}"})

    def add_event(self, event: str):
        """Add an new event in the memory

        Args:
            event (str): string that describe the event.
        """
        self.memory_lst.append({"role": "user", "content": f"{event}"})

    def add_memory(self, memory: str):
        """Monologue in the memory

        Args:
            memory (str): string that generated by the model in the last round.
        """
        self.memory_lst.append({"role": "assistant", "content": f"{memory}"})
        print(f"----- {self.name} -----\n{memory}\n")

    def ask(self, temperature: float=None):
        """Query for answer

        Args:
        """
        # query
        #num_context_token = sum([num_tokens_from_string(m["content"], self.model_name) for m in self.memory_lst])
        #max_token = model2max_context - num_context_token
        return self.query(self.memory_lst, 100, temperature=temperature if temperature else self.temperature)