hiyouga's picture
Update app.py
617cca9 verified
raw
history blame
3.11 kB
from threading import Thread
from typing import Dict
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, TextIteratorStreamer
TITLE = "<h1><center>Chat with PaliGemma-3B-Chat-v0.2</center></h1>"
DESCRIPTION = "<h3><center>Visit <a href='https://huggingface.co./BUAADreamer/PaliGemma-3B-Chat-v0.2' target='_blank'>our model page</a> for details.</center></h3>"
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model_id = "BUAADreamer/PaliGemma-3B-Chat-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
@spaces.GPU
def stream_chat(message: Dict[str, str], history: list):
# Turn 1:
# {'text': 'what is this', 'files': ['image-xxx.jpg']}
# []
# Turn 2:
# {'text': 'continue?', 'files': []}
# [[('image-xxx.jpg',), None], ['what is this', 'a image.']]
image_path = None
if len(message["files"]) != 0:
image_path = message["files"][0]
if len(history) != 0 and isinstance(history[0][0], tuple):
image_path = history[0][0][0]
history = history[1:]
if image_path is not None:
image = Image.open(image_path).convert("RGB")
else:
image = Image.new("RGB", (100, 100), (255, 255, 255))
pixel_values = processor(images=[image], return_tensors="pt").to(model.device)["pixel_values"]
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message["text"]})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
image_token_id = tokenizer.convert_tokens_to_ids("<image>")
image_prefix = torch.empty((1, getattr(processor, "image_seq_length")), dtype=input_ids.dtype).fill_(image_token_id)
input_ids = torch.cat((image_prefix, input_ids), dim=-1).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
pixel_values=pixel_values,
streamer=streamer,
max_new_tokens=256,
do_sample=True,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
output = ""
for new_token in streamer:
output += new_token
yield output
chatbot = gr.Chatbot(height=450)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
chatbot=chatbot,
fill_height=True,
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()