flashsloth / app.py
Tongbo's picture
Update app.py
b0b2441 verified
raw
history blame
5.08 kB
import os
import torch
from flashsloth.constants import (
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN, LEARNABLE_TOKEN, LEARNABLE_TOKEN_INDEX
)
from flashsloth.conversation import conv_templates, SeparatorStyle
from flashsloth.model.builder import load_pretrained_model
from flashsloth.utils import disable_torch_init
from flashsloth.mm_utils import (
tokenizer_image_token, process_images, process_images_hd_inference,
get_model_name_from_path, KeywordsStoppingCriteria
)
from PIL import Image
import gradio as gr
from transformers import TextIteratorStreamer
from threading import Thread
disable_torch_init()
MODEL_PATH = "Tongbo/FlashSloth_HD-3.2B"
model_name = get_model_name_from_path(MODEL_PATH)
tokenizer, model, image_processor, context_len = load_pretrained_model(MODEL_PATH, None, model_name)
model.to('cuda')
model.eval()
def generate_description(image, prompt_text, temperature, top_p, max_tokens):
keywords = ['</s>']
text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
text = text + LEARNABLE_TOKEN
image = image.convert('RGB')
if model.config.image_hd:
image_tensor = process_images_hd_inference([image], image_processor, model.config)[0]
else:
image_tensor = process_images([image], image_processor, model.config)[0]
image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True)
conv = conv_templates["phi2"].copy()
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True)
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
inputs=input_ids,
images=image_tensor,
do_sample=True,
temperature=temperature,
top_p=top_p,
max_new_tokens=int(max_tokens),
use_cache=True,
eos_token_id=tokenizer.eos_token_id,
stopping_criteria=[stopping_criteria],
streamer=streamer
)
def _generate():
with torch.inference_mode():
model.generate(**generation_kwargs)
# 在单独线程中运行生成,防止阻塞
generation_thread = Thread(target=_generate)
generation_thread.start()
# 边生成边yield输出
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
generation_thread.join()
# 自定义CSS样式,用于增大字体和美化界面
custom_css = """
<style>
/* 增大标题字体 */
#title {
font-size: 80px !important;
text-align: center;
margin-bottom: 20px;
}
/* 增大描述文字字体 */
#description {
font-size: 24px !important;
text-align: center;
margin-bottom: 40px;
}
/* 增大标签和输入框的字体 */
.gradio-container * {
font-size: 18px !important;
}
/* 增大按钮字体 */
button {
font-size: 20px !important;
padding: 10px 20px;
}
/* 增大输出文本的字体 */
.output_text {
font-size: 20px !important;
}
</style>
"""
with gr.Blocks(css=custom_css) as demo:
gr.HTML(custom_css)
gr.HTML("<h1 style='font-size:70px; text-align:center;'>FlashSloth 多模态大模型 Demo</h1>")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="上传图片")
temperature_slider = gr.Slider(
minimum=0.01,
maximum=1.0,
step=0.05,
value=0.7,
label="Temperature"
)
topp_slider = gr.Slider(
minimum=0.01,
maximum=1.0,
step=0.05,
value=0.9,
label="Top-p"
)
maxtoken_slider = gr.Slider(
minimum=64,
maximum=3072,
step=1,
value=512,
label="Max Tokens"
)
with gr.Column(scale=1):
prompt_input = gr.Textbox(
lines=3,
placeholder="Describe this photo in detail.",
label="问题提示"
)
submit_button = gr.Button("生成答案", variant="primary")
output_text = gr.Textbox(
label="生成的答案",
interactive=False,
lines=15,
elem_classes=["output_text"]
)
submit_button.click(
fn=generate_description,
inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider],
outputs=output_text,
show_progress=True
)
if __name__ == "__main__":
demo.queue().launch()