dj86 commited on
Commit
9b2ec19
1 Parent(s): 62a001a

Update vlog4chat.py

Browse files
Files changed (1) hide show
  1. vlog4chat.py +20 -129
vlog4chat.py CHANGED
@@ -1,21 +1,21 @@
1
  import os
2
- import cv2
3
- import pdb
4
  import sys
5
  import time
6
  import numpy as np
7
- from PIL import Image
8
- from transformers import logging
9
- logging.set_verbosity_error()
10
-
11
- from models.kts_model import VideoSegmentor
12
- from models.clip_model import FeatureExtractor
13
- from models.blip2_model import ImageCaptioner
14
- from models.grit_model import DenseCaptioner
15
- from models.whisper_model import AudioTranslator
16
- from models.gpt_model import LlmReasoner
17
  from utils.utils import logger_creator, format_time
18
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, AutoModelForSeq2SeqLM
19
 
20
  import together
21
  import warnings
@@ -38,13 +38,13 @@ from langchain.memory import ConversationBufferMemory
38
  from langchain import LLMChain, PromptTemplate
39
  from paddleocr import PaddleOCR, draw_ocr
40
 
41
- sys.path.append('/root/autodl-tmp/recognize-anything')
42
 
43
- from ram.models import ram
44
- from ram.models import tag2text
45
- from ram import inference_ram as inference
46
  #from ram import inference_tag2text as inference
47
- from ram import get_transform
48
 
49
  warnings.filterwarnings("ignore", category=UserWarning)
50
  B_INST, E_INST = "[INST]", "[/INST]"
@@ -68,7 +68,7 @@ Chat History:
68
  Follow Up Input: {question}
