Spaces:
Sleeping
Sleeping
import os | |
#import cv2 | |
#import pdb | |
import sys | |
import time | |
import numpy as np | |
#from PIL import Image | |
#from transformers import logging | |
#logging.set_verbosity_error() | |
#from models.kts_model import VideoSegmentor | |
#from models.clip_model import FeatureExtractor | |
#from models.blip2_model import ImageCaptioner | |
#from models.grit_model import DenseCaptioner | |
#from models.whisper_model import AudioTranslator | |
#from models.gpt_model import LlmReasoner | |
from utils import logger_creator, format_time | |
#from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, AutoModelForSeq2SeqLM | |
import together | |
import warnings | |
from together import Together | |
from langchain.prompts import PromptTemplate | |
from langchain.llms.base import LLM | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.llms import HuggingFacePipeline | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.document_loaders import TextLoader | |
from langchain.document_loaders import UnstructuredFileLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain.utils import get_from_dict_or_env | |
from langchain_core.runnables import RunnablePassthrough, RunnableParallel | |
from pydantic.v1 import Extra, Field, root_validator | |
from typing import Any, Dict, List, Mapping, Optional | |
from langchain.memory import ConversationBufferMemory | |
from langchain import LLMChain, PromptTemplate | |
from paddleocr import PaddleOCR, draw_ocr | |
#sys.path.append('/root/autodl-tmp/recognize-anything') | |
#from ram.models import ram | |
#from ram.models import tag2text | |
#from ram import inference_ram as inference | |
#from ram import inference_tag2text as inference | |
#from ram import get_transform | |
warnings.filterwarnings("ignore", category=UserWarning) | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
DEFAULT_SYSTEM_PROMPT = "" | |
instruction = """You are an AI assistant designed for answering questions about a video. | |
You are given a document and a question, the document records what people see and hear from this video. | |
Try to connet these information and provide a conversational answer. | |
Question: {question} | |
========= | |
{context} | |
========= | |
""" | |
system_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. | |
You can assume the discussion is about the video content. | |
Chat History: | |
{chat_history} | |
Follow Up Input: {question} | |
Standalone question:""" | |
#os.environ['HF_HOME'] = '/root/autodl-tmp/cache/' | |
os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af" | |
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ): | |
SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS | |
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST | |
return prompt_template | |
template = get_prompt(instruction, system_prompt) | |
prompt = PromptTemplate( | |
input_variables=["chat_history", "user_input"], template=template | |
) | |
class TogetherLLM(LLM): | |
"""Together large language models.""" | |
model: str = "togethercomputer/llama-2-70b-chat" | |
"""model endpoint to use""" | |
together_api_key: str = os.environ["TOGETHER_API_KEY"] | |
"""Together API key""" | |
temperature: float = 0.7 | |
"""What sampling temperature to use.""" | |
max_tokens: int = 512 | |
"""The maximum number of tokens to generate in the completion.""" | |
class Config: | |
extra = Extra.forbid | |
def validate_environment(cls, values: Dict) -> Dict: | |
"""Validate that the API key is set.""" | |
api_key = get_from_dict_or_env( | |
values, "together_api_key", "TOGETHER_API_KEY" | |
) | |
values["together_api_key"] = api_key | |
return values | |
def _llm_type(self) -> str: | |
"""Return type of LLM.""" | |
return "together" | |
def _call( | |
self, | |
prompt: str, | |
**kwargs: Any, | |
) -> str: | |
"""Call to Together endpoint.""" | |
together.api_key = self.together_api_key | |
output = together.Complete.create(prompt, | |
model=self.model, | |
max_tokens=self.max_tokens, | |
temperature=self.temperature, | |
top_p=0.7, | |
top_k=50, | |
repetition_penalty=1, | |
stop=["</s>"], | |
) | |
text = output['choices'][0]['text'] | |
return text | |
class Vlogger4chat : | |
def __init__(self, args): | |
self.args = args | |
self.alpha = args.alpha | |
self.beta = args.beta | |
self.data_dir = args.data_dir, | |
self.tmp_dir = args.tmp_dir | |
self.models_flag = False | |
self.init_llm() | |
self.history = [] | |
if not os.path.exists(self.tmp_dir): | |
os.makedirs(self.tmp_dir) | |
def init_llm(self): | |
print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m') | |
self.llm = TogetherLLM( | |
model= "microsoft/WizardLM-2-8x22B", | |
temperature=0.1, | |
max_tokens=768 | |
) | |
print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m') | |
def exist_videolog(self, video_id): | |
if isinstance(self.data_dir, tuple): | |
self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引 | |
if isinstance(video_id, tuple): | |
video_id = video_id[0] # 或者根据实际情况选择合适的索引 | |
log_path = os.path.join(self.data_dir, f"{video_id}.log") | |
#print(f"log_path: {log_path}\n") | |
if os.path.exists(log_path): | |
#print("existing log path!!!\n") | |
loader = UnstructuredFileLoader(log_path) | |
raw_documents = loader.load() | |
if not raw_documents: | |
print("The log file is empty or could not be loaded.") | |
return False # 如果 raw_documents 为空或所有内容都为空白,直接返回 | |
# Split text | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
chunks = text_splitter.split_documents(raw_documents) | |
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding) | |
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True) | |
return True | |
return False | |
def create_videolog(self, video_id): | |
video_id = os.path.basename(self.video_path).split('.')[0] | |
log_path = os.path.join(self.data_dir, video_id + '.log') | |
loader = UnstructuredFileLoader(log_path) | |
raw_documents = loader.load() | |
# Split text | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
chunks = text_splitter.split_documents(raw_documents) | |
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding) | |
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True) | |
def video2log(self, video_path): | |
self.video_path = video_path | |
video_id = os.path.basename(video_path).split('.')[0] | |
if self.exist_videolog(video_id): | |
return self.printlog(video_id) | |
return self.printlog(video_id) | |
def printlog(self, video_id): | |
log_list = [] | |
log_path = os.path.join(self.data_dir, video_id + '.log') | |
with open(log_path, 'r', encoding='utf-8') as f: | |
for line in f: | |
log_list.append(line.strip()) | |
return log_list | |
def chat2video(self, question): | |
print(f"Question: {question}") | |
response = self.chain({"question": "请用中文回答:"+question, "chat_history": self.history})['answer'] | |
self.history.append((question, response)) | |
return response | |
def clean_history(self): | |
self.history = [] | |
return | |