Spaces:
Sleeping
Sleeping
File size: 5,082 Bytes
b0b2441 39a3261 b0b2441 39a3261 b0b2441 39a3261 b0b2441 39a3261 b0b2441 39a3261 b0b2441 39a3261 b0b2441 39a3261 b0b2441 39a3261 b0b2441 39a3261 b0b2441 39a3261 b0b2441 39a3261 b0b2441 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import os
import torch
from flashsloth.constants import (
IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN, LEARNABLE_TOKEN, LEARNABLE_TOKEN_INDEX
)
from flashsloth.conversation import conv_templates, SeparatorStyle
from flashsloth.model.builder import load_pretrained_model
from flashsloth.utils import disable_torch_init
from flashsloth.mm_utils import (
tokenizer_image_token, process_images, process_images_hd_inference,
get_model_name_from_path, KeywordsStoppingCriteria
)
from PIL import Image
import gradio as gr
from transformers import TextIteratorStreamer
from threading import Thread
disable_torch_init()
MODEL_PATH = "Tongbo/FlashSloth_HD-3.2B"
model_name = get_model_name_from_path(MODEL_PATH)
tokenizer, model, image_processor, context_len = load_pretrained_model(MODEL_PATH, None, model_name)
model.to('cuda')
model.eval()
def generate_description(image, prompt_text, temperature, top_p, max_tokens):
keywords = ['</s>']
text = DEFAULT_IMAGE_TOKEN + '\n' + prompt_text
text = text + LEARNABLE_TOKEN
image = image.convert('RGB')
if model.config.image_hd:
image_tensor = process_images_hd_inference([image], image_processor, model.config)[0]
else:
image_tensor = process_images([image], image_processor, model.config)[0]
image_tensor = image_tensor.unsqueeze(0).to(dtype=torch.float16, device='cuda', non_blocking=True)
conv = conv_templates["phi2"].copy()
conv.append_message(conv.roles[0], text)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
input_ids = input_ids.unsqueeze(0).to(device='cuda', non_blocking=True)
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
inputs=input_ids,
images=image_tensor,
do_sample=True,
temperature=temperature,
top_p=top_p,
max_new_tokens=int(max_tokens),
use_cache=True,
eos_token_id=tokenizer.eos_token_id,
stopping_criteria=[stopping_criteria],
streamer=streamer
)
def _generate():
with torch.inference_mode():
model.generate(**generation_kwargs)
# 在单独线程中运行生成,防止阻塞
generation_thread = Thread(target=_generate)
generation_thread.start()
# 边生成边yield输出
partial_text = ""
for new_text in streamer:
partial_text += new_text
yield partial_text
generation_thread.join()
# 自定义CSS样式,用于增大字体和美化界面
custom_css = """
<style>
/* 增大标题字体 */
#title {
font-size: 80px !important;
text-align: center;
margin-bottom: 20px;
}
/* 增大描述文字字体 */
#description {
font-size: 24px !important;
text-align: center;
margin-bottom: 40px;
}
/* 增大标签和输入框的字体 */
.gradio-container * {
font-size: 18px !important;
}
/* 增大按钮字体 */
button {
font-size: 20px !important;
padding: 10px 20px;
}
/* 增大输出文本的字体 */
.output_text {
font-size: 20px !important;
}
</style>
"""
with gr.Blocks(css=custom_css) as demo:
gr.HTML(custom_css)
gr.HTML("<h1 style='font-size:70px; text-align:center;'>FlashSloth 多模态大模型 Demo</h1>")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="上传图片")
temperature_slider = gr.Slider(
minimum=0.01,
maximum=1.0,
step=0.05,
value=0.7,
label="Temperature"
)
topp_slider = gr.Slider(
minimum=0.01,
maximum=1.0,
step=0.05,
value=0.9,
label="Top-p"
)
maxtoken_slider = gr.Slider(
minimum=64,
maximum=3072,
step=1,
value=512,
label="Max Tokens"
)
with gr.Column(scale=1):
prompt_input = gr.Textbox(
lines=3,
placeholder="Describe this photo in detail.",
label="问题提示"
)
submit_button = gr.Button("生成答案", variant="primary")
output_text = gr.Textbox(
label="生成的答案",
interactive=False,
lines=15,
elem_classes=["output_text"]
)
submit_button.click(
fn=generate_description,
inputs=[image_input, prompt_input, temperature_slider, topp_slider, maxtoken_slider],
outputs=output_text,
show_progress=True
)
if __name__ == "__main__":
demo.queue().launch()
|