Spaces:
Running
on
Zero
Running
on
Zero
from threading import Thread | |
from typing import Dict | |
import gradio as gr | |
import spaces | |
import torch | |
from PIL import Image | |
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer | |
TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.2</center></h1>" | |
DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co./BUAADreamer/PaliGemma-3B-Chat-v0.2' target='_blank'>our model page</a> for details.</center></h3>" | |
CSS = """ | |
.duplicate-button { | |
margin: auto !important; | |
color: white !important; | |
background: black !important; | |
border-radius: 100vh !important; | |
} | |
""" | |
model_id = "BUAADreamer/PaliGemma-3B-Chat-v0.2" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
processor = AutoProcessor.from_pretrained(model_id) | |
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto") | |
def stream_chat(message: Dict[str, str], history: list): | |
# Turn 1: | |
# {'text': 'what is this', 'files': ['image-xxx.jpg']} | |
# [] | |
# Turn 2: | |
# {'text': 'continue?', 'files': []} | |
# [[('image-xxx.jpg',), None], ['what is this', 'a image.']] | |
image_path = None | |
if len(message["files"]) != 0: | |
image_path = message["files"][0] | |
if len(history) != 0 and isinstance(history[0][0], tuple): | |
image_path = history[0][0][0] | |
history = history[1:] | |
if image_path is not None: | |
image = Image.open(image_path).convert("RGB") | |
else: | |
image = Image.new("RGB", (100, 100), (255, 255, 255)) | |
pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"] | |
conversation = [] | |
for prompt, answer in history: | |
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}]) | |
conversation.append({"role": "user", "content": message["text"]}) | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
image_token_id = tokenizer.convert_tokens_to_ids("<image>") | |
image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id) | |
input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
pixel_values=pixel_values, | |
streamer=streamer, | |
max_new_tokens=256, | |
do_sample=True, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
output = "" | |
for new_token in streamer: | |
output += new_token | |
yield output | |
chatbot = gr.Chatbot(height=450) | |
with gr.Blocks(css=CSS) as demo: | |
gr.HTML(TITLE) | |
gr.HTML(DESCRIPTION) | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") | |
gr.ChatInterface( | |
fn=stream_chat, | |
multimodal=True, | |
chatbot=chatbot, | |
fill_height=True, | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
demo.launch() | |