VLog4CustomLLMsPlusQA-3 / vlog4chat.py
dj86's picture
Upload 3 files
62a001a verified
raw
history blame
14.4 kB
import os
import cv2
import pdb
import sys
import time
import numpy as np
from PIL import Image
from transformers import logging
logging.set_verbosity_error()
from models.kts_model import VideoSegmentor
from models.clip_model import FeatureExtractor
from models.blip2_model import ImageCaptioner
from models.grit_model import DenseCaptioner
from models.whisper_model import AudioTranslator
from models.gpt_model import LlmReasoner
from utils.utils import logger_creator, format_time
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, AutoModelForSeq2SeqLM
import together
import warnings
from together import Together
from langchain.prompts import PromptTemplate
from langchain.llms.base import LLM
from langchain.chains import ConversationalRetrievalChain
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import HuggingFacePipeline
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain.document_loaders import UnstructuredFileLoader
from langchain_community.vectorstores import FAISS
from langchain.utils import get_from_dict_or_env
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from pydantic.v1 import Extra, Field, root_validator
from typing import Any, Dict, List, Mapping, Optional
from langchain.memory import ConversationBufferMemory
from langchain import LLMChain, PromptTemplate
from paddleocr import PaddleOCR, draw_ocr
sys.path.append('/root/autodl-tmp/recognize-anything')
from ram.models import ram
from ram.models import tag2text
from ram import inference_ram as inference
#from ram import inference_tag2text as inference
from ram import get_transform
warnings.filterwarnings("ignore", category=UserWarning)
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = ""
instruction = """You are an AI assistant designed for answering questions about a video.
You are given a document and a question, the document records what people see and hear from this video.
Try to connet these information and provide a conversational answer.
Question: {question}
=========
{context}
=========
"""
system_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
You can assume the discussion is about the video content.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
os.environ['HF_HOME'] = '/root/autodl-tmp/cache/'
os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
return prompt_template
template = get_prompt(instruction, system_prompt)
prompt = PromptTemplate(
input_variables=["chat_history", "user_input"], template=template
)
class TogetherLLM(LLM):
"""Together large language models."""
model: str = "togethercomputer/llama-2-70b-chat"
"""model endpoint to use"""
together_api_key: str = os.environ["TOGETHER_API_KEY"]
"""Together API key"""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion."""
class Config:
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the API key is set."""
api_key = get_from_dict_or_env(
values, "together_api_key", "TOGETHER_API_KEY"
)
values["together_api_key"] = api_key
return values
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "together"
def _call(
self,
prompt: str,
**kwargs: Any,
) -> str:
"""Call to Together endpoint."""
together.api_key = self.together_api_key
output = together.Complete.create(prompt,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=0.7,
top_k=50,
repetition_penalty=1,
stop=["</s>"],
)
text = output['choices'][0]['text']
return text
class Vlogger4chat :
def __init__(self, args):
self.args = args
self.alpha = args.alpha
self.beta = args.beta
self.data_dir = args.data_dir,
self.tmp_dir = args.tmp_dir
self.models_flag = False
self.init_llm()
self.init_tag2txt()
self.history = []
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
def init_models(self):
print('\033[1;34m' + "Welcome to the our Vlog toolbox...".center(50, '-') + '\033[0m')
print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
print('\033[1;31m' + "This may time-consuming, please wait...".center(50, '-') + '\033[0m')
self.ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory
self.feature_extractor = FeatureExtractor(self.args)
self.video_segmenter = VideoSegmentor(alpha=self.alpha, beta=self.beta)
self.image_captioner = ImageCaptioner(model_name=self.args.captioner_base_model, device=self.args.image_captioner_device)
self.dense_captioner = DenseCaptioner(device=self.args.dense_captioner_device)
self.audio_translator = AudioTranslator(model=self.args.audio_translator, device=self.args.audio_translator_device)
print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
# 翻译文档
# 初始化 tokenizer 和 model
model_name = 'Helsinki-NLP/opus-mt-en-zh' #'Helsinki-NLP/opus-mt-zh-en'
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.my_embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-m3', model_kwargs={'device': 'cuda'} ,encode_kwargs={'normalize_embeddings': True})
def init_llm(self):
print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
self.llm = TogetherLLM(
model= "microsoft/WizardLM-2-8x22B",
temperature=0.1,
max_tokens=768
)
print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
if not self.models_flag:
self.init_models()
self.models_flag = True
def init_tag2txt(self):
self.transform = get_transform(image_size=384)
# delete some tags that may disturb captioning
# 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
#######load model
#self.tag2txt_model = tag2text(pretrained='/root/autodl-tmp/recognize-anything/pretrained/tag2text_swin_14m.pth',
# image_size=384, vit='swin_b', delete_tag_index=delete_tag_index)
self.ram_model = ram(pretrained='/root/autodl-tmp/recognize-anything/pretrained/ram_swin_large_14m.pth',
image_size=384,
vit='swin_l')
#self.tag2txt_model.threshold = 0.68 # threshold for tagging
#self.tag2txt_model.eval()
self.ram_model.eval()
#self.tag2txt_model = self.tag2txt_model.to(device=self.args.dense_captioner_device)
self.ram_model = self.ram_model.to(device=self.args.dense_captioner_device)
def exist_videolog(self, video_id):
if isinstance(self.data_dir, tuple):
self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引
if isinstance(video_id, tuple):
video_id = video_id[0] # 或者根据实际情况选择合适的索引
log_path = os.path.join(self.data_dir, f"{video_id}.log")
#print(f"log_path: {log_path}\n")
if os.path.exists(log_path):
#print("existing log path!!!\n")
loader = UnstructuredFileLoader(log_path)
raw_documents = loader.load()
if not raw_documents:
print("The log file is empty or could not be loaded.")
return False # 如果 raw_documents 为空或所有内容都为空白,直接返回
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = text_splitter.split_documents(raw_documents)
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
return True
return False
def create_videolog(self, video_id):
video_id = os.path.basename(self.video_path).split('.')[0]
log_path = os.path.join(self.data_dir, video_id + '.log')
loader = UnstructuredFileLoader(log_path)
raw_documents = loader.load()
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = text_splitter.split_documents(raw_documents)
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
def video2log(self, video_path):
self.video_path = video_path
video_id = os.path.basename(video_path).split('.')[0]
if self.exist_videolog(video_id):
return self.printlog(video_id)
if not self.models_flag:
self.init_models()
self.models_flag = True
logger = logger_creator(video_id)
clip_features, video_length = self.feature_extractor(video_path, video_id)
seg_windows = self.video_segmenter(clip_features, video_length)
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
audio_results = self.audio_translator(video_path)
for start_sec, end_sec in seg_windows:
middle_sec = (start_sec + end_sec) // 2
middle_frame_idx = int(middle_sec * fps)
cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_idx)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
image_caption = self.image_captioner.image_caption(frame)
dense_caption = self.dense_captioner.image_dense_caption(frame)
image = self.transform(Image.fromarray(frame)).unsqueeze(0).to(device=self.args.dense_captioner_device)
#tag2txt = inference(image, self.tag2txt_model, 'None')
ram = inference(image, self.ram_model)
audio_transcript = self.audio_translator.match(audio_results, start_sec, end_sec)
OCR_result = self.ocr.ocr(frame)
# 提取所有文本块中的所有行的文字
texts = []
for block in OCR_result:
if block is not None: # 检查 block 是否为 None
for line in block:
if line is not None: # 检查 line 是否为 None
text = line[1][0] # 提取文字部分
texts.append(text)
# 将列表中的所有文字合并成一个字符串
OCR_result_str = ' '.join(texts)
logger.info(f"When {format_time(start_sec)} - {format_time(end_sec)}")
chinese_image_caption = self.translate_text(image_caption, self.tokenizer, self.model)
#chinese_tag2txt = self.translate_text(tag2txt[2], self.tokenizer, self.model)
chinese_dense_caption = self.translate_text(dense_caption, self.tokenizer, self.model)
logger.info(f"我看到这些画面:\"{chinese_image_caption}\"")
#logger.info(f"我看见 {chinese_tag2txt}.")
logger.info(f"我发现这些内容:\"{chinese_dense_caption}\"")
logger.info(f"我检测到这些标签:\"{ram[1]}.\"")
logger.info(f"我识别到这些文字:\"{OCR_result_str}\"")
if len(audio_transcript) > 0:
#english_audio_text = self.translate_text(audio_transcript, self.tokenizer, self.model)
logger.info(f"我听到有人说:\"{audio_transcript}\"")
logger.info("\n")
cap.release()
self.create_videolog(video_id)
return self.printlog(video_id)
def printlog(self, video_id):
log_list = []
log_path = os.path.join(self.data_dir, video_id + '.log')
with open(log_path, 'r', encoding='utf-8') as f:
for line in f:
log_list.append(line.strip())
return log_list
def translate_text(self, text, tokenizer, model):
# 编码文本
encoded_text = tokenizer.prepare_seq2seq_batch([text], return_tensors='pt')
# 生成翻译
translated = model.generate(**encoded_text)
# 解码翻译后的文本
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
return translated_text
def chat2video(self, question):
print(f"Question: {question}")
response = self.chain({"question": "请用中文回答:"+question, "chat_history": self.history})['answer']
self.history.append((question, response))
return response
def clean_history(self):
self.history = []
return