Spaces:
Sleeping
Sleeping
Upload vlog4chat.py
Browse files- vlog4chat.py +214 -0
vlog4chat.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import pickle
|
5 |
+
import numpy as np
|
6 |
+
from utils import logger_creator, format_time
|
7 |
+
|
8 |
+
#import together
|
9 |
+
import warnings
|
10 |
+
from together import Together
|
11 |
+
from langchain.prompts import PromptTemplate
|
12 |
+
from langchain.llms.base import LLM
|
13 |
+
from langchain.chains import ConversationalRetrievalChain
|
14 |
+
from langchain_core.output_parsers import StrOutputParser
|
15 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16 |
+
#from langchain.llms import HuggingFacePipeline
|
17 |
+
from langchain_community.llms import HuggingFacePipeline
|
18 |
+
#from langchain.embeddings import HuggingFaceEmbeddings
|
19 |
+
#from langchain_community.embeddings import HuggingFaceEmbeddings
|
20 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
21 |
+
from langchain_community.document_loaders import TextLoader
|
22 |
+
#from langchain.document_loaders import UnstructuredFileLoader
|
23 |
+
from langchain_community.document_loaders import UnstructuredFileLoader
|
24 |
+
from langchain_community.vectorstores import FAISS
|
25 |
+
from langchain.utils import get_from_dict_or_env
|
26 |
+
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
|
27 |
+
from pydantic.v1 import Extra, Field, root_validator
|
28 |
+
from typing import Any, Dict, List, Mapping, Optional
|
29 |
+
from langchain.memory import ConversationBufferMemory
|
30 |
+
from langchain import LLMChain, PromptTemplate
|
31 |
+
|
32 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
33 |
+
B_INST, E_INST = "[INST]", "[/INST]"
|
34 |
+
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
35 |
+
|
36 |
+
DEFAULT_SYSTEM_PROMPT = ""
|
37 |
+
|
38 |
+
instruction = """You are an AI assistant designed for answering questions about a video.
|
39 |
+
You are given a document and a question, the document records what people see and hear from this video.
|
40 |
+
Try to connet these information and provide a conversational answer.
|
41 |
+
Question: {question}
|
42 |
+
=========
|
43 |
+
{context}
|
44 |
+
=========
|
45 |
+
"""
|
46 |
+
|
47 |
+
system_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
|
48 |
+
You can assume the discussion is about the video content.
|
49 |
+
Chat History:
|
50 |
+
{chat_history}
|
51 |
+
Follow Up Input: {question}
|
52 |
+
Standalone question:"""
|
53 |
+
|
54 |
+
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
|
55 |
+
SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
|
56 |
+
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
|
57 |
+
return prompt_template
|
58 |
+
|
59 |
+
template = get_prompt(instruction, system_prompt)
|
60 |
+
|
61 |
+
prompt = PromptTemplate(
|
62 |
+
input_variables=["chat_history", "user_input"], template=template
|
63 |
+
)
|
64 |
+
|
65 |
+
class TogetherLLM(LLM):
|
66 |
+
"""Together large language models."""
|
67 |
+
|
68 |
+
model: str = "togethercomputer/llama-2-70b-chat"
|
69 |
+
"""model endpoint to use"""
|
70 |
+
|
71 |
+
together_api_key: str = os.environ["TOGETHER_API_KEY"]
|
72 |
+
"""Together API key"""
|
73 |
+
|
74 |
+
temperature: float = 0.7
|
75 |
+
"""What sampling temperature to use."""
|
76 |
+
|
77 |
+
max_tokens: int = 512
|
78 |
+
"""The maximum number of tokens to generate in the completion."""
|
79 |
+
|
80 |
+
class Config:
|
81 |
+
extra = Extra.forbid
|
82 |
+
|
83 |
+
@root_validator()
|
84 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
85 |
+
"""Validate that the API key is set."""
|
86 |
+
api_key = get_from_dict_or_env(
|
87 |
+
values, "together_api_key", "TOGETHER_API_KEY"
|
88 |
+
)
|
89 |
+
values["together_api_key"] = api_key
|
90 |
+
return values
|
91 |
+
|
92 |
+
@property
|
93 |
+
def _llm_type(self) -> str:
|
94 |
+
"""Return type of LLM."""
|
95 |
+
return "together"
|
96 |
+
|
97 |
+
def _call(
|
98 |
+
self,
|
99 |
+
prompt: str,
|
100 |
+
**kwargs: Any,
|
101 |
+
) -> str:
|
102 |
+
"""Call to Together endpoint."""
|
103 |
+
client = Together(api_key=self.together_api_key)
|
104 |
+
messages = [
|
105 |
+
{
|
106 |
+
"role": "user",
|
107 |
+
"content": prompt
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"role": "assistant",
|
111 |
+
"content": ""
|
112 |
+
}
|
113 |
+
]
|
114 |
+
|
115 |
+
# Ensure the 'prompt' is passed as part of a structured 'messages' list
|
116 |
+
output = client.chat.completions.create(
|
117 |
+
model=self.model,
|
118 |
+
messages=messages,
|
119 |
+
max_tokens=512,
|
120 |
+
temperature=0.7,
|
121 |
+
top_p=0.7,
|
122 |
+
top_k=50,
|
123 |
+
repetition_penalty=1,
|
124 |
+
stop=["</s>"]
|
125 |
+
)
|
126 |
+
text = output.choices[0].message.content
|
127 |
+
print(f"Answer: {text}")
|
128 |
+
return text
|
129 |
+
|
130 |
+
class Vlogger4chat :
|
131 |
+
def __init__(self, args):
|
132 |
+
self.args = args
|
133 |
+
self.alpha = args.alpha
|
134 |
+
self.beta = args.beta
|
135 |
+
self.data_dir = args.data_dir,
|
136 |
+
self.tmp_dir = args.tmp_dir
|
137 |
+
self.models_flag = False
|
138 |
+
self.init_llm()
|
139 |
+
self.history = []
|
140 |
+
|
141 |
+
self.my_embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-m3', model_kwargs={'device': 'cpu'} ,encode_kwargs={'normalize_embeddings': True})
|
142 |
+
|
143 |
+
def init_llm(self):
|
144 |
+
print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
|
145 |
+
self.llm = TogetherLLM(
|
146 |
+
model= os.getenv("YOUR_MODEL_NAME"),
|
147 |
+
temperature=0.1,
|
148 |
+
max_tokens=768
|
149 |
+
)
|
150 |
+
print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
|
151 |
+
|
152 |
+
def exist_videolog(self, video_id):
|
153 |
+
if isinstance(self.data_dir, tuple):
|
154 |
+
self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引
|
155 |
+
if isinstance(video_id, tuple):
|
156 |
+
video_id = video_id[0] # 或者根据实际情况选择合适的索引
|
157 |
+
log_path = os.path.join(self.data_dir, f"{video_id}.log")
|
158 |
+
#print(f"log_path: {log_path}\n")
|
159 |
+
|
160 |
+
if os.path.exists(log_path):
|
161 |
+
#print("existing log path!!!\n")
|
162 |
+
loader = UnstructuredFileLoader(log_path)
|
163 |
+
raw_documents = loader.load()
|
164 |
+
if not raw_documents:
|
165 |
+
print("The log file is empty or could not be loaded.")
|
166 |
+
return False # 如果 raw_documents 为空或所有内容都为空白,直接返回
|
167 |
+
# Split text
|
168 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
169 |
+
chunks = text_splitter.split_documents(raw_documents)
|
170 |
+
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
|
171 |
+
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
|
172 |
+
return True
|
173 |
+
return False
|
174 |
+
|
175 |
+
def create_videolog(self, video_id):
|
176 |
+
video_id = os.path.basename(self.video_path).split('.')[0]
|
177 |
+
log_path = os.path.join(self.data_dir, video_id + '.log')
|
178 |
+
loader = UnstructuredFileLoader(log_path)
|
179 |
+
raw_documents = loader.load()
|
180 |
+
# Split text
|
181 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
182 |
+
chunks = text_splitter.split_documents(raw_documents)
|
183 |
+
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
|
184 |
+
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
|
185 |
+
|
186 |
+
def video2log(self, video_path):
|
187 |
+
self.video_path = video_path
|
188 |
+
video_id = os.path.basename(video_path).split('.')[0]
|
189 |
+
|
190 |
+
if self.exist_videolog(video_id):
|
191 |
+
return self.printlog(video_id)
|
192 |
+
|
193 |
+
return self.printlog(video_id)
|
194 |
+
|
195 |
+
def printlog(self, video_id):
|
196 |
+
log_list = []
|
197 |
+
log_path = os.path.join(self.data_dir, video_id + '.log')
|
198 |
+
with open(log_path, 'r', encoding='utf-8') as f:
|
199 |
+
for line in f:
|
200 |
+
log_list.append(line.strip())
|
201 |
+
return log_list
|
202 |
+
|
203 |
+
def chat2video(self, question):
|
204 |
+
print(f"Question: {question}")
|
205 |
+
|
206 |
+
#response = self.chain({"question": "请用中文回答:"+question, "chat_history": self.history})['answer']
|
207 |
+
response = self.chain.invoke({"question": "请用中文回答:" + question, "chat_history": self.history})['answer']
|
208 |
+
|
209 |
+
self.history.append((question, response))
|
210 |
+
return response
|
211 |
+
|
212 |
+
def clean_history(self):
|
213 |
+
self.history = []
|
214 |
+
return
|