69
  Standalone question:"""
70
 
71
- os.environ['HF_HOME'] = '/root/autodl-tmp/cache/'
72
  os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"
73
 
74
  def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
@@ -142,29 +142,10 @@ class Vlogger4chat :
142
  self.tmp_dir = args.tmp_dir
143
  self.models_flag = False
144
  self.init_llm()
145
- self.init_tag2txt()
146
  self.history = []
147
  if not os.path.exists(self.tmp_dir):
148
  os.makedirs(self.tmp_dir)
149
-
150
- def init_models(self):
151
- print('\033[1;34m' + "Welcome to the our Vlog toolbox...".center(50, '-') + '\033[0m')
152
- print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
153
- print('\033[1;31m' + "This may time-consuming, please wait...".center(50, '-') + '\033[0m')
154
- self.ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory
155
- self.feature_extractor = FeatureExtractor(self.args)
156
- self.video_segmenter = VideoSegmentor(alpha=self.alpha, beta=self.beta)
157
- self.image_captioner = ImageCaptioner(model_name=self.args.captioner_base_model, device=self.args.image_captioner_device)
158
- self.dense_captioner = DenseCaptioner(device=self.args.dense_captioner_device)
159
- self.audio_translator = AudioTranslator(model=self.args.audio_translator, device=self.args.audio_translator_device)
160
- print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
161
- # 翻译文档
162
- # 初始化 tokenizer 和 model
163
- model_name = 'Helsinki-NLP/opus-mt-en-zh' #'Helsinki-NLP/opus-mt-zh-en'
164
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
165
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
166
- self.my_embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-m3', model_kwargs={'device': 'cuda'} ,encode_kwargs={'normalize_embeddings': True})
167
-
168
  def init_llm(self):
169
  print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
170
  self.llm = TogetherLLM(
@@ -173,31 +154,7 @@ class Vlogger4chat :
173
  max_tokens=768
174
  )
175
  print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
176
-
177
- if not self.models_flag:
178
- self.init_models()
179
- self.models_flag = True
180
 
181
- def init_tag2txt(self):
182
- self.transform = get_transform(image_size=384)
183
-
184
- # delete some tags that may disturb captioning
185
- # 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
186
- delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
187
-
188
- #######load model
189
- #self.tag2txt_model = tag2text(pretrained='/root/autodl-tmp/recognize-anything/pretrained/tag2text_swin_14m.pth',
190
- # image_size=384, vit='swin_b', delete_tag_index=delete_tag_index)
191
- self.ram_model = ram(pretrained='/root/autodl-tmp/recognize-anything/pretrained/ram_swin_large_14m.pth',
192
- image_size=384,
193
- vit='swin_l')
194
- #self.tag2txt_model.threshold = 0.68 # threshold for tagging
195
- #self.tag2txt_model.eval()
196
- self.ram_model.eval()
197
-
198
- #self.tag2txt_model = self.tag2txt_model.to(device=self.args.dense_captioner_device)
199
- self.ram_model = self.ram_model.to(device=self.args.dense_captioner_device)
200
-
201
  def exist_videolog(self, video_id):
202
  if isinstance(self.data_dir, tuple):
203
  self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引
@@ -239,61 +196,6 @@ class Vlogger4chat :
239
  if self.exist_videolog(video_id):
240
  return self.printlog(video_id)
241
 
242
- if not self.models_flag:
243
- self.init_models()
244
- self.models_flag = True
245
-
246
- logger = logger_creator(video_id)
247
- clip_features, video_length = self.feature_extractor(video_path, video_id)
248
- seg_windows = self.video_segmenter(clip_features, video_length)
249
-
250
- cap = cv2.VideoCapture(video_path)
251
- fps = cap.get(cv2.CAP_PROP_FPS)
252
- audio_results = self.audio_translator(video_path)
253
-
254
- for start_sec, end_sec in seg_windows:
255
- middle_sec = (start_sec + end_sec) // 2
256
- middle_frame_idx = int(middle_sec * fps)
257
- cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_idx)
258
- ret, frame = cap.read()
259
-
260
- if ret:
261
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
262
- image_caption = self.image_captioner.image_caption(frame)
263
- dense_caption = self.dense_captioner.image_dense_caption(frame)
264
- image = self.transform(Image.fromarray(frame)).unsqueeze(0).to(device=self.args.dense_captioner_device)
265
- #tag2txt = inference(image, self.tag2txt_model, 'None')
266
- ram = inference(image, self.ram_model)
267
- audio_transcript = self.audio_translator.match(audio_results, start_sec, end_sec)
268
- OCR_result = self.ocr.ocr(frame)
269
- # 提取所有文本块中的所有行的文字
270
- texts = []
271
- for block in OCR_result:
272
- if block is not None: # 检查 block 是否为 None
273
- for line in block:
274
- if line is not None: # 检查 line 是否为 None
275
- text = line[1][0] # 提取文字部分
276
- texts.append(text)
277
- # 将列表中的所有文字合并成一个字符串
278
- OCR_result_str = ' '.join(texts)
279
-
280
- logger.info(f"When {format_time(start_sec)} - {format_time(end_sec)}")
281
- chinese_image_caption = self.translate_text(image_caption, self.tokenizer, self.model)
282
- #chinese_tag2txt = self.translate_text(tag2txt[2], self.tokenizer, self.model)
283
- chinese_dense_caption = self.translate_text(dense_caption, self.tokenizer, self.model)
284
- logger.info(f"我看到这些画面:\"{chinese_image_caption}\"")
285
- #logger.info(f"我看见 {chinese_tag2txt}.")
286
- logger.info(f"我发现这些内容:\"{chinese_dense_caption}\"")
287
- logger.info(f"我检测到这些标签:\"{ram[1]}.\"")
288
- logger.info(f"我识别到这些文字:\"{OCR_result_str}\"")
289
-
290
- if len(audio_transcript) > 0:
291
- #english_audio_text = self.translate_text(audio_transcript, self.tokenizer, self.model)
292
- logger.info(f"我听到有人说:\"{audio_transcript}\"")
293
- logger.info("\n")
294
-
295
- cap.release()
296
- self.create_videolog(video_id)
297
  return self.printlog(video_id)
298
 
299
  def printlog(self, video_id):
@@ -303,17 +205,6 @@ class Vlogger4chat :
303
  for line in f:
304
  log_list.append(line.strip())
305
  return log_list
306
-
307
- def translate_text(self, text, tokenizer, model):
308
- # 编码文本
309
- encoded_text = tokenizer.prepare_seq2seq_batch([text], return_tensors='pt')
310
-
311
- # 生成翻译
312
- translated = model.generate(**encoded_text)
313
-
314
- # 解码翻译后的文本
315
- translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
316
- return translated_text
317
 
318
  def chat2video(self, question):
319
  print(f"Question: {question}")
 
1
  import os
2
+ #import cv2
3
+ #import pdb
4
  import sys
5
  import time
6
  import numpy as np
7
+ #from PIL import Image
8
+ #from transformers import logging
9
+ #logging.set_verbosity_error()
10
+
11
+ #from models.kts_model import VideoSegmentor
12
+ #from models.clip_model import FeatureExtractor
13
+ #from models.blip2_model import ImageCaptioner
14
+ #from models.grit_model import DenseCaptioner
15
+ #from models.whisper_model import AudioTranslator
16
+ #from models.gpt_model import LlmReasoner
17
  from utils.utils import logger_creator, format_time
18
+ #from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, AutoModelForSeq2SeqLM
19
 
20
  import together
21
  import warnings
 
38
  from langchain import LLMChain, PromptTemplate
39
  from paddleocr import PaddleOCR, draw_ocr
40
 
41
+ #sys.path.append('/root/autodl-tmp/recognize-anything')
42
 
43
+ #from ram.models import ram
44
+ #from ram.models import tag2text
45
+ #from ram import inference_ram as inference
46
  #from ram import inference_tag2text as inference
47
+ #from ram import get_transform
48
 
49
  warnings.filterwarnings("ignore", category=UserWarning)
50
  B_INST, E_INST = "[INST]", "[/INST]"
 
68
  Follow Up Input: {question}
69
  Standalone question:"""
70
 
71
+ #os.environ['HF_HOME'] = '/root/autodl-tmp/cache/'
72
  os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"
73
 
74
  def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
 
142
  self.tmp_dir = args.tmp_dir
143
  self.models_flag = False
144
  self.init_llm()
 
145
  self.history = []
146
  if not os.path.exists(self.tmp_dir):
147
  os.makedirs(self.tmp_dir)
148
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def init_llm(self):
150
  print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
151
  self.llm = TogetherLLM(
 
154
  max_tokens=768
155
  )
156
  print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
 
 
 
 
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def exist_videolog(self, video_id):
159
  if isinstance(self.data_dir, tuple):
160
  self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引
 
196
  if self.exist_videolog(video_id):
197
  return self.printlog(video_id)
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  return self.printlog(video_id)
200
 
201
  def printlog(self, video_id):
 
205
  for line in f:
206
  log_list.append(line.strip())
207
  return log_list
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  def chat2video(self, question):
210
  print(f"Question: {question}")