Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- utils.py +54 -0
- vlog4chat.py +327 -0
- vlog4debate.py +235 -0
utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import pdb
|
4 |
+
import logging
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
def format_time(seconds):
|
8 |
+
hours = seconds // 3600
|
9 |
+
minutes = (seconds // 60) % 60
|
10 |
+
seconds = seconds % 60
|
11 |
+
return f"{hours}:{minutes:02d}:{seconds:02d}"
|
12 |
+
|
13 |
+
def extract_info_from_url(url):
|
14 |
+
if 'bilibili.com/video/' in url:
|
15 |
+
# 如果 URL 是 Bilibili 的视频链接,提取 BV 号
|
16 |
+
match = re.search(r'/video/(BV\w+)', url)
|
17 |
+
if match:
|
18 |
+
return match.group(1)
|
19 |
+
else:
|
20 |
+
return "BV ID not found!"
|
21 |
+
else:
|
22 |
+
# 如果 URL 是博客链接等,提取最后一段路径
|
23 |
+
match = re.search(r'/([^/]+)/?$', url)
|
24 |
+
if match:
|
25 |
+
return match.group(1)
|
26 |
+
else:
|
27 |
+
return "URL address is wired!"
|
28 |
+
|
29 |
+
def download_video(url, save_dir='./examples', size=768):
|
30 |
+
filename = extract_info_from_url(url)
|
31 |
+
save_path = f'{save_dir}/{filename}.mp4'
|
32 |
+
#cmd = f'yt-dlp -S ext:mp4:m4a --throttled-rate 5M -f "best[width<={size}][height<={size}]" --output {save_path} --merge-output-format mp4 https://www.youtube.com/embed/{url}'
|
33 |
+
# $ you-get -o ~/Videos -O zoo.webm 'https://www.youtube.com/watch?v=jNQXAC9IVRw'
|
34 |
+
cmd = f'you-get -o {save_dir} -O {filename} {url}'
|
35 |
+
if not os.path.exists(save_path):
|
36 |
+
try:
|
37 |
+
subprocess.call(cmd, shell=True)
|
38 |
+
except:
|
39 |
+
return None
|
40 |
+
return save_path
|
41 |
+
|
42 |
+
def logger_creator(video_id):
|
43 |
+
# set up logger
|
44 |
+
logger = logging.getLogger(__name__)
|
45 |
+
logger.setLevel(logging.INFO)
|
46 |
+
handler = logging.FileHandler(f'./examples/{video_id}.log', mode='w')
|
47 |
+
handler.setLevel(logging.INFO)
|
48 |
+
formatter = logging.Formatter('%(message)s')
|
49 |
+
handler.setFormatter(formatter)
|
50 |
+
logger.addHandler(handler)
|
51 |
+
return logger
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
download_video('outcGtbnMuQ', save_dir='./examples', size=768)
|
vlog4chat.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import pdb
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
from transformers import logging
|
9 |
+
logging.set_verbosity_error()
|
10 |
+
|
11 |
+
from models.kts_model import VideoSegmentor
|
12 |
+
from models.clip_model import FeatureExtractor
|
13 |
+
from models.blip2_model import ImageCaptioner
|
14 |
+
from models.grit_model import DenseCaptioner
|
15 |
+
from models.whisper_model import AudioTranslator
|
16 |
+
from models.gpt_model import LlmReasoner
|
17 |
+
from utils.utils import logger_creator, format_time
|
18 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, AutoModelForSeq2SeqLM
|
19 |
+
|
20 |
+
import together
|
21 |
+
import warnings
|
22 |
+
from together import Together
|
23 |
+
from langchain.prompts import PromptTemplate
|
24 |
+
from langchain.llms.base import LLM
|
25 |
+
from langchain.chains import ConversationalRetrievalChain
|
26 |
+
from langchain_core.output_parsers import StrOutputParser
|
27 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
28 |
+
from langchain.llms import HuggingFacePipeline
|
29 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
30 |
+
from langchain_community.document_loaders import TextLoader
|
31 |
+
from langchain.document_loaders import UnstructuredFileLoader
|
32 |
+
from langchain_community.vectorstores import FAISS
|
33 |
+
from langchain.utils import get_from_dict_or_env
|
34 |
+
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
|
35 |
+
from pydantic.v1 import Extra, Field, root_validator
|
36 |
+
from typing import Any, Dict, List, Mapping, Optional
|
37 |
+
from langchain.memory import ConversationBufferMemory
|
38 |
+
from langchain import LLMChain, PromptTemplate
|
39 |
+
from paddleocr import PaddleOCR, draw_ocr
|
40 |
+
|
41 |
+
sys.path.append('/root/autodl-tmp/recognize-anything')
|
42 |
+
|
43 |
+
from ram.models import ram
|
44 |
+
from ram.models import tag2text
|
45 |
+
from ram import inference_ram as inference
|
46 |
+
#from ram import inference_tag2text as inference
|
47 |
+
from ram import get_transform
|
48 |
+
|
49 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
50 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
51 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
52 |
+
|
53 |
+
DEFAULT_SYSTEM_PROMPT = ""
|
54 |
+
|
55 |
+
instruction = """You are an AI assistant designed for answering questions about a video.
|
56 |
+
You are given a document and a question, the document records what people see and hear from this video.
|
57 |
+
Try to connet these information and provide a conversational answer.
|
58 |
+
Question: {question}
|
59 |
+
=========
|
60 |
+
{context}
|
61 |
+
=========
|
62 |
+
"""
|
63 |
+
|
64 |
+
system_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
|
65 |
+
You can assume the discussion is about the video content.
|
66 |
+
Chat History:
|
67 |
+
{chat_history}
|
68 |
+
Follow Up Input: {question}
|
69 |
+
Standalone question:"""
|
70 |
+
|
71 |
+
os.environ['HF_HOME'] = '/root/autodl-tmp/cache/'
|
72 |
+
os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"
|
73 |
+
|
74 |
+
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
|
75 |
+
SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
|
76 |
+
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
|
77 |
+
return prompt_template
|
78 |
+
|
79 |
+
template = get_prompt(instruction, system_prompt)
|
80 |
+
|
81 |
+
prompt = PromptTemplate(
|
82 |
+
input_variables=["chat_history", "user_input"], template=template
|
83 |
+
)
|
84 |
+
|
85 |
+
class TogetherLLM(LLM):
|
86 |
+
"""Together large language models."""
|
87 |
+
|
88 |
+
model: str = "togethercomputer/llama-2-70b-chat"
|
89 |
+
"""model endpoint to use"""
|
90 |
+
|
91 |
+
together_api_key: str = os.environ["TOGETHER_API_KEY"]
|
92 |
+
"""Together API key"""
|
93 |
+
|
94 |
+
temperature: float = 0.7
|
95 |
+
"""What sampling temperature to use."""
|
96 |
+
|
97 |
+
max_tokens: int = 512
|
98 |
+
"""The maximum number of tokens to generate in the completion."""
|
99 |
+
|
100 |
+
class Config:
|
101 |
+
extra = Extra.forbid
|
102 |
+
|
103 |
+
@root_validator()
|
104 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
105 |
+
"""Validate that the API key is set."""
|
106 |
+
api_key = get_from_dict_or_env(
|
107 |
+
values, "together_api_key", "TOGETHER_API_KEY"
|
108 |
+
)
|
109 |
+
values["together_api_key"] = api_key
|
110 |
+
return values
|
111 |
+
|
112 |
+
@property
|
113 |
+
def _llm_type(self) -> str:
|
114 |
+
"""Return type of LLM."""
|
115 |
+
return "together"
|
116 |
+
|
117 |
+
def _call(
|
118 |
+
self,
|
119 |
+
prompt: str,
|
120 |
+
**kwargs: Any,
|
121 |
+
) -> str:
|
122 |
+
"""Call to Together endpoint."""
|
123 |
+
together.api_key = self.together_api_key
|
124 |
+
output = together.Complete.create(prompt,
|
125 |
+
model=self.model,
|
126 |
+
max_tokens=self.max_tokens,
|
127 |
+
temperature=self.temperature,
|
128 |
+
top_p=0.7,
|
129 |
+
top_k=50,
|
130 |
+
repetition_penalty=1,
|
131 |
+
stop=["</s>"],
|
132 |
+
)
|
133 |
+
text = output['choices'][0]['text']
|
134 |
+
return text
|
135 |
+
|
136 |
+
class Vlogger4chat :
|
137 |
+
def __init__(self, args):
|
138 |
+
self.args = args
|
139 |
+
self.alpha = args.alpha
|
140 |
+
self.beta = args.beta
|
141 |
+
self.data_dir = args.data_dir,
|
142 |
+
self.tmp_dir = args.tmp_dir
|
143 |
+
self.models_flag = False
|
144 |
+
self.init_llm()
|
145 |
+
self.init_tag2txt()
|
146 |
+
self.history = []
|
147 |
+
if not os.path.exists(self.tmp_dir):
|
148 |
+
os.makedirs(self.tmp_dir)
|
149 |
+
|
150 |
+
def init_models(self):
|
151 |
+
print('\033[1;34m' + "Welcome to the our Vlog toolbox...".center(50, '-') + '\033[0m')
|
152 |
+
print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
|
153 |
+
print('\033[1;31m' + "This may time-consuming, please wait...".center(50, '-') + '\033[0m')
|
154 |
+
self.ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory
|
155 |
+
self.feature_extractor = FeatureExtractor(self.args)
|
156 |
+
self.video_segmenter = VideoSegmentor(alpha=self.alpha, beta=self.beta)
|
157 |
+
self.image_captioner = ImageCaptioner(model_name=self.args.captioner_base_model, device=self.args.image_captioner_device)
|
158 |
+
self.dense_captioner = DenseCaptioner(device=self.args.dense_captioner_device)
|
159 |
+
self.audio_translator = AudioTranslator(model=self.args.audio_translator, device=self.args.audio_translator_device)
|
160 |
+
print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
|
161 |
+
# 翻译文档
|
162 |
+
# 初始化 tokenizer 和 model
|
163 |
+
model_name = 'Helsinki-NLP/opus-mt-en-zh' #'Helsinki-NLP/opus-mt-zh-en'
|
164 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
165 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
166 |
+
self.my_embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-m3', model_kwargs={'device': 'cuda'} ,encode_kwargs={'normalize_embeddings': True})
|
167 |
+
|
168 |
+
def init_llm(self):
|
169 |
+
print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
|
170 |
+
self.llm = TogetherLLM(
|
171 |
+
model= "microsoft/WizardLM-2-8x22B",
|
172 |
+
temperature=0.1,
|
173 |
+
max_tokens=768
|
174 |
+
)
|
175 |
+
print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
|
176 |
+
|
177 |
+
if not self.models_flag:
|
178 |
+
self.init_models()
|
179 |
+
self.models_flag = True
|
180 |
+
|
181 |
+
def init_tag2txt(self):
|
182 |
+
self.transform = get_transform(image_size=384)
|
183 |
+
|
184 |
+
# delete some tags that may disturb captioning
|
185 |
+
# 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
|
186 |
+
delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
|
187 |
+
|
188 |
+
#######load model
|
189 |
+
#self.tag2txt_model = tag2text(pretrained='/root/autodl-tmp/recognize-anything/pretrained/tag2text_swin_14m.pth',
|
190 |
+
# image_size=384, vit='swin_b', delete_tag_index=delete_tag_index)
|
191 |
+
self.ram_model = ram(pretrained='/root/autodl-tmp/recognize-anything/pretrained/ram_swin_large_14m.pth',
|
192 |
+
image_size=384,
|
193 |
+
vit='swin_l')
|
194 |
+
#self.tag2txt_model.threshold = 0.68 # threshold for tagging
|
195 |
+
#self.tag2txt_model.eval()
|
196 |
+
self.ram_model.eval()
|
197 |
+
|
198 |
+
#self.tag2txt_model = self.tag2txt_model.to(device=self.args.dense_captioner_device)
|
199 |
+
self.ram_model = self.ram_model.to(device=self.args.dense_captioner_device)
|
200 |
+
|
201 |
+
def exist_videolog(self, video_id):
|
202 |
+
if isinstance(self.data_dir, tuple):
|
203 |
+
self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引
|
204 |
+
if isinstance(video_id, tuple):
|
205 |
+
video_id = video_id[0] # 或者根据实际情况选择合适的索引
|
206 |
+
log_path = os.path.join(self.data_dir, f"{video_id}.log")
|
207 |
+
#print(f"log_path: {log_path}\n")
|
208 |
+
|
209 |
+
if os.path.exists(log_path):
|
210 |
+
#print("existing log path!!!\n")
|
211 |
+
loader = UnstructuredFileLoader(log_path)
|
212 |
+
raw_documents = loader.load()
|
213 |
+
if not raw_documents:
|
214 |
+
print("The log file is empty or could not be loaded.")
|
215 |
+
return False # 如果 raw_documents 为空或所有内容都为空白,直接返回
|
216 |
+
# Split text
|
217 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
218 |
+
chunks = text_splitter.split_documents(raw_documents)
|
219 |
+
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
|
220 |
+
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
|
221 |
+
return True
|
222 |
+
return False
|
223 |
+
|
224 |
+
def create_videolog(self, video_id):
|
225 |
+
video_id = os.path.basename(self.video_path).split('.')[0]
|
226 |
+
log_path = os.path.join(self.data_dir, video_id + '.log')
|
227 |
+
loader = UnstructuredFileLoader(log_path)
|
228 |
+
raw_documents = loader.load()
|
229 |
+
# Split text
|
230 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
231 |
+
chunks = text_splitter.split_documents(raw_documents)
|
232 |
+
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
|
233 |
+
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
|
234 |
+
|
235 |
+
def video2log(self, video_path):
|
236 |
+
self.video_path = video_path
|
237 |
+
video_id = os.path.basename(video_path).split('.')[0]
|
238 |
+
|
239 |
+
if self.exist_videolog(video_id):
|
240 |
+
return self.printlog(video_id)
|
241 |
+
|
242 |
+
if not self.models_flag:
|
243 |
+
self.init_models()
|
244 |
+
self.models_flag = True
|
245 |
+
|
246 |
+
logger = logger_creator(video_id)
|
247 |
+
clip_features, video_length = self.feature_extractor(video_path, video_id)
|
248 |
+
seg_windows = self.video_segmenter(clip_features, video_length)
|
249 |
+
|
250 |
+
cap = cv2.VideoCapture(video_path)
|
251 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
252 |
+
audio_results = self.audio_translator(video_path)
|
253 |
+
|
254 |
+
for start_sec, end_sec in seg_windows:
|
255 |
+
middle_sec = (start_sec + end_sec) // 2
|
256 |
+
middle_frame_idx = int(middle_sec * fps)
|
257 |
+
cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_idx)
|
258 |
+
ret, frame = cap.read()
|
259 |
+
|
260 |
+
if ret:
|
261 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
262 |
+
image_caption = self.image_captioner.image_caption(frame)
|
263 |
+
dense_caption = self.dense_captioner.image_dense_caption(frame)
|
264 |
+
image = self.transform(Image.fromarray(frame)).unsqueeze(0).to(device=self.args.dense_captioner_device)
|
265 |
+
#tag2txt = inference(image, self.tag2txt_model, 'None')
|
266 |
+
ram = inference(image, self.ram_model)
|
267 |
+
audio_transcript = self.audio_translator.match(audio_results, start_sec, end_sec)
|
268 |
+
OCR_result = self.ocr.ocr(frame)
|
269 |
+
# 提取所有文本块中的所有行的文字
|
270 |
+
texts = []
|
271 |
+
for block in OCR_result:
|
272 |
+
if block is not None: # 检查 block 是否为 None
|
273 |
+
for line in block:
|
274 |
+
if line is not None: # 检查 line 是否为 None
|
275 |
+
text = line[1][0] # 提取文字部分
|
276 |
+
texts.append(text)
|
277 |
+
# 将列表中的所有文字合并成一个字符串
|
278 |
+
OCR_result_str = ' '.join(texts)
|
279 |
+
|
280 |
+
logger.info(f"When {format_time(start_sec)} - {format_time(end_sec)}")
|
281 |
+
chinese_image_caption = self.translate_text(image_caption, self.tokenizer, self.model)
|
282 |
+
#chinese_tag2txt = self.translate_text(tag2txt[2], self.tokenizer, self.model)
|
283 |
+
chinese_dense_caption = self.translate_text(dense_caption, self.tokenizer, self.model)
|
284 |
+
logger.info(f"我看到这些画面:\"{chinese_image_caption}\"")
|
285 |
+
#logger.info(f"我看见 {chinese_tag2txt}.")
|
286 |
+
logger.info(f"我发现这些内容:\"{chinese_dense_caption}\"")
|
287 |
+
logger.info(f"我检测到这些标签:\"{ram[1]}.\"")
|
288 |
+
logger.info(f"我识别到这些文字:\"{OCR_result_str}\"")
|
289 |
+
|
290 |
+
if len(audio_transcript) > 0:
|
291 |
+
#english_audio_text = self.translate_text(audio_transcript, self.tokenizer, self.model)
|
292 |
+
logger.info(f"我听到有人说:\"{audio_transcript}\"")
|
293 |
+
logger.info("\n")
|
294 |
+
|
295 |
+
cap.release()
|
296 |
+
self.create_videolog(video_id)
|
297 |
+
return self.printlog(video_id)
|
298 |
+
|
299 |
+
def printlog(self, video_id):
|
300 |
+
log_list = []
|
301 |
+
log_path = os.path.join(self.data_dir, video_id + '.log')
|
302 |
+
with open(log_path, 'r', encoding='utf-8') as f:
|
303 |
+
for line in f:
|
304 |
+
log_list.append(line.strip())
|
305 |
+
return log_list
|
306 |
+
|
307 |
+
def translate_text(self, text, tokenizer, model):
|
308 |
+
# 编码文本
|
309 |
+
encoded_text = tokenizer.prepare_seq2seq_batch([text], return_tensors='pt')
|
310 |
+
|
311 |
+
# 生成翻译
|
312 |
+
translated = model.generate(**encoded_text)
|
313 |
+
|
314 |
+
# 解码翻译后的文本
|
315 |
+
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
|
316 |
+
return translated_text
|
317 |
+
|
318 |
+
def chat2video(self, question):
|
319 |
+
print(f"Question: {question}")
|
320 |
+
|
321 |
+
response = self.chain({"question": "请用中文回答:"+question, "chat_history": self.history})['answer']
|
322 |
+
self.history.append((question, response))
|
323 |
+
return response
|
324 |
+
|
325 |
+
def clean_history(self):
|
326 |
+
self.history = []
|
327 |
+
return
|
vlog4debate.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
MAD: Multi-Agent Debate with Large Language Models
|
3 |
+
Copyright (C) 2023 The MAD Team
|
4 |
+
|
5 |
+
This program is free software: you can redistribute it and/or modify
|
6 |
+
it under the terms of the GNU General Public License as published by
|
7 |
+
the Free Software Foundation, either version 3 of the License, or
|
8 |
+
(at your option) any later version.
|
9 |
+
|
10 |
+
This program is distributed in the hope that it will be useful,
|
11 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
12 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
13 |
+
GNU General Public License for more details.
|
14 |
+
|
15 |
+
You should have received a copy of the GNU General Public License
|
16 |
+
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import os
|
20 |
+
import json
|
21 |
+
import random
|
22 |
+
# random.seed(0)
|
23 |
+
from Agent import Agent
|
24 |
+
|
25 |
+
os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"
|
26 |
+
|
27 |
+
NAME_LIST=[
|
28 |
+
"Affirmative side",
|
29 |
+
"Negative side",
|
30 |
+
"Moderator",
|
31 |
+
]
|
32 |
+
|
33 |
+
class DebatePlayer(Agent):
|
34 |
+
def __init__(self, model_name: str, name: str, temperature:float, openai_api_key: str, sleep_time: float) -> None:
|
35 |
+
"""Create a player in the debate
|
36 |
+
|
37 |
+
Args:
|
38 |
+
model_name(str): model name
|
39 |
+
name (str): name of this player
|
40 |
+
temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
|
41 |
+
openai_api_key (str): As the parameter name suggests
|
42 |
+
sleep_time (float): sleep because of rate limits
|
43 |
+
"""
|
44 |
+
super(DebatePlayer, self).__init__(model_name, name, temperature, sleep_time)
|
45 |
+
self.openai_api_key = openai_api_key
|
46 |
+
|
47 |
+
|
48 |
+
class Debate:
|
49 |
+
def __init__(self,
|
50 |
+
model_name: str='Qwen/Qwen1.5-72B-Chat',
|
51 |
+
temperature: float=0,
|
52 |
+
num_players: int=3,
|
53 |
+
openai_api_key: str=os.environ["TOGETHER_API_KEY"],
|
54 |
+
config: dict=None,
|
55 |
+
max_round: int=3,
|
56 |
+
sleep_time: float=0
|
57 |
+
) -> None:
|
58 |
+
"""Create a debate
|
59 |
+
|
60 |
+
Args:
|
61 |
+
model_name (str): openai model name
|
62 |
+
temperature (float): higher values make the output more random, while lower values make it more focused and deterministic
|
63 |
+
num_players (int): num of players
|
64 |
+
openai_api_key (str): As the parameter name suggests
|
65 |
+
max_round (int): maximum Rounds of Debate
|
66 |
+
sleep_time (float): sleep because of rate limits
|
67 |
+
"""
|
68 |
+
|
69 |
+
self.model_name = model_name
|
70 |
+
self.temperature = temperature
|
71 |
+
self.num_players = num_players
|
72 |
+
self.openai_api_key = openai_api_key
|
73 |
+
self.config = config
|
74 |
+
self.max_round = max_round
|
75 |
+
self.sleep_time = sleep_time
|
76 |
+
self.initial_debate = ''
|
77 |
+
self.init_prompt()
|
78 |
+
|
79 |
+
# creat&init agents
|
80 |
+
self.creat_agents()
|
81 |
+
self.init_agents()
|
82 |
+
|
83 |
+
|
84 |
+
def init_prompt(self):
|
85 |
+
def prompt_replace(key):
|
86 |
+
self.config[key] = self.config[key].replace("##debate_topic##", self.config["debate_topic"])
|
87 |
+
prompt_replace("player_meta_prompt")
|
88 |
+
prompt_replace("moderator_meta_prompt")
|
89 |
+
prompt_replace("affirmative_prompt")
|
90 |
+
prompt_replace("judge_prompt_last2")
|
91 |
+
|
92 |
+
def creat_agents(self):
|
93 |
+
# creates players
|
94 |
+
self.players = [
|
95 |
+
DebatePlayer(model_name=self.model_name, name=name, temperature=self.temperature, openai_api_key=self.openai_api_key, sleep_time=self.sleep_time) for name in NAME_LIST
|
96 |
+
]
|
97 |
+
self.affirmative = self.players[0]
|
98 |
+
self.negative = self.players[1]
|
99 |
+
self.moderator = self.players[2]
|
100 |
+
|
101 |
+
def init_agents(self):
|
102 |
+
# start: set meta prompt
|
103 |
+
self.affirmative.set_meta_prompt(self.config['player_meta_prompt'])
|
104 |
+
self.negative.set_meta_prompt(self.config['player_meta_prompt'])
|
105 |
+
self.moderator.set_meta_prompt(self.config['moderator_meta_prompt'])
|
106 |
+
|
107 |
+
# start: first round debate, state opinions
|
108 |
+
print(f"===== Start Debate Round =====\n")
|
109 |
+
self.affirmative.add_event(self.config['affirmative_prompt'])
|
110 |
+
self.aff_ans = self.affirmative.ask()
|
111 |
+
self.affirmative.add_memory(self.aff_ans)
|
112 |
+
self.config['base_answer'] = self.aff_ans
|
113 |
+
affirm_side = "\n\n正方观点:" + self.aff_ans
|
114 |
+
self.initial_debate += affirm_side
|
115 |
+
|
116 |
+
self.negative.add_event(self.config['negative_prompt'].replace('##aff_ans##', self.aff_ans))
|
117 |
+
self.neg_ans = self.negative.ask()
|
118 |
+
self.negative.add_memory(self.neg_ans)
|
119 |
+
neg_side = "\n\n反方观点:" + self.neg_ans
|
120 |
+
self.initial_debate += neg_side
|
121 |
+
|
122 |
+
self.moderator.add_event(self.config['moderator_prompt'].replace('##aff_ans##', self.aff_ans).replace('##neg_ans##', self.neg_ans).replace('##round##', 'first'))
|
123 |
+
self.mod_ans = self.moderator.ask()
|
124 |
+
self.moderator.add_memory(self.mod_ans)
|
125 |
+
self.mod_ans = eval(self.mod_ans)
|
126 |
+
|
127 |
+
def round_dct(self, num: int):
|
128 |
+
dct = {
|
129 |
+
1: 'first', 2: 'second', 3: 'third', 4: 'fourth', 5: 'fifth', 6: 'sixth', 7: 'seventh', 8: 'eighth', 9: 'ninth', 10: 'tenth'
|
130 |
+
}
|
131 |
+
return dct[num]
|
132 |
+
|
133 |
+
def print_answer(self):
|
134 |
+
print("\n\n===== Debate Done! =====")
|
135 |
+
print("\n----- Debate Topic -----")
|
136 |
+
print(self.config["debate_topic"])
|
137 |
+
print("\n----- Base Answer -----")
|
138 |
+
print(self.config["base_answer"])
|
139 |
+
print("\n----- Debate Answer -----")
|
140 |
+
print(self.config["debate_answer"])
|
141 |
+
print("\n----- Debate Reason -----")
|
142 |
+
print(self.config["Reason"])
|
143 |
+
|
144 |
+
def broadcast(self, msg: str):
|
145 |
+
"""Broadcast a message to all players.
|
146 |
+
Typical use is for the host to announce public information
|
147 |
+
|
148 |
+
Args:
|
149 |
+
msg (str): the message
|
150 |
+
"""
|
151 |
+
# print(msg)
|
152 |
+
for player in self.players:
|
153 |
+
player.add_event(msg)
|
154 |
+
|
155 |
+
def speak(self, speaker: str, msg: str):
|
156 |
+
"""The speaker broadcast a message to all other players.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
speaker (str): name of the speaker
|
160 |
+
msg (str): the message
|
161 |
+
"""
|
162 |
+
if not msg.startswith(f"{speaker}: "):
|
163 |
+
msg = f"{speaker}: {msg}"
|
164 |
+
# print(msg)
|
165 |
+
for player in self.players:
|
166 |
+
if player.name != speaker:
|
167 |
+
player.add_event(msg)
|
168 |
+
|
169 |
+
def ask_and_speak(self, player: DebatePlayer):
|
170 |
+
ans = player.ask()
|
171 |
+
player.add_memory(ans)
|
172 |
+
self.speak(player.name, ans)
|
173 |
+
|
174 |
+
|
175 |
+
def run(self):
|
176 |
+
for round in range(self.max_round - 1):
|
177 |
+
|
178 |
+
if self.mod_ans["debate_answer"] != '':
|
179 |
+
break
|
180 |
+
else:
|
181 |
+
print(f"===== Debate Round-{round+1} =====\n")
|
182 |
+
self.affirmative.add_event(self.config['debate_prompt'].replace('##oppo_ans##', self.neg_ans))
|
183 |
+
self.aff_ans = self.affirmative.ask()
|
184 |
+
self.affirmative.add_memory(self.aff_ans)
|
185 |
+
|
186 |
+
self.negative.add_event(self.config['debate_prompt'].replace('##oppo_ans##', self.aff_ans))
|
187 |
+
self.neg_ans = self.negative.ask()
|
188 |
+
self.negative.add_memory(self.neg_ans)
|
189 |
+
|
190 |
+
self.moderator.add_event(self.config['moderator_prompt'].replace('##aff_ans##', self.aff_ans).replace('##neg_ans##', self.neg_ans).replace('##round##', self.round_dct(round+2)))
|
191 |
+
self.mod_ans = self.moderator.ask()
|
192 |
+
self.moderator.add_memory(self.mod_ans)
|
193 |
+
self.mod_ans = eval(self.mod_ans)
|
194 |
+
|
195 |
+
if self.mod_ans["debate_answer"] != '':
|
196 |
+
self.config.update(self.mod_ans)
|
197 |
+
self.config['success'] = True
|
198 |
+
|
199 |
+
# ultimate deadly technique.
|
200 |
+
else:
|
201 |
+
judge_player = DebatePlayer(model_name=self.model_name, name='Judge', temperature=self.temperature, openai_api_key=self.openai_api_key, sleep_time=self.sleep_time)
|
202 |
+
aff_ans = self.affirmative.memory_lst[2]['content']
|
203 |
+
neg_ans = self.negative.memory_lst[2]['content']
|
204 |
+
|
205 |
+
judge_player.set_meta_prompt(self.config['moderator_meta_prompt'])
|
206 |
+
|
207 |
+
# extract answer candidates
|
208 |
+
judge_player.add_event(self.config['judge_prompt_last1'].replace('##aff_ans##', aff_ans).replace('##neg_ans##', neg_ans))
|
209 |
+
ans = judge_player.ask()
|
210 |
+
judge_player.add_memory(ans)
|
211 |
+
|
212 |
+
# select one from the candidates
|
213 |
+
judge_player.add_event(self.config['judge_prompt_last2'])
|
214 |
+
ans = judge_player.ask()
|
215 |
+
judge_player.add_memory(ans)
|
216 |
+
|
217 |
+
ans = eval(ans)
|
218 |
+
if ans["debate_answer"] != '':
|
219 |
+
self.config['success'] = True
|
220 |
+
# save file
|
221 |
+
self.config.update(ans)
|
222 |
+
self.players.append(judge_player)
|
223 |
+
|
224 |
+
self.print_answer()
|
225 |
+
combined_string = ''.join([
|
226 |
+
self.config["debate_topic"],
|
227 |
+
self.initial_debate,
|
228 |
+
f"\n\n经过{self.max_round}轮辩论......",
|
229 |
+
"\n\n仲裁观点:",
|
230 |
+
self.config["debate_answer"],
|
231 |
+
"\n\n仲裁理由:",
|
232 |
+
self.config["Reason"]
|
233 |
+
])
|
234 |
+
|
235 |
+
return combined_string
|