import os import torch from flashsloth.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, LEARNABLE_TOKEN, LEARNABLE_TOKEN_INDEX ) from flashsloth.conversation import conv_templates, SeparatorStyle from flashsloth.model.builder import load_pretrained_model from flashsloth.utils import disable_torch_init from flashsloth.mm_utils import ( tokenizer_image_token, process_images, process_images_hd_inference, get_model_name_from_path, KeywordsStoppingCriteria ) from PIL import Image import gradio as gr from transformers import TextIteratorStreamer from threading import Thread disable_torch_init() MODEL_PATH = "Tongbo/FlashSloth_HD-3.2B" model_name = get_model_name_from_path(MODEL_PATH) tokenizer, model, image_processor, context_len = load_pretrained_model(MODEL_PATH, None, model_name) model.to('cuda') model.eval() def generate_description(image, prompt_text, temperature, top_p, max_tokens): keywords = [''] text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text text = text + LEARNABLE_TOKEN image = image.convert('RGB') if model.config.image_hd: image_tensor = process_images_hd_inference([image], image_processor, model.config)[0] else: image_tensor = process_images([image], image_processor, model.config)[0] image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True) conv = conv_templates["phi2"].copy() conv.append_message(conv.roles[0], text) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True) stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextIteratorStreamer( tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True ) generation_kwargs = dict( inputs=input_ids, images=image_tensor, do_sample=True, temperature=temperature, top_p=top_p, max_new_tokens=int(max_tokens), use_cache=True, eos_token_id=tokenizer.eos_token_id, stopping_criteria=[stopping_criteria], streamer=streamer ) def _generate(): with torch.inference_mode(): model.generate(**generation_kwargs) # 在单独线程中运行生成,防止阻塞 generation_thread = Thread(target=_generate) generation_thread.start() # 边生成边yield输出 partial_text = "" for new_text in streamer: partial_text += new_text yield partial_text generation_thread.join() # 自定义CSS样式,用于增大字体和美化界面 custom_css = """ """ with gr.Blocks(css=custom_css) as demo: gr.HTML(custom_css) gr.HTML("

FlashSloth 多模态大模型 Demo

") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="上传图片") temperature_slider = gr.Slider( minimum=0.01, maximum=1.0, step=0.05, value=0.7, label="Temperature" ) topp_slider = gr.Slider( minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top-p" ) maxtoken_slider = gr.Slider( minimum=64, maximum=3072, step=1, value=512, label="Max Tokens" ) with gr.Column(scale=1): prompt_input = gr.Textbox( lines=3, placeholder="Describe this photo in detail.", label="问题提示" ) submit_button = gr.Button("生成答案", variant="primary") output_text = gr.Textbox( label="生成的答案", interactive=False, lines=15, elem_classes=["output_text"] ) submit_button.click( fn=generate_description, inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider], outputs=output_text, show_progress=True ) if __name__ == "__main__": demo.queue().launch()