import torch import os import argparse import numpy as np import copy import gradio as gr import re import torchaudio import io import cv2 import math import spaces from numba import jit from huggingface_hub import snapshot_download from vita.constants import DEFAULT_AUDIO_TOKEN, DEFAULT_IMAGE_TOKEN, MAX_IMAGE_LENGTH, MIN_IMAGE_LENGTH, IMAGE_TOKEN_INDEX, AUDIO_TOKEN_INDEX from vita.conversation import conv_templates, SeparatorStyle from vita.util.mm_utils import tokenizer_image_token, tokenizer_image_audio_token from PIL import Image from decord import VideoReader, cpu from vita.model.builder import load_pretrained_model from vita.model.vita_tts.decoder.llm2tts import llm2TTS from vita.model.language_model.vita_qwen2 import VITAQwen2Config, VITAQwen2ForCausalLM decoder_topk = 2 codec_chunk_size = 40 codec_padding_size = 10 PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘'‛""„‟…‧﹏." MODEL_NAME = "VITA-MLLM/VITA-1.5" model_path = snapshot_download(MODEL_NAME, local_dir="VITA_ckpt") tokenizer, model, feature_extractor, context_len = load_pretrained_model( model_path, model_base=None, model_name="VITA-1.5", model_type="qwen2p5_instruct" ) llm_embedding = model.get_input_embeddings().cuda() tts = llm2TTS(os.path.join(model_path, 'vita_tts_ckpt/')) @jit def float_to_int16(audio: np.ndarray) -> np.ndarray: am = int(math.ceil(float(np.abs(audio).max())) * 32768) am = 32767 * 32768 // am return np.multiply(audio, am).astype(np.int16) def remove_special_characters(input_str): # Remove special tokens special_tokens = ['☞', '☟', '☜', '', '<|im_end|>'] for token in special_tokens: input_str = input_str.replace(token, '') return input_str def replace_equation(sentence): special_notations = { "sin": " sine ", "cos": " cosine ", "tan": " tangent ", "cot": " cotangent ", "sec": " secant ", "csc": " cosecant ", "log": " logarithm ", "exp": "e^", "sqrt": "根号 ", "abs": "绝对值 ", } special_operators = { "+": "加", "-": "减", "*": "乘", "/": "除", "=": "等于", '!=': '不等于', '>': '大于', '<': '小于', '>=': '大于等于', '<=': '小于等于', } greek_letters = { "α": "alpha ", "β": "beta ", "γ": "gamma ", "δ": "delta ", "ε": "epsilon ", "ζ": "zeta ", "η": "eta ", "θ": "theta ", "ι": "iota ", "κ": "kappa ", "λ": "lambda ", "μ": "mu ", "ν": "nu ", "ξ": "xi ", "ο": "omicron ", "π": "派 ", "ρ": "rho ", "σ": "sigma ", "τ": "tau ", "υ": "upsilon ", "φ": "phi ", "χ": "chi ", "ψ": "psi ", "ω": "omega " } sentence = sentence.replace('**', ' ') sentence = re.sub(r'(? end_time: start_time, end_time = end_time, start_time elif start_time == end_time: end_time = start_time + 1 if os.path.exists(video_path): vreader = VideoReader(video_path, ctx=cpu(0)) else: raise FileNotFoundError fps = vreader.get_avg_fps() f_start = 0 if start_time is None else int(start_time * fps) f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1)) num_frames = f_end - f_start + 1 if num_frames > 0: sample_fps = int(video_framerate) t_stride = int(round(float(fps) / sample_fps)) all_pos = list(range(f_start, f_end + 1, t_stride)) if len(all_pos) > max_frames: sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)] elif len(all_pos) < min_frames: sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=min_frames, dtype=int)] else: sample_pos = all_pos patch_images = [Image.fromarray(f).convert("RGB") for f in vreader.get_batch(sample_pos).asnumpy()] return patch_images, len(patch_images) else: print(f"video path: {video_path} error.") def _parse_text(text): lines = text.split("\n") lines = [line for line in lines if line != ""] count = 0 for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split("`") if count % 2 == 1: lines[i] = f'
'
            else:
                lines[i] = "
