dj86 commited on
Commit
fbfd07e
1 Parent(s): 656ef42

Upload vlog4chat.py

Browse files
Files changed (1) hide show
  1. vlog4chat.py +214 -0
vlog4chat.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import pickle
5
+ import numpy as np
6
+ from utils import logger_creator, format_time
7
+
8
+ #import together
9
+ import warnings
10
+ from together import Together
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain.llms.base import LLM
13
+ from langchain.chains import ConversationalRetrievalChain
14
+ from langchain_core.output_parsers import StrOutputParser
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ #from langchain.llms import HuggingFacePipeline
17
+ from langchain_community.llms import HuggingFacePipeline
18
+ #from langchain.embeddings import HuggingFaceEmbeddings
19
+ #from langchain_community.embeddings import HuggingFaceEmbeddings
20
+ from langchain_huggingface import HuggingFaceEmbeddings
21
+ from langchain_community.document_loaders import TextLoader
22
+ #from langchain.document_loaders import UnstructuredFileLoader
23
+ from langchain_community.document_loaders import UnstructuredFileLoader
24
+ from langchain_community.vectorstores import FAISS
25
+ from langchain.utils import get_from_dict_or_env
26
+ from langchain_core.runnables import RunnablePassthrough, RunnableParallel
27
+ from pydantic.v1 import Extra, Field, root_validator
28
+ from typing import Any, Dict, List, Mapping, Optional
29
+ from langchain.memory import ConversationBufferMemory
30
+ from langchain import LLMChain, PromptTemplate
31
+
32
+ warnings.filterwarnings("ignore", category=UserWarning)
33
+ B_INST, E_INST = "[INST]", "[/INST]"
34
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
35
+
36
+ DEFAULT_SYSTEM_PROMPT = ""
37
+
38
+ instruction = """You are an AI assistant designed for answering questions about a video.
39
+ You are given a document and a question, the document records what people see and hear from this video.
40
+ Try to connet these information and provide a conversational answer.
41
+ Question: {question}
42
+ =========
43
+ {context}
44
+ =========
45
+ """
46
+
47
+ system_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
48
+ You can assume the discussion is about the video content.
49
+ Chat History:
50
+ {chat_history}
51
+ Follow Up Input: {question}
52
+ Standalone question:"""
53
+
54
+ def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
55
+ SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
56
+ prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
57
+ return prompt_template
58
+
59
+ template = get_prompt(instruction, system_prompt)
60
+
61
+ prompt = PromptTemplate(
62
+ input_variables=["chat_history", "user_input"], template=template
63
+ )
64
+
65
+ class TogetherLLM(LLM):
66
+ """Together large language models."""
67
+
68
+ model: str = "togethercomputer/llama-2-70b-chat"
69
+ """model endpoint to use"""
70
+
71
+ together_api_key: str = os.environ["TOGETHER_API_KEY"]
72
+ """Together API key"""
73
+
74
+ temperature: float = 0.7
75
+ """What sampling temperature to use."""
76
+
77
+ max_tokens: int = 512
78
+ """The maximum number of tokens to generate in the completion."""
79
+
80
+ class Config:
81
+ extra = Extra.forbid
82
+
83
+ @root_validator()
84
+ def validate_environment(cls, values: Dict) -> Dict:
85
+ """Validate that the API key is set."""
86
+ api_key = get_from_dict_or_env(
87
+ values, "together_api_key", "TOGETHER_API_KEY"
88
+ )
89
+ values["together_api_key"] = api_key
90
+ return values
91
+
92
+ @property
93
+ def _llm_type(self) -> str:
94
+ """Return type of LLM."""
95
+ return "together"
96
+
97
+ def _call(
98
+ self,
99
+ prompt: str,
100
+ **kwargs: Any,
101
+ ) -> str:
102
+ """Call to Together endpoint."""
103
+ client = Together(api_key=self.together_api_key)
104
+ messages = [
105
+ {
106
+ "role": "user",
107
+ "content": prompt
108
+ },
109
+ {
110
+ "role": "assistant",
111
+ "content": ""
112
+ }
113
+ ]
114
+
115
+ # Ensure the 'prompt' is passed as part of a structured 'messages' list
116
+ output = client.chat.completions.create(
117
+ model=self.model,
118
+ messages=messages,
119
+ max_tokens=512,
120
+ temperature=0.7,
121
+ top_p=0.7,
122
+ top_k=50,
123
+ repetition_penalty=1,
124
+ stop=["</s>"]
125
+ )
126
+ text = output.choices[0].message.content
127
+ print(f"Answer: {text}")
128
+ return text
129
+
130
+ class Vlogger4chat :
131
+ def __init__(self, args):
132
+ self.args = args
133
+ self.alpha = args.alpha
134
+ self.beta = args.beta
135
+ self.data_dir = args.data_dir,
136
+ self.tmp_dir = args.tmp_dir
137
+ self.models_flag = False
138
+ self.init_llm()
139
+ self.history = []
140
+
141
+ self.my_embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-m3', model_kwargs={'device': 'cpu'} ,encode_kwargs={'normalize_embeddings': True})
142
+
143
+ def init_llm(self):
144
+ print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
145
+ self.llm = TogetherLLM(
146
+ model= os.getenv("YOUR_MODEL_NAME"),
147
+ temperature=0.1,
148
+ max_tokens=768
149
+ )
150
+ print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
151
+
152
+ def exist_videolog(self, video_id):
153
+ if isinstance(self.data_dir, tuple):
154
+ self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引
155
+ if isinstance(video_id, tuple):
156
+ video_id = video_id[0] # 或者根据实际情况选择合适的索引
157
+ log_path = os.path.join(self.data_dir, f"{video_id}.log")
158
+ #print(f"log_path: {log_path}\n")
159
+
160
+ if os.path.exists(log_path):
161
+ #print("existing log path!!!\n")
162
+ loader = UnstructuredFileLoader(log_path)
163
+ raw_documents = loader.load()
164
+ if not raw_documents:
165
+ print("The log file is empty or could not be loaded.")
166
+ return False # 如果 raw_documents 为空或所有内容都为空白,直接返回
167
+ # Split text
168
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
169
+ chunks = text_splitter.split_documents(raw_documents)
170
+ self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
171
+ self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
172
+ return True
173
+ return False
174
+
175
+ def create_videolog(self, video_id):
176
+ video_id = os.path.basename(self.video_path).split('.')[0]
177
+ log_path = os.path.join(self.data_dir, video_id + '.log')
178
+ loader = UnstructuredFileLoader(log_path)
179
+ raw_documents = loader.load()
180
+ # Split text
181
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
182
+ chunks = text_splitter.split_documents(raw_documents)
183
+ self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
184
+ self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
185
+
186
+ def video2log(self, video_path):
187
+ self.video_path = video_path
188
+ video_id = os.path.basename(video_path).split('.')[0]
189
+
190
+ if self.exist_videolog(video_id):
191
+ return self.printlog(video_id)
192
+
193
+ return self.printlog(video_id)
194
+
195
+ def printlog(self, video_id):
196
+ log_list = []
197
+ log_path = os.path.join(self.data_dir, video_id + '.log')
198
+ with open(log_path, 'r', encoding='utf-8') as f:
199
+ for line in f:
200
+ log_list.append(line.strip())
201
+ return log_list
202
+
203
+ def chat2video(self, question):
204
+ print(f"Question: {question}")
205
+
206
+ #response = self.chain({"question": "请用中文回答:"+question, "chat_history": self.history})['answer']
207
+ response = self.chain.invoke({"question": "请用中文回答:" + question, "chat_history": self.history})['answer']
208
+
209
+ self.history.append((question, response))
210
+ return response
211
+
212
+ def clean_history(self):
213
+ self.history = []
214
+ return