File size: 7,815 Bytes
62a001a
 
 
8d42871
62a001a
075c157
62a001a
 
 
 
 
 
 
 
 
3014fbb
 
cc6b261
7b610c1
 
62a001a
d6743ee
 
62a001a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f1ad13
3014fbb
9b2ec19
62a001a
 
 
 
 
 
 
 
 
 
 
 
 
 
8eabbf8
62a001a
 
 
 
 
 
 
 
 
 
c87fef7
 
 
62a001a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import sys
import time
import pickle
import numpy as np
from utils import logger_creator, format_time

import together
import warnings
from together import Together
from langchain.prompts import PromptTemplate
from langchain.llms.base import LLM
from langchain.chains import ConversationalRetrievalChain
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
#from langchain.llms import HuggingFacePipeline
from langchain_community.llms import HuggingFacePipeline
#from langchain.embeddings import HuggingFaceEmbeddings
#from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
#from langchain.document_loaders import UnstructuredFileLoader
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_community.vectorstores import FAISS
from langchain.utils import get_from_dict_or_env
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from pydantic.v1 import Extra, Field, root_validator
from typing import Any, Dict, List, Mapping, Optional
from langchain.memory import ConversationBufferMemory
from langchain import LLMChain, PromptTemplate

warnings.filterwarnings("ignore", category=UserWarning)
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"

DEFAULT_SYSTEM_PROMPT = ""

instruction = """You are an AI assistant designed for answering questions about a video.
You are given a document and a question, the document records what people see and hear from this video.
Try to connet these information and provide a conversational answer.
Question: {question}
=========
{context}
=========
"""

system_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
You can assume the discussion is about the video content.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""

os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"

def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
    SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
    prompt_template =  B_INST + SYSTEM_PROMPT + instruction + E_INST
    return prompt_template

template = get_prompt(instruction, system_prompt)

prompt = PromptTemplate(
    input_variables=["chat_history", "user_input"], template=template
)

class TogetherLLM(LLM):
    """Together large language models."""

    model: str = "togethercomputer/llama-2-70b-chat"
    """model endpoint to use"""

    together_api_key: str = os.environ["TOGETHER_API_KEY"]
    """Together API key"""

    temperature: float = 0.7
    """What sampling temperature to use."""

    max_tokens: int = 512
    """The maximum number of tokens to generate in the completion."""

    class Config:
        extra = Extra.forbid

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that the API key is set."""
        api_key = get_from_dict_or_env(
            values, "together_api_key", "TOGETHER_API_KEY"
        )
        values["together_api_key"] = api_key
        return values

    @property
    def _llm_type(self) -> str:
        """Return type of LLM."""
        return "together"

    def _call(
        self,
        prompt: str,
        **kwargs: Any,
    ) -> str:
        """Call to Together endpoint."""
        together.api_key = self.together_api_key
        output = together.Complete.create(prompt,
                                          model=self.model,
                                          max_tokens=self.max_tokens,
                                          temperature=self.temperature,
                                          top_p=0.7,
                                          top_k=50,
                                          repetition_penalty=1,
                                          stop=["</s>"],
                                          )
        text = output['choices'][0]['text']
        return text

class Vlogger4chat :
    def __init__(self, args):
        self.args = args
        self.alpha = args.alpha
        self.beta = args.beta
        self.data_dir = args.data_dir,
        self.tmp_dir = args.tmp_dir
        self.models_flag = False
        self.init_llm()
        self.history = []

        self.my_embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-m3', model_kwargs={'device': 'cpu'} ,encode_kwargs={'normalize_embeddings': True}) 
            
    def init_llm(self):
        print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
        self.llm = TogetherLLM(
            model= "microsoft/WizardLM-2-8x22B",
            temperature=0.1,
            max_tokens=768
        )
        print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
            
    def exist_videolog(self, video_id):
        if isinstance(self.data_dir, tuple):
            self.data_dir = self.data_dir[0]  # 或者根据实际情况选择合适的索引
        if isinstance(video_id, tuple):
            video_id = video_id[0]  # 或者根据实际情况选择合适的索引
        log_path = os.path.join(self.data_dir, f"{video_id}.log")
        #print(f"log_path: {log_path}\n")

        if os.path.exists(log_path):
            #print("existing log path!!!\n")
            loader = UnstructuredFileLoader(log_path)
            raw_documents = loader.load()
            if not raw_documents:
                print("The log file is empty or could not be loaded.")
                return False # 如果 raw_documents 为空或所有内容都为空白,直接返回
            # Split text
            text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
            chunks = text_splitter.split_documents(raw_documents)
            self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)    
            self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
            return True
        return False
    
    def create_videolog(self, video_id):
        video_id = os.path.basename(self.video_path).split('.')[0]
        log_path = os.path.join(self.data_dir, video_id + '.log')
        loader = UnstructuredFileLoader(log_path)
        raw_documents = loader.load()
        # Split text
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
        chunks = text_splitter.split_documents(raw_documents)
        self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
        self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
        
    def video2log(self, video_path): 
        self.video_path = video_path
        video_id = os.path.basename(video_path).split('.')[0]  
        
        if self.exist_videolog(video_id):
            return self.printlog(video_id)

        return self.printlog(video_id)
        
    def printlog(self, video_id):
        log_list = []
        log_path = os.path.join(self.data_dir, video_id + '.log')
        with open(log_path, 'r', encoding='utf-8') as f:
            for line in f:
                log_list.append(line.strip())
        return log_list

    def chat2video(self, question):
        print(f"Question: {question}")
        
        response = self.chain({"question": "请用中文回答:"+question, "chat_history": self.history})['answer']
        self.history.append((question, response))
        return response

    def clean_history(self):
        self.history = []
        return