VLog4CustomLLMsPlusQA-3 / vlog4chat.py
dj86's picture
Update vlog4chat.py
c59556c verified
raw
history blame
8.04 kB
import os
import sys
import time
import pickle
import numpy as np
from utils import logger_creator, format_time
import together
import warnings
from together import Together
from langchain.prompts import PromptTemplate
from langchain.llms.base import LLM
from langchain.chains import ConversationalRetrievalChain
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
#from langchain.llms import HuggingFacePipeline
from langchain_community.llms import HuggingFacePipeline
#from langchain.embeddings import HuggingFaceEmbeddings
#from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
#from langchain.document_loaders import UnstructuredFileLoader
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_community.vectorstores import FAISS
from langchain.utils import get_from_dict_or_env
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from pydantic.v1 import Extra, Field, root_validator
from typing import Any, Dict, List, Mapping, Optional
from langchain.memory import ConversationBufferMemory
from langchain import LLMChain, PromptTemplate
warnings.filterwarnings("ignore", category=UserWarning)
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = ""
instruction = """You are an AI assistant designed for answering questions about a video.
You are given a document and a question, the document records what people see and hear from this video.
Try to connet these information and provide a conversational answer.
Question: {question}
=========
{context}
=========
"""
system_prompt = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
You can assume the discussion is about the video content.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
os.environ["TOGETHER_API_KEY"] = "48bf2536f85b599c7d5d7f9921cc9ee7056f40ed535fd2174d061e1b9abcf8af"
def get_prompt(instruction, new_system_prompt=DEFAULT_SYSTEM_PROMPT ):
SYSTEM_PROMPT = B_SYS + new_system_prompt + E_SYS
prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
return prompt_template
template = get_prompt(instruction, system_prompt)
prompt = PromptTemplate(
input_variables=["chat_history", "user_input"], template=template
)
class TogetherLLM(LLM):
"""Together large language models."""
model: str = "togethercomputer/llama-2-70b-chat"
"""model endpoint to use"""
together_api_key: str = os.environ["TOGETHER_API_KEY"]
"""Together API key"""
temperature: float = 0.7
"""What sampling temperature to use."""
max_tokens: int = 512
"""The maximum number of tokens to generate in the completion."""
class Config:
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the API key is set."""
api_key = get_from_dict_or_env(
values, "together_api_key", "TOGETHER_API_KEY"
)
values["together_api_key"] = api_key
return values
@property
def _llm_type(self) -> str:
"""Return type of LLM."""
return "together"
def _call(
self,
prompt: str,
**kwargs: Any,
) -> str:
"""Call to Together endpoint."""
together.api_key = self.together_api_key
output = together.Complete.create(prompt,
model=self.model,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=0.7,
top_k=50,
repetition_penalty=1,
stop=["</s>"],
)
text = output['choices'][0]['text']
return text
class Vlogger4chat :
def __init__(self, args):
self.args = args
self.alpha = args.alpha
self.beta = args.beta
self.data_dir = args.data_dir,
self.tmp_dir = args.tmp_dir
self.models_flag = False
self.init_llm()
self.history = []
self.my_embedding = HuggingFaceEmbeddings(model_name='BAAI/bge-m3', model_kwargs={'device': 'cpu'} ,encode_kwargs={'normalize_embeddings': True})
def init_llm(self):
print('\033[1;33m' + "Initializing LLM Reasoner...".center(50, '-') + '\033[0m')
self.llm = TogetherLLM(
model= "microsoft/WizardLM-2-8x22B",
temperature=0.1,
max_tokens=768
)
print('\033[1;32m' + "LLM initialization finished!".center(50, '-') + '\033[0m')
def exist_videolog(self, video_id):
if isinstance(self.data_dir, tuple):
self.data_dir = self.data_dir[0] # 或者根据实际情况选择合适的索引
if isinstance(video_id, tuple):
video_id = video_id[0] # 或者根据实际情况选择合适的索引
log_path = os.path.join(self.data_dir, f"{video_id}.log")
#print(f"log_path: {log_path}\n")
if os.path.exists(log_path):
#print("existing log path!!!\n")
loader = UnstructuredFileLoader(log_path)
raw_documents = loader.load()
if not raw_documents:
print("The log file is empty or could not be loaded.")
return False # 如果 raw_documents 为空或所有内容都为空白,直接返回
# Split text
#text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
#chunks = text_splitter.split_documents(raw_documents)
#self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
pkl_path = os.path.join(self.data_dir, f"{video_id}.pkl")
with open(pkl_path, 'rb') as f:
vector_storage =pickle.load(f)
self.vector_storage = map_to_cpu(vector_storage)
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
return True
return False
def create_videolog(self, video_id):
video_id = os.path.basename(self.video_path).split('.')[0]
log_path = os.path.join(self.data_dir, video_id + '.log')
loader = UnstructuredFileLoader(log_path)
raw_documents = loader.load()
# Split text
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
chunks = text_splitter.split_documents(raw_documents)
self.vector_storage = FAISS.from_documents(chunks, self.my_embedding)
self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vector_storage.as_retriever(), return_source_documents=True)
def video2log(self, video_path):
self.video_path = video_path
video_id = os.path.basename(video_path).split('.')[0]
if self.exist_videolog(video_id):
return self.printlog(video_id)
return self.printlog(video_id)
def printlog(self, video_id):
log_list = []
log_path = os.path.join(self.data_dir, video_id + '.log')
with open(log_path, 'r', encoding='utf-8') as f:
for line in f:
log_list.append(line.strip())
return log_list
def chat2video(self, question):
print(f"Question: {question}")
response = self.chain({"question": "请用中文回答:"+question, "chat_history": self.history})['answer']
self.history.append((question, response))
return response
def clean_history(self):
self.history = []
return