Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from flashsloth.constants import ( | |
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, | |
DEFAULT_IM_END_TOKEN, LEARNABLE_TOKEN, LEARNABLE_TOKEN_INDEX | |
) | |
from flashsloth.conversation import conv_templates, SeparatorStyle | |
from flashsloth.model.builder import load_pretrained_model | |
from flashsloth.utils import disable_torch_init | |
from flashsloth.mm_utils import ( | |
tokenizer_image_token, process_images, process_images_hd_inference, | |
get_model_name_from_path, KeywordsStoppingCriteria | |
) | |
from PIL import Image | |
import gradio as gr | |
from transformers import TextIteratorStreamer | |
from threading import Thread | |
disable_torch_init() | |
MODEL_PATH = "Tongbo/FlashSloth_HD-3.2B" | |
model_name = get_model_name_from_path(MODEL_PATH) | |
tokenizer, model, image_processor, context_len = load_pretrained_model(MODEL_PATH, None, model_name) | |
model.to('cuda') | |
model.eval() | |
def generate_description(image, prompt_text, temperature, top_p, max_tokens): | |
keywords = ['</s>'] | |
text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text | |
text = text + LEARNABLE_TOKEN | |
image = image.convert('RGB') | |
if model.config.image_hd: | |
image_tensor = process_images_hd_inference([image], image_processor, model.config)[0] | |
else: | |
image_tensor = process_images([image], image_processor, model.config)[0] | |
image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True) | |
conv = conv_templates["phi2"].copy() | |
conv.append_message(conv.roles[0], text) | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') | |
input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True) | |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | |
streamer = TextIteratorStreamer( | |
tokenizer=tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
generation_kwargs = dict( | |
inputs=input_ids, | |
images=image_tensor, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
max_new_tokens=int(max_tokens), | |
use_cache=True, | |
eos_token_id=tokenizer.eos_token_id, | |
stopping_criteria=[stopping_criteria], | |
streamer=streamer | |
) | |
def _generate(): | |
with torch.inference_mode(): | |
model.generate(**generation_kwargs) | |
# 在单独线程中运行生成,防止阻塞 | |
generation_thread = Thread(target=_generate) | |
generation_thread.start() | |
# 边生成边yield输出 | |
partial_text = "" | |
for new_text in streamer: | |
partial_text += new_text | |
yield partial_text | |
generation_thread.join() | |
# 自定义CSS样式,用于增大字体和美化界面 | |
custom_css = """ | |
<style> | |
/* 增大标题字体 */ | |
#title { | |
font-size: 80px !important; | |
text-align: center; | |
margin-bottom: 20px; | |
} | |
/* 增大描述文字字体 */ | |
#description { | |
font-size: 24px !important; | |
text-align: center; | |
margin-bottom: 40px; | |
} | |
/* 增大标签和输入框的字体 */ | |
.gradio-container * { | |
font-size: 18px !important; | |
} | |
/* 增大按钮字体 */ | |
button { | |
font-size: 20px !important; | |
padding: 10px 20px; | |
} | |
/* 增大输出文本的字体 */ | |
.output_text { | |
font-size: 20px !important; | |
} | |
</style> | |
""" | |
with gr.Blocks(css=custom_css) as demo: | |
gr.HTML(custom_css) | |
gr.HTML("<h1 style='font-size:70px; text-align:center;'>FlashSloth 多模态大模型 Demo</h1>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil", label="上传图片") | |
temperature_slider = gr.Slider( | |
minimum=0.01, | |
maximum=1.0, | |
step=0.05, | |
value=0.7, | |
label="Temperature" | |
) | |
topp_slider = gr.Slider( | |
minimum=0.01, | |
maximum=1.0, | |
step=0.05, | |
value=0.9, | |
label="Top-p" | |
) | |
maxtoken_slider = gr.Slider( | |
minimum=64, | |
maximum=3072, | |
step=1, | |
value=512, | |
label="Max Tokens" | |
) | |
with gr.Column(scale=1): | |
prompt_input = gr.Textbox( | |
lines=3, | |
placeholder="Describe this photo in detail.", | |
label="问题提示" | |
) | |
submit_button = gr.Button("生成答案", variant="primary") | |
output_text = gr.Textbox( | |
label="生成的答案", | |
interactive=False, | |
lines=15, | |
elem_classes=["output_text"] | |
) | |
submit_button.click( | |
fn=generate_description, | |
inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider], | |
outputs=output_text, | |
show_progress=True | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() | |