Spaces:
Sleeping
Sleeping
Update vlog4chat.py
Browse files- vlog4chat.py +20 -129
vlog4chat.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1 |
import os
|
2 |
-
import cv2
|
3 |
-
import pdb
|
4 |
import sys
|
5 |
import time
|
6 |
import numpy as np
|
7 |
-
from PIL import Image
|
8 |
-
from transformers import logging
|
9 |
-
logging.set_verbosity_error()
|
10 |
-
|
11 |
-
from models.kts_model import VideoSegmentor
|
12 |
-
from models.clip_model import FeatureExtractor
|
13 |
-
from models.blip2_model import ImageCaptioner
|
14 |
-
from models.grit_model import DenseCaptioner
|
15 |
-
from models.whisper_model import AudioTranslator
|
16 |
-
from models.gpt_model import LlmReasoner
|
17 |
from utils.utils import logger_creator, format_time
|
18 |
-
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, AutoModelForSeq2SeqLM
|
19 |
|
20 |
import together
|
21 |
import warnings
|
@@ -38,13 +38,13 @@ from langchain.memory import ConversationBufferMemory
|
|
38 |
from langchain import LLMChain, PromptTemplate
|
39 |
from paddleocr import PaddleOCR, draw_ocr
|
40 |
|
41 |
-
sys.path.append('/root/autodl-tmp/recognize-anything')
|
42 |
|
43 |
-
from ram.models import ram
|
44 |
-
from ram.models import tag2text
|
45 |
-
from ram import inference_ram as inference
|
46 |
#from ram import inference_tag2text as inference
|
47 |
-
from ram import get_transform
|
48 |
|
49 |
warnings.filterwarnings("ignore", category=UserWarning)
|
50 |
B_INST, E_INST = "[INST]", "[/INST]"
|
@@ -68,7 +68,7 @@ Chat History:
|
|
68 |
Follow Up Input: {question}
|
69 |
Standalone question:"""
|
70 |
|
71 |
-
os.environ['HF_HOME'] = '/root/autodl-tmp/cache/'
|
72 |
os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"
|
73 |
|
74 |
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
|
@@ -142,29 +142,10 @@ class Vlogger4chat :
|
|
142 |
self.tmp_dir = args.tmp_dir
|
143 |
self.models_flag = False
|
144 |
self.init_llm()
|
145 |
-
self.init_tag2txt()
|
146 |
self.history = []
|
147 |
if not os.path.exists(self.tmp_dir):
|
148 |
os.makedirs(self.tmp_dir)
|
149 |
-
|
150 |
-
def init_models(self):
|
151 |
-
print('\033[1;34m' + "Welcome to the our Vlog toolbox...".center(50, '-') + '\033[0m')
|
152 |
-
print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
|
153 |
-
print('\033[1;31m' + "This may time-consuming, please wait...".center(50, '-') + '\033[0m')
|
154 |
-
self.ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory
|
155 |
-
self.feature_extractor = FeatureExtractor(self.args)
|
156 |
-
self.video_segmenter = VideoSegmentor(alpha=self.alpha, beta=self.beta)
|
157 |
-
self.image_captioner = ImageCaptioner(model_name=self.args.captioner_base_model, device=self.args.image_captioner_device)
|
158 |
-
self.dense_captioner = DenseCaptioner(device=self.args.dense_captioner_device)
|
159 |
-
self.audio_translator = AudioTranslator(model=self.args.audio_translator, device=self.args.audio_translator_device)
|
160 |
-
print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
|
161 |
-
# 翻译文档
|
162 |
-
# 初始化 tokenizer 和 model
|
163 |
-
model_name = 'Helsinki-NLP/opus-mt-en-zh' #'Helsinki-NLP/opus-mt-zh-en'
|
164 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
165 |
-
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
166 |
-
self.my_embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-m3', model_kwargs={'device': 'cuda'} ,encode_kwargs={'normalize_embeddings': True})
|
167 |
-
|
168 |
def init_llm(self):
|
169 |
print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
|
170 |
self.llm = TogetherLLM(
|
@@ -173,31 +154,7 @@ class Vlogger4chat :
|
|
173 |
max_tokens=768
|
174 |
)
|
175 |
print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
|
176 |
-
|
177 |
-
if not self.models_flag:
|
178 |
-
self.init_models()
|
179 |
-
self.models_flag = True
|
180 |
|
181 |
-
def init_tag2txt(self):
|
182 |
-
self.transform = get_transform(image_size=384)
|
183 |
-
|
184 |
-
# delete some tags that may disturb captioning
|
185 |
-
# 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
|
186 |
-
delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
|
187 |
-
|
188 |
-
#######load model
|
189 |
-
#self.tag2txt_model = tag2text(pretrained='/root/autodl-tmp/recognize-anything/pretrained/tag2text_swin_14m.pth',
|
190 |
-
# image_size=384, vit='swin_b', delete_tag_index=delete_tag_index)
|
191 |
-
self.ram_model = ram(pretrained='/root/autodl-tmp/recognize-anything/pretrained/ram_swin_large_14m.pth',
|
192 |
-
image_size=384,
|
193 |
-
vit='swin_l')
|
194 |
-
#self.tag2txt_model.threshold = 0.68 # threshold for tagging
|
195 |
-
#self.tag2txt_model.eval()
|
196 |
-
self.ram_model.eval()
|
197 |
-
|
198 |
-
#self.tag2txt_model = self.tag2txt_model.to(device=self.args.dense_captioner_device)
|
199 |
-
self.ram_model = self.ram_model.to(device=self.args.dense_captioner_device)
|
200 |
-
|
201 |
def exist_videolog(self, video_id):
|
202 |
if isinstance(self.data_dir, tuple):
|
203 |
self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引
|
@@ -239,61 +196,6 @@ class Vlogger4chat :
|
|
239 |
if self.exist_videolog(video_id):
|
240 |
return self.printlog(video_id)
|
241 |
|
242 |
-
if not self.models_flag:
|
243 |
-
self.init_models()
|
244 |
-
self.models_flag = True
|
245 |
-
|
246 |
-
logger = logger_creator(video_id)
|
247 |
-
clip_features, video_length = self.feature_extractor(video_path, video_id)
|
248 |
-
seg_windows = self.video_segmenter(clip_features, video_length)
|
249 |
-
|
250 |
-
cap = cv2.VideoCapture(video_path)
|
251 |
-
fps = cap.get(cv2.CAP_PROP_FPS)
|
252 |
-
audio_results = self.audio_translator(video_path)
|
253 |
-
|
254 |
-
for start_sec, end_sec in seg_windows:
|
255 |
-
middle_sec = (start_sec + end_sec) // 2
|
256 |
-
middle_frame_idx = int(middle_sec * fps)
|
257 |
-
cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_idx)
|
258 |
-
ret, frame = cap.read()
|
259 |
-
|
260 |
-
if ret:
|
261 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
262 |
-
image_caption = self.image_captioner.image_caption(frame)
|
263 |
-
dense_caption = self.dense_captioner.image_dense_caption(frame)
|
264 |
-
image = self.transform(Image.fromarray(frame)).unsqueeze(0).to(device=self.args.dense_captioner_device)
|
265 |
-
#tag2txt = inference(image, self.tag2txt_model, 'None')
|
266 |
-
ram = inference(image, self.ram_model)
|
267 |
-
audio_transcript = self.audio_translator.match(audio_results, start_sec, end_sec)
|
268 |
-
OCR_result = self.ocr.ocr(frame)
|
269 |
-
# 提取所有文本块中的所有行的文字
|
270 |
-
texts = []
|
271 |
-
for block in OCR_result:
|
272 |
-
if block is not None: # 检查 block 是否为 None
|
273 |
-
for line in block:
|
274 |
-
if line is not None: # 检查 line 是否为 None
|
275 |
-
text = line[1][0] # 提取文字部分
|
276 |
-
texts.append(text)
|
277 |
-
# 将列表中的所有文字合并成一个字符串
|
278 |
-
OCR_result_str = ' '.join(texts)
|
279 |
-
|
280 |
-
logger.info(f"When {format_time(start_sec)} - {format_time(end_sec)}")
|
281 |
-
chinese_image_caption = self.translate_text(image_caption, self.tokenizer, self.model)
|
282 |
-
#chinese_tag2txt = self.translate_text(tag2txt[2], self.tokenizer, self.model)
|
283 |
-
chinese_dense_caption = self.translate_text(dense_caption, self.tokenizer, self.model)
|
284 |
-
logger.info(f"我看到这些画面:\"{chinese_image_caption}\"")
|
285 |
-
#logger.info(f"我看见 {chinese_tag2txt}.")
|
286 |
-
logger.info(f"我发现这些内容:\"{chinese_dense_caption}\"")
|
287 |
-
logger.info(f"我检测到这些标签:\"{ram[1]}.\"")
|
288 |
-
logger.info(f"我识别到这些文字:\"{OCR_result_str}\"")
|
289 |
-
|
290 |
-
if len(audio_transcript) > 0:
|
291 |
-
#english_audio_text = self.translate_text(audio_transcript, self.tokenizer, self.model)
|
292 |
-
logger.info(f"我听到有人说:\"{audio_transcript}\"")
|
293 |
-
logger.info("\n")
|
294 |
-
|
295 |
-
cap.release()
|
296 |
-
self.create_videolog(video_id)
|
297 |
return self.printlog(video_id)
|
298 |
|
299 |
def printlog(self, video_id):
|
@@ -303,17 +205,6 @@ class Vlogger4chat :
|
|
303 |
for line in f:
|
304 |
log_list.append(line.strip())
|
305 |
return log_list
|
306 |
-
|
307 |
-
def translate_text(self, text, tokenizer, model):
|
308 |
-
# 编码文本
|
309 |
-
encoded_text = tokenizer.prepare_seq2seq_batch([text], return_tensors='pt')
|
310 |
-
|
311 |
-
# 生成翻译
|
312 |
-
translated = model.generate(**encoded_text)
|
313 |
-
|
314 |
-
# 解码翻译后的文本
|
315 |
-
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
|
316 |
-
return translated_text
|
317 |
|
318 |
def chat2video(self, question):
|
319 |
print(f"Question: {question}")
|
|
|
1 |
import os
|
2 |
+
#import cv2
|
3 |
+
#import pdb
|
4 |
import sys
|
5 |
import time
|
6 |
import numpy as np
|
7 |
+
#from PIL import Image
|
8 |
+
#from transformers import logging
|
9 |
+
#logging.set_verbosity_error()
|
10 |
+
|
11 |
+
#from models.kts_model import VideoSegmentor
|
12 |
+
#from models.clip_model import FeatureExtractor
|
13 |
+
#from models.blip2_model import ImageCaptioner
|
14 |
+
#from models.grit_model import DenseCaptioner
|
15 |
+
#from models.whisper_model import AudioTranslator
|
16 |
+
#from models.gpt_model import LlmReasoner
|
17 |
from utils.utils import logger_creator, format_time
|
18 |
+
#from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, AutoModelForSeq2SeqLM
|
19 |
|
20 |
import together
|
21 |
import warnings
|
|
|
38 |
from langchain import LLMChain, PromptTemplate
|
39 |
from paddleocr import PaddleOCR, draw_ocr
|
40 |
|
41 |
+
#sys.path.append('/root/autodl-tmp/recognize-anything')
|
42 |
|
43 |
+
#from ram.models import ram
|
44 |
+
#from ram.models import tag2text
|
45 |
+
#from ram import inference_ram as inference
|
46 |
#from ram import inference_tag2text as inference
|
47 |
+
#from ram import get_transform
|
48 |
|
49 |
warnings.filterwarnings("ignore", category=UserWarning)
|
50 |
B_INST, E_INST = "[INST]", "[/INST]"
|
|
|
68 |
Follow Up Input: {question}
|
69 |
Standalone question:"""
|
70 |
|
71 |
+
#os.environ['HF_HOME'] = '/root/autodl-tmp/cache/'
|
72 |
os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"
|
73 |
|
74 |
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
|
|
|
142 |
self.tmp_dir = args.tmp_dir
|
143 |
self.models_flag = False
|
144 |
self.init_llm()
|
|
|
145 |
self.history = []
|
146 |
if not os.path.exists(self.tmp_dir):
|
147 |
os.makedirs(self.tmp_dir)
|
148 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
def init_llm(self):
|
150 |
print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
|
151 |
self.llm = TogetherLLM(
|
|
|
154 |
max_tokens=768
|
155 |
)
|
156 |
print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
|
|
|
|
|
|
|
|
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
def exist_videolog(self, video_id):
|
159 |
if isinstance(self.data_dir, tuple):
|
160 |
self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引
|
|
|
196 |
if self.exist_videolog(video_id):
|
197 |
return self.printlog(video_id)
|
198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
return self.printlog(video_id)
|
200 |
|
201 |
def printlog(self, video_id):
|
|
|
205 |
for line in f:
|
206 |
log_list.append(line.strip())
|
207 |
return log_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
|
209 |
def chat2video(self, question):
|
210 |
print(f"Question: {question}")
|