diff --git a/app.py b/app.py index 5cc265e4965b21fd51f9accf7aafbebd89067761..f0e7b442ce0f4bfa76d4e55dc085c1820e97910a 100644 --- a/app.py +++ b/app.py @@ -1,14 +1,485 @@ +import torch +import os +import argparse +import numpy as np +import copy import gradio as gr +import re +import torchaudio +import io +import cv2 +import math import spaces -import torch +from numba import jit +from huggingface_hub import snapshot_download + +from vita.constants import DEFAULT_AUDIO_TOKEN, DEFAULT_IMAGE_TOKEN, MAX_IMAGE_LENGTH, MIN_IMAGE_LENGTH, IMAGE_TOKEN_INDEX, AUDIO_TOKEN_INDEX +from vita.conversation import conv_templates, SeparatorStyle +from vita.util.mm_utils import tokenizer_image_token, tokenizer_image_audio_token +from PIL import Image +from decord import VideoReader, cpu +from vita.model.builder import load_pretrained_model +from vita.model.vita_tts.decoder.llm2tts import llm2TTS +from vita.model.language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM + +decoder_topk = 2 +codec_chunk_size = 40 +codec_padding_size = 10 + +PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛""„‟…‧﹏." + +MODEL_NAME = "VITA-MLLM/VITA-1.5" +model_path = snapshot_download(MODEL_NAME, local_dir="VITA_ckpt") +tokenizer, model, feature_extractor, context_len = load_pretrained_model( + model_path, model_base=None, model_name="VITA-1.5", model_type="qwen2p5_instruct" +) +llm_embedding = model.get_input_embeddings().cuda() +tts = llm2TTS(os.path.join(model_path, 'vita_tts_ckpt/')) + +@jit +def float_to_int16(audio: np.ndarray) -> np.ndarray: + am = int(math.ceil(float(np.abs(audio).max())) * 32768) + am = 32767 * 32768 // am + return np.multiply(audio, am).astype(np.int16) + + +def remove_special_characters(input_str): + # Remove special tokens + special_tokens = ['☞', '☟', '☜', '', '<|im_end|>'] + for token in special_tokens: + input_str = input_str.replace(token, '') + return input_str + + +def replace_equation(sentence): + special_notations = { + "sin": " sine ", + "cos": " cosine ", + "tan": " tangent ", + "cot": " cotangent ", + "sec": " secant ", + "csc": " cosecant ", + "log": " logarithm ", + "exp": "e^", + "sqrt": "根号 ", + "abs": "绝对值 ", + } + + special_operators = { + "+": "加", + "-": "减", + "*": "乘", + "/": "除", + "=": "等于", + '!=': '不等于', + '>': '大于', + '<': '小于', + '>=': '大于等于', + '<=': '小于等于', + } + + greek_letters = { + "α": "alpha ", + "β": "beta ", + "γ": "gamma ", + "δ": "delta ", + "ε": "epsilon ", + "ζ": "zeta ", + "η": "eta ", + "θ": "theta ", + "ι": "iota ", + "κ": "kappa ", + "λ": "lambda ", + "μ": "mu ", + "ν": "nu ", + "ξ": "xi ", + "ο": "omicron ", + "π": "派 ", + "ρ": "rho ", + "σ": "sigma ", + "τ": "tau ", + "υ": "upsilon ", + "φ": "phi ", + "χ": "chi ", + "ψ": "psi ", + "ω": "omega " + } + + sentence = sentence.replace('**', ' ') + + sentence = re.sub(r'(? end_time: + start_time, end_time = end_time, start_time + elif start_time == end_time: + end_time = start_time + 1 + + if os.path.exists(video_path): + vreader = VideoReader(video_path, ctx=cpu(0)) + else: + raise FileNotFoundError + + fps = vreader.get_avg_fps() + f_start = 0 if start_time is None else int(start_time * fps) + f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1)) + num_frames = f_end - f_start + 1 + + if num_frames > 0: + sample_fps = int(video_framerate) + t_stride = int(round(float(fps) / sample_fps)) + all_pos = list(range(f_start, f_end + 1, t_stride)) + + if len(all_pos) > max_frames: + sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)] + elif len(all_pos) < min_frames: + sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)] + else: + sample_pos = all_pos + + patch_images = [Image.fromarray(f).convert("RGB") for f in vreader.get_batch(sample_pos).asnumpy()] + return patch_images, len(patch_images) + else: + print(f"video path: {video_path} error.") -zero = torch.Tensor([0]).cuda() -print(zero.device) # <-- 'cpu' 🤔 +def _parse_text(text): + lines = text.split("\n") + lines = [line for line in lines if line != ""] + count = 0 + + for i, line in enumerate(lines): + if "```" in line: + count += 1 + items = line.split("`") + if count % 2 == 1: + lines[i] = f'
'
+            else:
+                lines[i] = "
" + else: + if i > 0 and count % 2 == 1: + line = line.replace("`", r"\`") + line = line.replace("<", "<") + line = line.replace(">", ">") + line = line.replace(" ", " ") + line = line.replace("*", "*") + line = line.replace("_", "_") + line = line.replace("-", "-") + line = line.replace(".", ".") + line = line.replace("!", "!") + line = line.replace("(", "(") + line = line.replace(")", ")") + line = line.replace("$", "$") + lines[i] = "
" + line + + return "".join(lines) + + +@spaces.GPU +def predict(_chatbot, task_history): + chat_query = task_history[-1][0] + print(task_history) + + conv_mode = "qwen2p5_instruct" + conv = conv_templates[conv_mode].copy() + + all_audio_path = [] + all_visual_tensor = [] + + qs = '' + input_mode = 'lang' + for i, (q, a) in enumerate(task_history): + if isinstance(q, (tuple, list)): + if is_image(q[0]): + images = [Image.open(q[0]).convert("RGB")] + all_visual_tensor.extend(images) + input_mode = 'image' + qs += DEFAULT_IMAGE_TOKEN * len(images) + '\n' + elif is_video(q[0]): + video_frames, slice_len = _get_rawvideo_dec(q[0]) + all_visual_tensor.extend(video_frames) + input_mode = 'video' + qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n' + elif is_wav(q[0]): + if a is not None and a.startswith('☜'): + continue + else: + all_audio_path.append(q[0]) + new_q = qs + DEFAULT_AUDIO_TOKEN + qs = '' + conv.append_message(conv.roles[0], new_q) + conv.append_message(conv.roles[1], a) + else: + new_q = qs + q + qs = '' + conv.append_message(conv.roles[0], new_q) + conv.append_message(conv.roles[1], a) + + prompt = conv.get_prompt(input_mode) + + if all_audio_path != []: + input_ids = tokenizer_image_audio_token( + prompt, tokenizer, + image_token_index=IMAGE_TOKEN_INDEX, + audio_token_index=AUDIO_TOKEN_INDEX + ) + audio_list = [] + for single_audio_path in all_audio_path: + try: + audio, original_sr = torchaudio.load(single_audio_path) + target_sr = 16000 + if original_sr != target_sr: + resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr) + audio = resampler(audio) + audio_features = feature_extractor(audio, sampling_rate=target_sr, return_tensors="pt")["input_features"] + audio_list.append(audio_features.squeeze(0)) + except Exception as e: + print(f"Error processing {single_audio_path}: {e}") + else: + input_ids = tokenizer_image_token( + prompt, tokenizer, + image_token_index=IMAGE_TOKEN_INDEX + ) + + if all_visual_tensor == [] and all_audio_path == []: + datapromt = { + "prompt_token_ids": input_ids, + } + elif all_visual_tensor != [] and all_audio_path == []: + datapromt = { + "prompt_token_ids": input_ids, + "multi_modal_data": { + "image": all_visual_tensor + }, + } + elif all_visual_tensor == [] and all_audio_path != []: + datapromt = { + "prompt_token_ids": input_ids, + "multi_modal_data": { + "audio": audio_list + }, + } + else: + datapromt = { + "prompt_token_ids": input_ids, + "multi_modal_data": { + "image": all_visual_tensor, + "audio": audio_list + }, + } + + print(datapromt) + + with torch.inference_mode(): + output_ids = model.generate( + input_ids, + images=all_visual_tensor, + audios=audio_list, + do_sample=False, + temperature=0.01, + top_p=None, + num_beams=1, + output_scores=True, + return_dict_in_generate=True, + max_new_tokens=1024, + use_cache=True, + ) + + outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0] + outputs = outputs.strip() + + task_history[-1] = (chat_query, outputs) + remove_special_characters_output = remove_special_characters(outputs) + _chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output)) + print("query", chat_query) + print("task_history", task_history) + print(_chatbot) + print("answer: ", outputs) + yield _chatbot + + +def add_text(history, task_history, text): + task_text = text + if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION: + task_text = text[:-1] + history = history + [(_parse_text(text), None)] + task_history = task_history + [(task_text, None)] + return history, task_history, "" + +def add_file(history, task_history, file): + history = history + [((file.name,), None)] + task_history = task_history + [((file.name,), None)] + return history, task_history + +def add_audio(history, task_history, file): + print(file) + if file is None: + return history, task_history + history = history + [((file,), None)] + task_history = task_history + [((file,), None)] + return history, task_history + +def add_video(history, task_history, file): + print(file) + if file is None: + return history, task_history + new_file_name = file.replace(".webm",".mp4") + if file.endswith(".webm"): + convert_webm_to_mp4(file, new_file_name) + task_history = task_history + [((new_file_name,), None)] + return history, task_history + + +def reset_user_input(): + return gr.update(value="") + +def reset_state(task_history): + task_history.clear() + return [] @spaces.GPU -def greet(n): - print(zero.device) # <-- 'cuda:0' 🤗 - return f"Hello {zero + n} Tensor" +def stream_audio_output(history, task_history): + text = task_history[-1][-1] + if not text: + # import pdb;pdb.set_trace() + yield None,None + llm_resounse = replace_equation(remove_special_characters(text)) + #print('tts_text', llm_resounse) + for idx, text in enumerate(split_into_sentences(llm_resounse)): + embeddings = llm_embedding(torch.tensor(tokenizer.encode(text)).cuda()) + for seg in tts.run(embeddings.reshape(-1, 896).unsqueeze(0), decoder_topk, + None, + codec_chunk_size, codec_padding_size): + if idx == 0: + try: + split_idx = torch.nonzero(seg.abs() > 0.03, as_tuple=True)[-1][0] + seg = seg[:, :, split_idx:] + except: + print('Do not need to split') + pass + + if seg is not None and len(seg) > 0: + seg = seg.to(torch.float32).cpu().numpy() + yield 24000, float_to_int16(seg).T + + +with gr.Blocks(title="VideoMLLM") as demo: + gr.Markdown("""
VITA
""") + chatbot = gr.Chatbot(label='VITA', elem_classes="control-height", height=500) + query = gr.Textbox(lines=2, label='Text Input') + task_history = gr.State([]) + with gr.Row(): + add_text_button = gr.Button("Submit Text (提交文本)") + add_audio_button = gr.Button("Submit Audio (提交音频)") + with gr.Row(): + with gr.Column(scale=2): + addfile_btn = gr.UploadButton("📁 Upload (上传文件[视频,图片])", file_types=["video", "image"]) + video_input = gr.Video(sources=[ "webcam"], height=400, width=700, container=True, interactive=True, show_download_button=True, label="📹 Video Recording (视频录制)") + + with gr.Column(scale=1): + empty_bin = gr.Button("🧹 Clear History (清除历史)") + record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000)) + audio_output = gr.Audio( + label="Output Audio", + value=None, + format= "wav", + autoplay=True, + streaming=True, + interactive=False, + show_label=True, + waveform_options=gr.WaveformOptions( + sample_rate=24000, + ), + ) + + + add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then( + reset_user_input, [], [query] + ).then( + predict, [chatbot, task_history], [chatbot], show_progress=True + ).then( + stream_audio_output,[chatbot, task_history], [audio_output], + ) + + + video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True) + empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) + addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True) + + + + add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then( + predict, [chatbot, task_history], [chatbot], show_progress=True + ).then( + stream_audio_output,[chatbot, task_history], [audio_output], + ) + -demo = gr.Interface(fn=greet, inputs=gr.Number(), outputs=gr.Text()) -demo.launch() +demo.launch(server_port=18806) diff --git a/vita/config/__init__.py b/vita/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57a4e3cd2742b36618ee8838e7573fdfa06b3939 --- /dev/null +++ b/vita/config/__init__.py @@ -0,0 +1,10 @@ +from .dataset_config import * + +NaturalCap0 = [ShareGPT4V0] +NaturalCap = [ShareGPT4V] + +DataConfig = { + "Pretrain_video": NaturalCap0, +} + +NoPatchSets = ["khair", "jester"] diff --git a/vita/config/dataset_config.py b/vita/config/dataset_config.py new file mode 100644 index 0000000000000000000000000000000000000000..eac21c8e3422ba71a6184dd0f45743b6be2c9153 --- /dev/null +++ b/vita/config/dataset_config.py @@ -0,0 +1,8 @@ +AudioFolder = "" +FolderDict = { + #### NaturalCap + "sharegpt4": "", +} +#### NaturalCap +ShareGPT4V = {"chat_path": ""} +ShareGPT4V0 = {"chat_path": ""} diff --git a/vita/constants.py b/vita/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..8302610d120eec46970b3d174fedc395ea2f376b --- /dev/null +++ b/vita/constants.py @@ -0,0 +1,14 @@ +# Model Constants +MAX_IMAGE_LENGTH = 16 # 8#16#32#64 +MIN_IMAGE_LENGTH = 4 +IGNORE_INDEX = -100 +IMAGE_TOKEN_INDEX = -200 +AUDIO_TOKEN_INDEX = -500 +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_VIDEO_TOKEN = "