Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig | |
from PIL import Image | |
import torch | |
import spaces | |
# Flag to use GPU (set to False by default) | |
USE_GPU = False | |
# Load the processor and model | |
device = torch.device("cuda" if USE_GPU and torch.cuda.is_available() else "cpu") | |
processor = AutoProcessor.from_pretrained( | |
'allenai/MolmoE-1B-0924', | |
trust_remote_code=True, | |
torch_dtype='auto', | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
'allenai/MolmoE-1B-0924', | |
trust_remote_code=True, | |
torch_dtype='auto', | |
device_map='auto' if USE_GPU else None | |
) | |
if not USE_GPU: | |
model.to(device) | |
# Predefined prompts | |
prompts = [ | |
"Describe this image in detail", | |
"What objects can you see in this image?", | |
"What's the main subject of this image?", | |
"Describe the colors in this image", | |
"What emotions does this image evoke?" | |
] | |
def process_image_and_text(image, text, max_new_tokens, temperature, top_p): | |
# Process the image and text | |
inputs = processor.process( | |
images=[Image.fromarray(image)], | |
text=text | |
) | |
# Move inputs to the correct device and make a batch of size 1 | |
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()} | |
# Generate output | |
output = model.generate_from_batch( | |
inputs, | |
GenerationConfig( | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
stop_strings="<|endoftext|>" | |
), | |
tokenizer=processor.tokenizer | |
) | |
# Only get generated tokens; decode them to text | |
generated_tokens = output[0, inputs['input_ids'].size(1):] | |
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
return generated_text | |
def chatbot(image, text, history, max_new_tokens, temperature, top_p): | |
if image is None: | |
return history + [("Please upload an image first.", None)] | |
response = process_image_and_text(image, text, max_new_tokens, temperature, top_p) | |
history.append((text, response)) | |
return history | |
def update_textbox(prompt): | |
return gr.update(value=prompt) | |
# Define the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Image Chatbot with MolmoE-1B-0924") | |
with gr.Row(): | |
image_input = gr.Image(type="numpy") | |
chatbot_output = gr.Chatbot() | |
with gr.Row(): | |
text_input = gr.Textbox(placeholder="Ask a question about the image...") | |
prompt_dropdown = gr.Dropdown(choices=[""] + prompts, label="Select a premade prompt", value="") | |
submit_button = gr.Button("Submit") | |
clear_button = gr.ClearButton([text_input, chatbot_output]) | |
with gr.Accordion("Advanced options", open=False): | |
max_new_tokens = gr.Slider(minimum=1, maximum=500, value=200, step=1, label="Max new tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
state = gr.State([]) | |
# Add copy button for raw output | |
with gr.Row(): | |
raw_output = gr.Textbox(label="Raw Output", interactive=False) | |
copy_button = gr.Button("Copy Raw Output") | |
def update_raw_output(history): | |
if history: | |
return history[-1][1] | |
return "" | |
submit_button.click( | |
chatbot, | |
inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p], | |
outputs=[chatbot_output] | |
).then( | |
update_raw_output, | |
inputs=[chatbot_output], | |
outputs=[raw_output] | |
) | |
text_input.submit( | |
chatbot, | |
inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p], | |
outputs=[chatbot_output] | |
).then( | |
update_raw_output, | |
inputs=[chatbot_output], | |
outputs=[raw_output] | |
) | |
prompt_dropdown.change(update_textbox, inputs=[prompt_dropdown], outputs=[text_input]) | |
copy_button.click(lambda x: gr.update(value=x), inputs=[raw_output], outputs=[gr.Textbox(visible=False)]) | |
demo.launch() |