Spaces:
Sleeping
Sleeping
import os | |
import cv2 | |
import pdb | |
import sys | |
import time | |
import numpy as np | |
from PIL import Image | |
from transformers import logging | |
logging.set_verbosity_error() | |
from models.kts_model import VideoSegmentor | |
from models.clip_model import FeatureExtractor | |
from models.blip2_model import ImageCaptioner | |
from models.grit_model import DenseCaptioner | |
from models.whisper_model import AudioTranslator | |
from models.gpt_model import LlmReasoner | |
from utils.utils import logger_creator, format_time | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, AutoModelForSeq2SeqLM | |
import together | |
import warnings | |
from together import Together | |
from langchain.prompts import PromptTemplate | |
from langchain.llms.base import LLM | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.llms import HuggingFacePipeline | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.document_loaders import TextLoader | |
from langchain.document_loaders import UnstructuredFileLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain.utils import get_from_dict_or_env | |
from langchain_core.runnables import RunnablePassthrough, RunnableParallel | |
from pydantic.v1 import Extra, Field, root_validator | |
from typing import Any, Dict, List, Mapping, Optional | |
from langchain.memory import ConversationBufferMemory | |
from langchain import LLMChain, PromptTemplate | |
from paddleocr import PaddleOCR, draw_ocr | |
sys.path.append('/root/autodl-tmp/recognize-anything') | |
from ram.models import ram | |
from ram.models import tag2text | |
from ram import inference_ram as inference | |
#from ram import inference_tag2text as inference | |
from ram import get_transform | |
warnings.filterwarnings("ignore", category=UserWarning) | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
DEFAULT_SYSTEM_PROMPT = "" | |
instruction = """You are an AI assistant designed for answering questions about a video. | |
You are given a document and a question, the document records what people see and hear from this video. | |
Try to connet these information and provide a conversational answer. | |
Question: {question} | |
========= | |
{context} | |
========= | |
""" | |
system_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. | |
You can assume the discussion is about the video content. | |
Chat History: | |
{chat_history} | |
Follow Up Input: {question} | |
Standalone question:""" | |
os.environ['HF_HOME'] = '/root/autodl-tmp/cache/' | |
os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af" | |
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ): | |
SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS | |
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST | |
return prompt_template | |
template = get_prompt(instruction, system_prompt) | |
prompt = PromptTemplate( | |
input_variables=["chat_history", "user_input"], template=template | |
) | |
class TogetherLLM(LLM): | |
"""Together large language models.""" | |
model: str = "togethercomputer/llama-2-70b-chat" | |
"""model endpoint to use""" | |
together_api_key: str = os.environ["TOGETHER_API_KEY"] | |
"""Together API key""" | |
temperature: float = 0.7 | |
"""What sampling temperature to use.""" | |
max_tokens: int = 512 | |
"""The maximum number of tokens to generate in the completion.""" | |
class Config: | |
extra = Extra.forbid | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that the API key is set.""" | |
api_key = get_from_dict_or_env( | |
values, "together_api_key", "TOGETHER_API_KEY" | |
) | |
values["together_api_key"] = api_key | |
return values | |
def _llm_type(self) -> str: | |
"""Return type of LLM.""" | |
return "together" | |
def _call( | |
self, | |
prompt: str, | |
**kwargs: Any, | |
) -> str: | |
"""Call to Together endpoint.""" | |
together.api_key = self.together_api_key | |
output = together.Complete.create(prompt, | |
model=self.model, | |
max_tokens=self.max_tokens, | |
temperature=self.temperature, | |
top_p=0.7, | |
top_k=50, | |
repetition_penalty=1, | |
stop=["</s>"], | |
) | |
text = output['choices'][0]['text'] | |
return text | |
class Vlogger4chat : | |
def __init__(self, args): | |
self.args = args | |
self.alpha = args.alpha | |
self.beta = args.beta | |
self.data_dir = args.data_dir, | |
self.tmp_dir = args.tmp_dir | |
self.models_flag = False | |
self.init_llm() | |
self.init_tag2txt() | |
self.history = [] | |
if not os.path.exists(self.tmp_dir): | |
os.makedirs(self.tmp_dir) | |
def init_models(self): | |
print('\033[1;34m' + "Welcome to the our Vlog toolbox...".center(50, '-') + '\033[0m') | |
print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m') | |
print('\033[1;31m' + "This may time-consuming, please wait...".center(50, '-') + '\033[0m') | |
self.ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory | |
self.feature_extractor = FeatureExtractor(self.args) | |
self.video_segmenter = VideoSegmentor(alpha=self.alpha, beta=self.beta) | |
self.image_captioner = ImageCaptioner(model_name=self.args.captioner_base_model, device=self.args.image_captioner_device) | |
self.dense_captioner = DenseCaptioner(device=self.args.dense_captioner_device) | |
self.audio_translator = AudioTranslator(model=self.args.audio_translator, device=self.args.audio_translator_device) | |
print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m') | |
# 翻译文档 | |
# 初始化 tokenizer 和 model | |
model_name = 'Helsinki-NLP/opus-mt-en-zh' #'Helsinki-NLP/opus-mt-zh-en' | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
self.my_embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-m3', model_kwargs={'device': 'cuda'} ,encode_kwargs={'normalize_embeddings': True}) | |
def init_llm(self): | |
print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m') | |
self.llm = TogetherLLM( | |
model= "microsoft/WizardLM-2-8x22B", | |
temperature=0.1, | |
max_tokens=768 | |
) | |
print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m') | |
if not self.models_flag: | |
self.init_models() | |
self.models_flag = True | |
def init_tag2txt(self): | |
self.transform = get_transform(image_size=384) | |
# delete some tags that may disturb captioning | |
# 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one" | |
delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359] | |
#######load model | |
#self.tag2txt_model = tag2text(pretrained='/root/autodl-tmp/recognize-anything/pretrained/tag2text_swin_14m.pth', | |
# image_size=384, vit='swin_b', delete_tag_index=delete_tag_index) | |
self.ram_model = ram(pretrained='/root/autodl-tmp/recognize-anything/pretrained/ram_swin_large_14m.pth', | |
image_size=384, | |
vit='swin_l') | |
#self.tag2txt_model.threshold = 0.68 # threshold for tagging | |
#self.tag2txt_model.eval() | |
self.ram_model.eval() | |
#self.tag2txt_model = self.tag2txt_model.to(device=self.args.dense_captioner_device) | |
self.ram_model = self.ram_model.to(device=self.args.dense_captioner_device) | |
def exist_videolog(self, video_id): | |
if isinstance(self.data_dir, tuple): | |
self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引 | |
if isinstance(video_id, tuple): | |
video_id = video_id[0] # 或者根据实际情况选择合适的索引 | |
log_path = os.path.join(self.data_dir, f"{video_id}.log") | |
#print(f"log_path: {log_path}\n") | |
if os.path.exists(log_path): | |
#print("existing log path!!!\n") | |
loader = UnstructuredFileLoader(log_path) | |
raw_documents = loader.load() | |
if not raw_documents: | |
print("The log file is empty or could not be loaded.") | |
return False # 如果 raw_documents 为空或所有内容都为空白,直接返回 | |
# Split text | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
chunks = text_splitter.split_documents(raw_documents) | |
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding) | |
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True) | |
return True | |
return False | |
def create_videolog(self, video_id): | |
video_id = os.path.basename(self.video_path).split('.')[0] | |
log_path = os.path.join(self.data_dir, video_id + '.log') | |
loader = UnstructuredFileLoader(log_path) | |
raw_documents = loader.load() | |
# Split text | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
chunks = text_splitter.split_documents(raw_documents) | |
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding) | |
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True) | |
def video2log(self, video_path): | |
self.video_path = video_path | |
video_id = os.path.basename(video_path).split('.')[0] | |
if self.exist_videolog(video_id): | |
return self.printlog(video_id) | |
if not self.models_flag: | |
self.init_models() | |
self.models_flag = True | |
logger = logger_creator(video_id) | |
clip_features, video_length = self.feature_extractor(video_path, video_id) | |
seg_windows = self.video_segmenter(clip_features, video_length) | |
cap = cv2.VideoCapture(video_path) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
audio_results = self.audio_translator(video_path) | |
for start_sec, end_sec in seg_windows: | |
middle_sec = (start_sec + end_sec) // 2 | |
middle_frame_idx = int(middle_sec * fps) | |
cap.set(cv2.CAP_PROP_POS_FRAMES, middle_frame_idx) | |
ret, frame = cap.read() | |
if ret: | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
image_caption = self.image_captioner.image_caption(frame) | |
dense_caption = self.dense_captioner.image_dense_caption(frame) | |
image = self.transform(Image.fromarray(frame)).unsqueeze(0).to(device=self.args.dense_captioner_device) | |
#tag2txt = inference(image, self.tag2txt_model, 'None') | |
ram = inference(image, self.ram_model) | |
audio_transcript = self.audio_translator.match(audio_results, start_sec, end_sec) | |
OCR_result = self.ocr.ocr(frame) | |
# 提取所有文本块中的所有行的文字 | |
texts = [] | |
for block in OCR_result: | |
if block is not None: # 检查 block 是否为 None | |
for line in block: | |
if line is not None: # 检查 line 是否为 None | |
text = line[1][0] # 提取文字部分 | |
texts.append(text) | |
# 将列表中的所有文字合并成一个字符串 | |
OCR_result_str = ' '.join(texts) | |
logger.info(f"When {format_time(start_sec)} - {format_time(end_sec)}") | |
chinese_image_caption = self.translate_text(image_caption, self.tokenizer, self.model) | |
#chinese_tag2txt = self.translate_text(tag2txt[2], self.tokenizer, self.model) | |
chinese_dense_caption = self.translate_text(dense_caption, self.tokenizer, self.model) | |
logger.info(f"我看到这些画面:\"{chinese_image_caption}\"") | |
#logger.info(f"我看见 {chinese_tag2txt}.") | |
logger.info(f"我发现这些内容:\"{chinese_dense_caption}\"") | |
logger.info(f"我检测到这些标签:\"{ram[1]}.\"") | |
logger.info(f"我识别到这些文字:\"{OCR_result_str}\"") | |
if len(audio_transcript) > 0: | |
#english_audio_text = self.translate_text(audio_transcript, self.tokenizer, self.model) | |
logger.info(f"我听到有人说:\"{audio_transcript}\"") | |
logger.info("\n") | |
cap.release() | |
self.create_videolog(video_id) | |
return self.printlog(video_id) | |
def printlog(self, video_id): | |
log_list = [] | |
log_path = os.path.join(self.data_dir, video_id + '.log') | |
with open(log_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
log_list.append(line.strip()) | |
return log_list | |
def translate_text(self, text, tokenizer, model): | |
# 编码文本 | |
encoded_text = tokenizer.prepare_seq2seq_batch([text], return_tensors='pt') | |
# 生成翻译 | |
translated = model.generate(**encoded_text) | |
# 解码翻译后的文本 | |
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0] | |
return translated_text | |
def chat2video(self, question): | |
print(f"Question: {question}") | |
response = self.chain({"question": "请用中文回答:"+question, "chat_history": self.history})['answer'] | |
self.history.append((question, response)) | |
return response | |
def clean_history(self): | |
self.history = [] | |
return | |