MolMo-7B-D-0924 / app.py
zqu2004's picture
Update app.py
8b1dbc7 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
import torch
import spaces
# Flag to use GPU (set to False by default)
USE_GPU = False
# Load the processor and model
device = torch.device("cuda" if USE_GPU and torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained(
'allenai/MolmoE-1B-0924',
trust_remote_code=True,
torch_dtype='auto',
)
model = AutoModelForCausalLM.from_pretrained(
'allenai/MolmoE-1B-0924',
trust_remote_code=True,
torch_dtype='auto',
device_map='auto' if USE_GPU else None
)
if not USE_GPU:
model.to(device)
# Predefined prompts
prompts = [
"Describe this image in detail",
"What objects can you see in this image?",
"What's the main subject of this image?",
"Describe the colors in this image",
"What emotions does this image evoke?"
]
def process_image_and_text(image, text, max_new_tokens, temperature, top_p):
# Process the image and text
inputs = processor.process(
images=[Image.fromarray(image)],
text=text
)
# Move inputs to the correct device and make a batch of size 1
inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
# Generate output
output = model.generate_from_batch(
inputs,
GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
stop_strings="<|endoftext|>"
),
tokenizer=processor.tokenizer
)
# Only get generated tokens; decode them to text
generated_tokens = output[0, inputs['input_ids'].size(1):]
generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return generated_text
def chatbot(image, text, history, max_new_tokens, temperature, top_p):
if image is None:
return history + [("Please upload an image first.", None)]
response = process_image_and_text(image, text, max_new_tokens, temperature, top_p)
history.append((text, response))
return history
def update_textbox(prompt):
return gr.update(value=prompt)
# Define the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Image Chatbot with MolmoE-1B-0924")
with gr.Row():
image_input = gr.Image(type="numpy")
chatbot_output = gr.Chatbot()
with gr.Row():
text_input = gr.Textbox(placeholder="Ask a question about the image...")
prompt_dropdown = gr.Dropdown(choices=[""] + prompts, label="Select a premade prompt", value="")
submit_button = gr.Button("Submit")
clear_button = gr.ClearButton([text_input, chatbot_output])
with gr.Accordion("Advanced options", open=False):
max_new_tokens = gr.Slider(minimum=1, maximum=500, value=200, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
state = gr.State([])
# Add copy button for raw output
with gr.Row():
raw_output = gr.Textbox(label="Raw Output", interactive=False)
copy_button = gr.Button("Copy Raw Output")
def update_raw_output(history):
if history:
return history[-1][1]
return ""
submit_button.click(
chatbot,
inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p],
outputs=[chatbot_output]
).then(
update_raw_output,
inputs=[chatbot_output],
outputs=[raw_output]
)
text_input.submit(
chatbot,
inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p],
outputs=[chatbot_output]
).then(
update_raw_output,
inputs=[chatbot_output],
outputs=[raw_output]
)
prompt_dropdown.change(update_textbox, inputs=[prompt_dropdown], outputs=[text_input])
copy_button.click(lambda x: gr.update(value=x), inputs=[raw_output], outputs=[gr.Textbox(visible=False)])
demo.launch()