dj86 commited on
Commit
62a001a
1 Parent(s): 6570a5e

Upload 3 files

Browse files
Files changed (3) hide show
  1. utils.py +54 -0
  2. vlog4chat.py +327 -0
  3. vlog4debate.py +235 -0
utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pdb
4
+ import logging
5
+ import subprocess
6
+
7
+ def format_time(seconds):
8
+ hours = seconds // 3600
9
+ minutes = (seconds // 60) % 60
10
+ seconds = seconds % 60
11
+ return f"{hours}:{minutes:02d}:{seconds:02d}"
12
+
13
+ def extract_info_from_url(url):
14
+ if 'bilibili.com/video/' in url:
15
+ # 如果 URL 是 Bilibili 的视频链接,提取 BV 号
16
+ match = re.search(r'/video/(BV\w+)', url)
17
+ if match:
18
+ return match.group(1)
19
+ else:
20
+ return "BV ID not found!"
21
+ else:
22
+ # 如果 URL 是博客链接等,提取最后一段路径
23
+ match = re.search(r'/([^/]+)/?$', url)
24
+ if match:
25
+ return match.group(1)
26
+ else:
27
+ return "URL address is wired!"
28
+
29
+ def download_video(url, save_dir='./examples', size=768):
30
+ filename = extract_info_from_url(url)
31
+ save_path = f'{save_dir}/{filename}.mp4'
32
+ #cmd = f'yt-dlp -S ext:mp4:m4a --throttled-rate 5M -f "best[width<={size}][height<={size}]" --output {save_path} --merge-output-format mp4 https://www.youtube.com/embed/{url}'
33
+ # $ you-get -o ~/Videos -O zoo.webm 'https://www.youtube.com/watch?v=jNQXAC9IVRw'
34
+ cmd = f'you-get -o {save_dir} -O {filename} {url}'
35
+ if not os.path.exists(save_path):
36
+ try:
37
+ subprocess.call(cmd, shell=True)
38
+ except:
39
+ return None
40
+ return save_path
41
+
42
+ def logger_creator(video_id):
43
+ # set up logger
44
+ logger = logging.getLogger(__name__)
45
+ logger.setLevel(logging.INFO)
46
+ handler = logging.FileHandler(f'./examples/{video_id}.log', mode='w')
47
+ handler.setLevel(logging.INFO)
48
+ formatter = logging.Formatter('%(message)s')
49
+ handler.setFormatter(formatter)
50
+ logger.addHandler(handler)
51
+ return logger
52
+
53
+ if __name__ == "__main__":
54
+ download_video('outcGtbnMuQ', save_dir='./examples', size=768)
vlog4chat.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import pdb
4
+ import sys
5
+ import time
6
+ import numpy as np
7
+ from PIL import Image
8
+ from transformers import logging
9
+ logging.set_verbosity_error()
10
+
11
+ from models.kts_model import VideoSegmentor
12
+ from models.clip_model import FeatureExtractor
13
+ from models.blip2_model import ImageCaptioner
14
+ from models.grit_model import DenseCaptioner
15
+ from models.whisper_model import AudioTranslator
16
+ from models.gpt_model import LlmReasoner
17
+ from utils.utils import logger_creator, format_time
18
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, AutoModelForSeq2SeqLM
19
+
20
+ import together
21
+ import warnings
22
+ from together import Together
23
+ from langchain.prompts import PromptTemplate
24
+ from langchain.llms.base import LLM
25
+ from langchain.chains import ConversationalRetrievalChain
26
+ from langchain_core.output_parsers import StrOutputParser
27
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
28
+ from langchain.llms import HuggingFacePipeline
29
+ from langchain.embeddings import HuggingFaceEmbeddings
30
+ from langchain_community.document_loaders import TextLoader
31
+ from langchain.document_loaders import UnstructuredFileLoader
32
+ from langchain_community.vectorstores import FAISS
33
+ from langchain.utils import get_from_dict_or_env
34
+ from langchain_core.runnables import RunnablePassthrough, RunnableParallel
35
+ from pydantic.v1 import Extra, Field, root_validator
36
+ from typing import Any, Dict, List, Mapping, Optional
37
+ from langchain.memory import ConversationBufferMemory
38
+ from langchain import LLMChain, PromptTemplate
39
+ from paddleocr import PaddleOCR, draw_ocr
40
+
41
+ sys.path.append('/root/autodl-tmp/recognize-anything')
42
+
43
+ from ram.models import ram
44
+ from ram.models import tag2text
45
+ from ram import inference_ram as inference
46
+ #from ram import inference_tag2text as inference
47
+ from ram import get_transform
48
+
49
+ warnings.filterwarnings("ignore", category=UserWarning)
50
+ B_INST, E_INST = "[INST]", "[/INST]"
51
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
52
+
53
+ DEFAULT_SYSTEM_PROMPT = ""
54
+
55
+ instruction = """You are an AI assistant designed for answering questions about a video.
56
+ You are given a document and a question, the document records what people see and hear from this video.
57
+ Try to connet these information and provide a conversational answer.
58
+ Question: {question}
59
+ =========
60
+ {context}
61
+ =========
62
+ """
63
+
64
+ system_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
65
+ You can assume the discussion is about the video content.
66
+ Chat History:
67
+ {chat_history}
68
+ Follow Up Input: {question}
69
+ Standalone question:"""
70
+
71
+ os.environ['HF_HOME'] = '/root/autodl-tmp/cache/'
72
+ os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"
73
+
74
+ def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
75
+ SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
76
+ prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
77
+ return prompt_template
78
+
79
+ template = get_prompt(instruction, system_prompt)
80
+
81
+ prompt = PromptTemplate(
82
+ input_variables=["chat_history", "user_input"], template=template
83
+ )
84
+
85
+ class TogetherLLM(LLM):
86
+ """Together large language models."""
87
+
88
+ model: str = "togethercomputer/llama-2-70b-chat"
89
+ """model endpoint to use"""
90
+
91
+ together_api_key: str = os.environ["TOGETHER_API_KEY"]
92
+ """Together API key"""
93
+
94
+ temperature: float = 0.7
95
+ """What sampling temperature to use."""
96
+
97
+ max_tokens: int = 512
98
+ """The maximum number of tokens to generate in the completion."""
99
+
100
+ class Config:
101
+ extra = Extra.forbid
102
+
103
+ @root_validator()
104
+ def validate_environment(cls, values: Dict) -> Dict:
105
+ """Validate that the API key is set."""
106
+ api_key = get_from_dict_or_env(
107
+ values, "together_api_key", "TOGETHER_API_KEY"
108
+ )
109
+ values["together_api_key"] = api_key
110
+ return values
111
+
112
+ @property
113
+ def _llm_type(self) -> str:
114
+ """Return type of LLM."""
115
+ return "together"
116
+
117
+ def _call(
118
+ self,
119
+ prompt: str,
120
+ **kwargs: Any,
121
+ ) -> str:
122
+ """Call to Together endpoint."""
123
+ together.api_key = self.together_api_key
124
+ output = together.Complete.create(prompt,
125
+ model=self.model,
126
+ max_tokens=self.max_tokens,
127
+ temperature=self.temperature,
128
+ top_p=0.7,
129
+ top_k=50,
130
+ repetition_penalty=1,
131
+ stop=["</s>"],
132
+ )
133
+ text = output['choices'][0]['text']
134
+ return text
135
+
136
+ class Vlogger4chat :
137
+ def __init__(self, args):
138
+ self.args = args
139
+ self.alpha = args.alpha
140
+ self.beta = args.beta
141
+ self.data_dir = args.data_dir,
142
+ self.tmp_dir = args.tmp_dir
143
+ self.models_flag = False
144
+ self.init_llm()
145
+ self.init_tag2txt()
146
+ self.history = []
147
+ if not os.path.exists(self.tmp_dir):
148
+ os.makedirs(self.tmp_dir)
149
+
150
+ def init_models(self):
151
+ print('\033[1;34m' + "Welcome to the our Vlog toolbox...".center(50, '-') + '\033[0m')
152
+ print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
153
+ print('\033[1;31m' + "This may time-consuming, please wait...".center(50, '-') + '\033[0m')
154
+ self.ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory
155
+ self.feature_extractor = FeatureExtractor(self.args)
156
+ self.video_segmenter = VideoSegmentor(alpha=self.alpha, beta=self.beta)
157
+ self.image_captioner = ImageCaptioner(model_name=self.args.captioner_base_model, device=self.args.image_captioner_device)
158
+ self.dense_captioner = DenseCaptioner(device=self.args.dense_captioner_device)
159
+ self.audio_translator = AudioTranslator(model=self.args.audio_translator, device=self.args.audio_translator_device)
160
+ print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
161
+ # 翻译文档
162
+ # 初始化 tokenizer 和 model
163
+ model_name = 'Helsinki-NLP/opus-mt-en-zh' #'Helsinki-NLP/opus-mt-zh-en'
164
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
165
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
166
+ self.my_embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-m3', model_kwargs={'device': 'cuda'} ,encode_kwargs={'normalize_embeddings': True})
167
+
168
+ def init_llm(self):
169
+ print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
170
+ self.llm = TogetherLLM(
171
+ model= "microsoft/WizardLM-2-8x22B",
172
+ temperature=0.1,
173
+ max_tokens=768
174
+ )
175
+ print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
176
+
177
+ if not self.models_flag:
178
+ self.init_models()
179
+ self.models_flag = True
180
+
181
+ def init_tag2txt(self):
182
+ self.transform = get_transform(image_size=384)
183
+
184
+ # delete some tags that may disturb captioning
185
+ # 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
186
+ delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
187
+
188
+ #######load model
189
+ #self.tag2txt_model = tag2text(pretrained='/root/autodl-tmp/recognize-anything/pretrained/tag2text_swin_14m.pth',
190
+ # image_size=384, vit='swin_b', delete_tag_index=delete_tag_index)
191
+ self.ram_model = ram(pretrained='/root/autodl-tmp/recognize-anything/pretrained/ram_swin_large_14m.pth',
192
+ image_size=384,
193
+ vit='swin_l')
194
+ #self.tag2txt_model.threshold = 0.68 # threshold for tagging
195
+ #self.tag2txt_model.eval()
196
+ self.ram_model.eval()
197
+
198
+ #self.tag2txt_model = self.tag2txt_model.to(device=self.args.dense_captioner_device)
199
+ self.ram_model = self.ram_model.to(device=self.args.dense_captioner_device)
200
+
201
+ def exist_videolog(self, video_id):
202
+ if isinstance(self.data_dir, tuple):
203
+ self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引
204
+ if isinstance(video_id, tuple):
205
+ video_id = video_id[0] # 或者根据实际情况选择合适的索引
206
+ log_path = os.path.join(self.data_dir, f"{video_id}.log")
207
+ #print(f"log_path: {log_path}\n")
208
+
209
+ if os.path.exists(log_path):
210
+ #print("existing log path!!!\n")
211
+ loader = UnstructuredFileLoader(log_path)
212
+ raw_documents = loader.load()
213
+ if not raw_documents:
214
+ print("The log file is empty or could not be loaded.")
215
+ return False # 如果 raw_documents 为空或所有内容都为空白,直接返回
216
+ # Split text
217
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
218
+ chunks = text_splitter.split_documents(raw_documents)
219
+ self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
220
+ self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
221
+ return True
222
+ return False
223
+
224
+ def create_videolog(self, video_id):
225
+ video_id = os.path.basename(self.video_path).split('.')[0]
226
+ log_path = os.path.join(self.data_dir, video_id + '.log')
227
+ loader = UnstructuredFileLoader(log_path)
228
+ raw_documents = loader.load()
229
+ # Split text
230
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
231
+ chunks = text_splitter.split_documents(raw_documents)
232
+ self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
233
+ self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
234
+
235
+ def video2log(self, video_path):
236
+ self.video_path = video_path
237
+ video_id = os.path.basename(video_path).split('.')[0]
238
+
239
+ if self.exist_videolog(video_id):
240
+ return self.printlog(video_id)
241
+
242
+ if not self.models_flag:
243
+ self.init_models()
244
+ self.models_flag = True
245
+
246
+ logger = logger_creator(video_id)
247
+ clip_features, video_length = self.feature_extractor(video_path, video_id)
248
+ seg_windows = self.video_segmenter(clip_features, video_length)
249
+
250
+ cap = cv2.VideoCapture(video_path)
251
+ fps = cap.get(cv2.CAP_PROP_FPS)
252
+ audio_results = self.audio_translator(video_path)
253
+
254
+ for start_sec, end_sec in seg_windows:
255
+ middle_sec = (start_sec + end_sec) // 2
256
+ middle_frame_idx = int(middle_sec * fps)
257
+ cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_idx)
258
+ ret, frame = cap.read()
259
+
260
+ if ret:
261
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
262
+ image_caption = self.image_captioner.image_caption(frame)
263
+ dense_caption = self.dense_captioner.image_dense_caption(frame)
264
+ image = self.transform(Image.fromarray(frame)).unsqueeze(0).to(device=self.args.dense_captioner_device)
265
+ #tag2txt = inference(image, self.tag2txt_model, 'None')
266
+ ram = inference(image, self.ram_model)
267
+ audio_transcript = self.audio_translator.match(audio_results, start_sec, end_sec)
268
+ OCR_result = self.ocr.ocr(frame)
269
+ # 提取所有文本块中的所有行的文字
270
+ texts = []
271
+ for block in OCR_result:
272
+ if block is not None: # 检查 block 是否为 None
273
+ for line in block:
274
+ if line is not None: # 检查 line 是否为 None
275
+ text = line[1][0] # 提取文字部分
276
+ texts.append(text)
277
+ # 将列表中的所有文字合并成一个字符串
278
+ OCR_result_str = ' '.join(texts)
279
+
280
+ logger.info(f"When {format_time(start_sec)} - {format_time(end_sec)}")
281
+ chinese_image_caption = self.translate_text(image_caption, self.tokenizer, self.model)
282
+ #chinese_tag2txt = self.translate_text(tag2txt[2], self.tokenizer, self.model)
283
+ chinese_dense_caption = self.translate_text(dense_caption, self.tokenizer, self.model)
284
+ logger.info(f"我看到这些画面:\"{chinese_image_caption}\"")
285
+ #logger.info(f"我看见 {chinese_tag2txt}.")
286
+ logger.info(f"我发现这些内容:\"{chinese_dense_caption}\"")
287
+ logger.info(f"我检测到这些标签:\"{ram[1]}.\"")
288
+ logger.info(f"我识别到这些文字:\"{OCR_result_str}\"")
289
+
290
+ if len(audio_transcript) > 0:
291
+ #english_audio_text = self.translate_text(audio_transcript, self.tokenizer, self.model)
292
+ logger.info(f"我听到有人说:\"{audio_transcript}\"")
293
+ logger.info("\n")
294
+
295
+ cap.release()
296
+ self.create_videolog(video_id)
297
+ return self.printlog(video_id)
298
+
299
+ def printlog(self, video_id):
300
+ log_list = []
301
+ log_path = os.path.join(self.data_dir, video_id + '.log')
302
+ with open(log_path, 'r', encoding='utf-8') as f:
303
+ for line in f:
304
+ log_list.append(line.strip())
305
+ return log_list
306
+
307
+ def translate_text(self, text, tokenizer, model):
308
+ # 编码文本
309
+ encoded_text = tokenizer.prepare_seq2seq_batch([text], return_tensors='pt')
310
+
311
+ # 生成翻译
312
+ translated = model.generate(**encoded_text)
313
+
314
+ # 解码翻译后的文本
315
+ translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
316
+ return translated_text
317
+
318
+ def chat2video(self, question):
319
+ print(f"Question: {question}")
320
+
321
+ response = self.chain({"question": "请用中文回答:"+question, "chat_history": self.history})['answer']
322
+ self.history.append((question, response))
323
+ return response
324
+
325
+ def clean_history(self):
326
+ self.history = []
327
+ return
vlog4debate.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MAD: Multi-Agent Debate with Large Language Models
3
+ Copyright (C) 2023 The MAD Team
4
+
5
+ This program is free software: you can redistribute it and/or modify
6
+ it under the terms of the GNU General Public License as published by
7
+ the Free Software Foundation, either version 3 of the License, or
8
+ (at your option) any later version.
9
+
10
+ This program is distributed in the hope that it will be useful,
11
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ GNU General Public License for more details.
14
+
15
+ You should have received a copy of the GNU General Public License
16
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
17
+ """
18
+
19
+ import os
20
+ import json
21
+ import random
22
+ # random.seed(0)
23
+ from Agent import Agent
24
+
25
+ os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"
26
+
27
+ NAME_LIST=[
28
+ "Affirmative side",
29
+ "Negative side",
30
+ "Moderator",
31
+ ]
32
+
33
+ class DebatePlayer(Agent):
34
+ def __init__(self, model_name: str, name: str, temperature:float, openai_api_key: str, sleep_time: float) -> None:
35
+ """Create a player in the debate
36
+
37
+ Args:
38
+ model_name(str): model name
39
+ name (str): name of this player
40
+ temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
41
+ openai_api_key (str): As the parameter name suggests
42
+ sleep_time (float): sleep because of rate limits
43
+ """
44
+ super(DebatePlayer, self).__init__(model_name, name, temperature, sleep_time)
45
+ self.openai_api_key = openai_api_key
46
+
47
+
48
+ class Debate:
49
+ def __init__(self,
50
+ model_name: str='Qwen/Qwen1.5-72B-Chat',
51
+ temperature: float=0,
52
+ num_players: int=3,
53
+ openai_api_key: str=os.environ["TOGETHER_API_KEY"],
54
+ config: dict=None,
55
+ max_round: int=3,
56
+ sleep_time: float=0
57
+ ) -> None:
58
+ """Create a debate
59
+
60
+ Args:
61
+ model_name (str): openai model name
62
+ temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
63
+ num_players (int): num of players
64
+ openai_api_key (str): As the parameter name suggests
65
+ max_round (int): maximum Rounds of Debate
66
+ sleep_time (float): sleep because of rate limits
67
+ """
68
+
69
+ self.model_name = model_name
70
+ self.temperature = temperature
71
+ self.num_players = num_players
72
+ self.openai_api_key = openai_api_key
73
+ self.config = config
74
+ self.max_round = max_round
75
+ self.sleep_time = sleep_time
76
+ self.initial_debate = ''
77
+ self.init_prompt()
78
+
79
+ # creat&init agents
80
+ self.creat_agents()
81
+ self.init_agents()
82
+
83
+
84
+ def init_prompt(self):
85
+ def prompt_replace(key):
86
+ self.config[key] = self.config[key].replace("##debate_topic##", self.config["debate_topic"])
87
+ prompt_replace("player_meta_prompt")
88
+ prompt_replace("moderator_meta_prompt")
89
+ prompt_replace("affirmative_prompt")
90
+ prompt_replace("judge_prompt_last2")
91
+
92
+ def creat_agents(self):
93
+ # creates players
94
+ self.players = [
95
+ DebatePlayer(model_name=self.model_name, name=name, temperature=self.temperature, openai_api_key=self.openai_api_key, sleep_time=self.sleep_time) for name in NAME_LIST
96
+ ]
97
+ self.affirmative = self.players[0]
98
+ self.negative = self.players[1]
99
+ self.moderator = self.players[2]
100
+
101
+ def init_agents(self):
102
+ # start: set meta prompt
103
+ self.affirmative.set_meta_prompt(self.config['player_meta_prompt'])
104
+ self.negative.set_meta_prompt(self.config['player_meta_prompt'])
105
+ self.moderator.set_meta_prompt(self.config['moderator_meta_prompt'])
106
+
107
+ # start: first round debate, state opinions
108
+ print(f"===== Start Debate Round =====\n")
109
+ self.affirmative.add_event(self.config['affirmative_prompt'])
110
+ self.aff_ans = self.affirmative.ask()
111
+ self.affirmative.add_memory(self.aff_ans)
112
+ self.config['base_answer'] = self.aff_ans
113
+ affirm_side = "\n\n正方观点:" + self.aff_ans
114
+ self.initial_debate += affirm_side
115
+
116
+ self.negative.add_event(self.config['negative_prompt'].replace('##aff_ans##', self.aff_ans))
117
+ self.neg_ans = self.negative.ask()
118
+ self.negative.add_memory(self.neg_ans)
119
+ neg_side = "\n\n反方观点:" + self.neg_ans
120
+ self.initial_debate += neg_side
121
+
122
+ self.moderator.add_event(self.config['moderator_prompt'].replace('##aff_ans##', self.aff_ans).replace('##neg_ans##', self.neg_ans).replace('##round##', 'first'))
123
+ self.mod_ans = self.moderator.ask()
124
+ self.moderator.add_memory(self.mod_ans)
125
+ self.mod_ans = eval(self.mod_ans)
126
+
127
+ def round_dct(self, num: int):
128
+ dct = {
129
+ 1: 'first', 2: 'second', 3: 'third', 4: 'fourth', 5: 'fifth', 6: 'sixth', 7: 'seventh', 8: 'eighth', 9: 'ninth', 10: 'tenth'
130
+ }
131
+ return dct[num]
132
+
133
+ def print_answer(self):
134
+ print("\n\n===== Debate Done! =====")
135
+ print("\n----- Debate Topic -----")
136
+ print(self.config["debate_topic"])
137
+ print("\n----- Base Answer -----")
138
+ print(self.config["base_answer"])
139
+ print("\n----- Debate Answer -----")
140
+ print(self.config["debate_answer"])
141
+ print("\n----- Debate Reason -----")
142
+ print(self.config["Reason"])
143
+
144
+ def broadcast(self, msg: str):
145
+ """Broadcast a message to all players.
146
+ Typical use is for the host to announce public information
147
+
148
+ Args:
149
+ msg (str): the message
150
+ """
151
+ # print(msg)
152
+ for player in self.players:
153
+ player.add_event(msg)
154
+
155
+ def speak(self, speaker: str, msg: str):
156
+ """The speaker broadcast a message to all other players.
157
+
158
+ Args:
159
+ speaker (str): name of the speaker
160
+ msg (str): the message
161
+ """
162
+ if not msg.startswith(f"{speaker}: "):
163
+ msg = f"{speaker}: {msg}"
164
+ # print(msg)
165
+ for player in self.players:
166
+ if player.name != speaker:
167
+ player.add_event(msg)
168
+
169
+ def ask_and_speak(self, player: DebatePlayer):
170
+ ans = player.ask()
171
+ player.add_memory(ans)
172
+ self.speak(player.name, ans)
173
+
174
+
175
+ def run(self):
176
+ for round in range(self.max_round - 1):
177
+
178
+ if self.mod_ans["debate_answer"] != '':
179
+ break
180
+ else:
181
+ print(f"===== Debate Round-{round+1} =====\n")
182
+ self.affirmative.add_event(self.config['debate_prompt'].replace('##oppo_ans##', self.neg_ans))
183
+ self.aff_ans = self.affirmative.ask()
184
+ self.affirmative.add_memory(self.aff_ans)
185
+
186
+ self.negative.add_event(self.config['debate_prompt'].replace('##oppo_ans##', self.aff_ans))
187
+ self.neg_ans = self.negative.ask()
188
+ self.negative.add_memory(self.neg_ans)
189
+
190
+ self.moderator.add_event(self.config['moderator_prompt'].replace('##aff_ans##', self.aff_ans).replace('##neg_ans##', self.neg_ans).replace('##round##', self.round_dct(round+2)))
191
+ self.mod_ans = self.moderator.ask()
192
+ self.moderator.add_memory(self.mod_ans)
193
+ self.mod_ans = eval(self.mod_ans)
194
+
195
+ if self.mod_ans["debate_answer"] != '':
196
+ self.config.update(self.mod_ans)
197
+ self.config['success'] = True
198
+
199
+ # ultimate deadly technique.
200
+ else:
201
+ judge_player = DebatePlayer(model_name=self.model_name, name='Judge', temperature=self.temperature, openai_api_key=self.openai_api_key, sleep_time=self.sleep_time)
202
+ aff_ans = self.affirmative.memory_lst[2]['content']
203
+ neg_ans = self.negative.memory_lst[2]['content']
204
+
205
+ judge_player.set_meta_prompt(self.config['moderator_meta_prompt'])
206
+
207
+ # extract answer candidates
208
+ judge_player.add_event(self.config['judge_prompt_last1'].replace('##aff_ans##', aff_ans).replace('##neg_ans##', neg_ans))
209
+ ans = judge_player.ask()
210
+ judge_player.add_memory(ans)
211
+
212
+ # select one from the candidates
213
+ judge_player.add_event(self.config['judge_prompt_last2'])
214
+ ans = judge_player.ask()
215
+ judge_player.add_memory(ans)
216
+
217
+ ans = eval(ans)
218
+ if ans["debate_answer"] != '':
219
+ self.config['success'] = True
220
+ # save file
221
+ self.config.update(ans)
222
+ self.players.append(judge_player)
223
+
224
+ self.print_answer()
225
+ combined_string = ''.join([
226
+ self.config["debate_topic"],
227
+ self.initial_debate,
228
+ f"\n\n经过{self.max_round}轮辩论......",
229
+ "\n\n仲裁观点:",
230
+ self.config["debate_answer"],
231
+ "\n\n仲裁理由:",
232
+ self.config["Reason"]
233
+ ])
234
+
235
+ return combined_string