" else: if i > 0 and count % 2 == 1: line = line.replace("`", r"\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") lines[i] = "
" + line return "".join(lines) @spaces.GPU def predict(_chatbot, task_history): chat_query = task_history[-1][0] print(task_history) conv_mode = "qwen2p5_instruct" conv = conv_templates[conv_mode].copy() all_audio_path = [] all_visual_tensor = [] qs = '' input_mode = 'lang' for i, (q, a) in enumerate(task_history): if isinstance(q, (tuple, list)): if is_image(q[0]): images = [Image.open(q[0]).convert("RGB")] all_visual_tensor.extend(images) input_mode = 'image' qs += DEFAULT_IMAGE_TOKEN * len(images) + '\n' elif is_video(q[0]): video_frames, slice_len = _get_rawvideo_dec(q[0]) all_visual_tensor.extend(video_frames) input_mode = 'video' qs += DEFAULT_IMAGE_TOKEN * slice_len + '\n' elif is_wav(q[0]): if a is not None and a.startswith('☜'): continue else: all_audio_path.append(q[0]) new_q = qs + DEFAULT_AUDIO_TOKEN qs = '' conv.append_message(conv.roles[0], new_q) conv.append_message(conv.roles[1], a) else: new_q = qs + q qs = '' conv.append_message(conv.roles[0], new_q) conv.append_message(conv.roles[1], a) prompt = conv.get_prompt(input_mode) if all_audio_path != []: input_ids = tokenizer_image_audio_token( prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, audio_token_index=AUDIO_TOKEN_INDEX ) audio_list = [] for single_audio_path in all_audio_path: try: audio, original_sr = torchaudio.load(single_audio_path) target_sr = 16000 if original_sr != target_sr: resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=target_sr) audio = resampler(audio) audio_features = feature_extractor(audio, sampling_rate=target_sr, return_tensors="pt")["input_features"] audio_list.append(audio_features.squeeze(0)) except Exception as e: print(f"Error processing {single_audio_path}: {e}") else: input_ids = tokenizer_image_token( prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX ) if all_visual_tensor == [] and all_audio_path == []: datapromt = { "prompt_token_ids": input_ids, } elif all_visual_tensor != [] and all_audio_path == []: datapromt = { "prompt_token_ids": input_ids, "multi_modal_data": { "image": all_visual_tensor }, } elif all_visual_tensor == [] and all_audio_path != []: datapromt = { "prompt_token_ids": input_ids, "multi_modal_data": { "audio": audio_list }, } else: datapromt = { "prompt_token_ids": input_ids, "multi_modal_data": { "image": all_visual_tensor, "audio": audio_list }, } print(datapromt) with torch.inference_mode(): output_ids = model.generate( input_ids, images=all_visual_tensor if all_visual_tensor else None, audios=audio_list if audio_list else None, do_sample=False, temperature=0.01, top_p=None, num_beams=1, output_scores=True, return_dict_in_generate=True, max_new_tokens=1024, use_cache=True, ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=False)[0] outputs = outputs.strip() task_history[-1] = (chat_query, outputs) remove_special_characters_output = remove_special_characters(outputs) _chatbot[-1] = (chat_query, _parse_text(remove_special_characters_output)) print("query", chat_query) print("task_history", task_history) print(_chatbot) print("answer: ", outputs) yield _chatbot def add_text(history, task_history, text): task_text = text if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION: task_text = text[:-1] history = history + [(_parse_text(text), None)] task_history = task_history + [(task_text, None)] return history, task_history, "" def add_file(history, task_history, file): history = history + [((file.name,), None)] task_history = task_history + [((file.name,), None)] return history, task_history def add_audio(history, task_history, file): print(file) if file is None: return history, task_history history = history + [((file,), None)] task_history = task_history + [((file,), None)] return history, task_history def add_video(history, task_history, file): print(file) if file is None: return history, task_history new_file_name = file.replace(".webm",".mp4") if file.endswith(".webm"): convert_webm_to_mp4(file, new_file_name) task_history = task_history + [((new_file_name,), None)] return history, task_history def reset_user_input(): return gr.update(value="") def reset_state(task_history): task_history.clear() return [] @spaces.GPU def stream_audio_output(history, task_history): text = task_history[-1][-1] if not text: # import pdb;pdb.set_trace() yield None,None llm_resounse = replace_equation(remove_special_characters(text)) #print('tts_text', llm_resounse) for idx, text in enumerate(split_into_sentences(llm_resounse)): embeddings = llm_embedding(torch.tensor(tokenizer.encode(text)).cuda()) for seg in tts.run(embeddings.reshape(-1, 896).unsqueeze(0), decoder_topk, None, codec_chunk_size, codec_padding_size): if idx == 0: try: split_idx = torch.nonzero(seg.abs() > 0.03, as_tuple=True)[-1][0] seg = seg[:, :, split_idx:] except: print('Do not need to split') pass if seg is not None and len(seg) > 0: seg = seg.to(torch.float32).cpu().numpy() yield 24000, float_to_int16(seg).T with gr.Blocks(title="VideoMLLM") as demo: gr.Markdown("""
VITA
""") chatbot = gr.Chatbot(label='VITA', elem_classes="control-height", height=500) query = gr.Textbox(lines=2, label='Text Input') task_history = gr.State([]) with gr.Row(): add_text_button = gr.Button("Submit Text (提交文本)") add_audio_button = gr.Button("Submit Audio (提交音频)") with gr.Row(): with gr.Column(scale=2): addfile_btn = gr.UploadButton("📁 Upload (上传文件[视频,图片])", file_types=["video", "image"]) video_input = gr.Video(sources=[ "webcam"], height=400, width=700, container=True, interactive=True, show_download_button=True, label="📹 Video Recording (视频录制)") with gr.Column(scale=1): empty_bin = gr.Button("🧹 Clear History (清除历史)") record_btn = gr.Audio(sources=[ "microphone","upload"], type="filepath", label="🎤 Record or Upload Audio (录音或上传音频)", show_download_button=True, waveform_options=gr.WaveformOptions(sample_rate=16000)) audio_output = gr.Audio( label="Output Audio", value=None, format= "wav", autoplay=True, streaming=True, interactive=False, show_label=True, waveform_options=gr.WaveformOptions( sample_rate=24000, ), ) add_text_button.click(add_text, [chatbot, task_history, query], [chatbot, task_history], show_progress=True).then( reset_user_input, [], [query] ).then( predict, [chatbot, task_history], [chatbot], show_progress=True ).then( stream_audio_output,[chatbot, task_history], [audio_output], ) video_input.stop_recording(add_video, [chatbot, task_history, video_input], [chatbot, task_history], show_progress=True) empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True) addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True) add_audio_button.click(add_audio, [chatbot, task_history,record_btn], [chatbot, task_history], show_progress=True).then( predict, [chatbot, task_history], [chatbot], show_progress=True ).then( stream_audio_output,[chatbot, task_history], [audio_output], ) demo.launch()