File size: 5,402 Bytes
b0f4c90
 
 
 
 
 
 
 
 
96c260d
 
 
 
b0f4c90
 
 
 
10ec55c
b0f4c90
 
 
10ec55c
 
 
 
 
 
b0f4c90
 
10ec55c
 
b0f4c90
10ec55c
b0f4c90
 
10ec55c
 
 
 
b0f4c90
10ec55c
b0f4c90
7f85cc2
b0f4c90
10ec55c
b0f4c90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10ec55c
 
b0f4c90
 
 
 
 
 
 
 
 
 
 
 
96c260d
7294bbc
ea7e28a
 
7294bbc
 
ea7e28a
7294bbc
 
 
 
 
 
 
 
 
 
 
96c260d
b0f4c90
 
 
ea7e28a
b0f4c90
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from agent_build_sdk.builder import AgentBuilder
from agent_build_sdk.model.model import AgentResp, AgentReq, STATUS_DISTRIBUTION, STATUS_ROUND, STATUS_VOTE, \
    STATUS_START, STATUS_VOTE_RESULT, STATUS_RESULT
from agent_build_sdk.sdk.agent import BasicAgent
from agent_build_sdk.sdk.agent import format_prompt

from prompts import DESC_PROMPT, VOTE_PROMPT
from agent_build_sdk.utils.logger import logger

from openai import OpenAI
import os


class SpyAgent(BasicAgent):

    def perceive(self, req=AgentReq):
        logger.info("spy perceive: {}".format(req))
        if req.status == STATUS_START:  # Start a new game
            self.memory.clear()
            self.memory.set_variable("name", req.message)
            self.memory.append_history(
                f"Host: Ladies and gentlemen, welcome to the game of 'Who is the Spy?' We have a group of 6 players, among whom one is a spy. Each person will receive a card. 5 of these cards will have the same word, while the spy will receive a different word."
                f"Once you have your word, take some time to think about how to cleverly describe it without revealing it. Each person will use one sentence to describe their word in each round, and no one can repeat descriptions. The word itself cannot be mentioned."
                f"After each round of descriptions, everyone present votes to identify the person they suspect is the spy. The person with the most votes is eliminated. If the spy is eliminated, the game ends; if not, the game continues."
                f"You need to judge whether you are the spy based on the context. If you are the spy, you should try to confuse others and avoid being voted out. If you are not the spy, you should ensure that the spy remains undetected while providing hints to teammates."
            )
        elif req.status == STATUS_DISTRIBUTION:  # Assign words
            self.memory.set_variable("word", req.word)
            self.memory.append_history(
                'Host: Hello,{}. The word assigned to you is {}'.format(self.memory.load_variable("name"), req.word))
        elif req.status == STATUS_ROUND:  # Speech
            if req.name:
                # from other players
                self.memory.append_history(req.name + ': ' + req.message)
            else:
                # from the host
                self.memory.append_history('Host: Now entering round {}.'.format(str(req.round)))
                self.memory.append_history('Host: Each player describes the word they have been assigned.')
        elif req.status == STATUS_VOTE:  # Vote 
            self.memory.append_history(req.name + ': ' + req.message)
        elif req.status == STATUS_VOTE_RESULT:   
            if req.name:
                self.memory.append_history('Host: The voting results are: {}.'.format(req.message))
            else:
                self.memory.append_history('Host: No one is out.')
        elif req.status == STATUS_RESULT:
            self.memory.append_history(req.message)
        else:
            raise NotImplementedError

    def interact(self, req=AgentReq) -> AgentResp:
        logger.info("spy interact: {}".format(req))
        if req.status == STATUS_ROUND:
            prompt = format_prompt(DESC_PROMPT,
                                   {"name": self.memory.load_variable("name"),
                                    "word": self.memory.load_variable("word"),
                                    "history": "\n".join(self.memory.load_history())
                                    })
            logger.info("prompt:" + prompt)
            result = self.llm_caller(prompt)
            logger.info("spy interact result: {}".format(result))
            return AgentResp(success=True, result=result, errMsg=None)

        elif req.status == STATUS_VOTE:
            self.memory.append_history("Host: It's time for voting.")
            choices = [name for name in req.message.split(",") if name != self.memory.load_variable("name")]  # Exclude self
            self.memory.set_variable("choices", choices)
            prompt = format_prompt(VOTE_PROMPT, {"name": self.memory.load_variable("name"),
                                                 "choices": choices,
                                                 "history": "\n".join(self.memory.load_history())
                                                 })
            logger.info("prompt:" + prompt)
            result = self.llm_caller(prompt)
            logger.info("spy interact result: {}".format(result))
            return AgentResp(success=True, result=result, errMsg=None)
        else:
            raise NotImplementedError

    def llm_caller(self, prompt):
        client = OpenAI(
            api_key=os.getenv('API_KEY'),
            base_url=os.getenv('BASE_URL')
        )
        completion = client.chat.completions.create(
            model=self.model_name,
            messages=[
                {'role': 'system', 'content': 'You are a helpful assistant.'},
                {'role': 'user', 'content': prompt}
            ],
            temperature=0
        )
        try:
            return completion.choices[0].message.content
        except Exception as e:
            print(e)
            return None


if __name__ == '__main__':
    name = 'spy'
    agent_builder = AgentBuilder(name, agent=SpyAgent(name, model_name=os.getenv('MODEL_NAME')))
    agent_builder.start()