Spaces:
Runtime error
Runtime error
File size: 3,713 Bytes
03f3fb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
"""
This script creates a Gradio demo with a Transformers backend for the glm-4v-9b model, allowing users to interact with the model through a Gradio web UI.
Usage:
- Run the script to start the Gradio server.
- Interact with the model via the web UI.
Requirements:
- Gradio package
- Type `pip install gradio` to install Gradio.
"""
import os
import torch
import gradio as gr
from threading import Thread
from transformers import (
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer, AutoModel, BitsAndBytesConfig
)
from PIL import Image
import requests
from io import BytesIO
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b')
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
encode_special_tokens=True
)
model = AutoModel.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16
).eval()
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = model.config.eos_token_id
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def get_image(image_path=None, image_url=None):
if image_path:
return Image.open(image_path).convert("RGB")
elif image_url:
response = requests.get(image_url)
return Image.open(BytesIO(response.content)).convert("RGB")
return None
def chatbot(image_path=None, image_url=None, assistant_prompt=""):
image = get_image(image_path, image_url)
messages = [
{"role": "assistant", "content": assistant_prompt},
{"role": "user", "content": "", "image": image}
]
model_inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True
).to(next(model.parameters()).device)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60,
skip_prompt=True,
skip_special_tokens=True
)
generate_kwargs = {
**model_inputs,
"streamer": streamer,
"max_new_tokens": 1024,
"do_sample": True,
"top_p": 0.8,
"temperature": 0.6,
"stopping_criteria": StoppingCriteriaList([StopOnTokens()]),
"repetition_penalty": 1.2,
"eos_token_id": [151329, 151336, 151338],
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
response = ""
for new_token in streamer:
if new_token:
response += new_token
return image, response.strip()
with gr.Blocks() as demo:
demo.title = "GLM-4V-9B Image Recognition Demo"
demo.description = """
This demo uses the GLM-4V-9B model to got image infomation.
"""
with gr.Row():
with gr.Column():
image_path_input = gr.File(label="Upload Image (High-Priority)", type="filepath")
image_url_input = gr.Textbox(label="Image URL (Low-Priority)")
assistant_prompt_input = gr.Textbox(label="Assistant Prompt (You Can Change It)", value="这是什么?")
submit_button = gr.Button("Submit")
with gr.Column():
chatbot_output = gr.Textbox(label="GLM-4V-9B Model Response")
image_output = gr.Image(label="Image Preview")
submit_button.click(chatbot,
inputs=[image_path_input, image_url_input, assistant_prompt_input],
outputs=[image_output, chatbot_output])
demo.launch(server_name="127.0.0.1", server_port=8911, inbrowser=True, share=False) |