ahmed-masry's picture
Update app.py
9a05abc verified
raw
history blame
2.26 kB
import gradio as gr
from transformers import AutoProcessor, LlavaForConditionalGeneration
import requests
from PIL import Image
import torch, os, re, json
import spaces
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/74801584018932.png', 'chart_example_1.png')
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_2.png')
model = LlavaForConditionalGeneration.from_pretrained("ahmed-masry/ChartInstruct-LLama2", torch_dtype=torch.float16)
processor = AutoProcessor.from_pretrained("ahmed-masry/ChartInstruct-LLama2")
@spaces.GPU
def predict(image, input_text):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_prompt = f"<image>\n Question: {input_text} Answer: "
image = image.convert("RGB")
inputs = processor(text=input_prompt, images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# change type if pixel_values in inputs to fp16.
inputs['pixel_values'] = inputs['pixel_values'].to(torch.float16)
prompt_length = inputs['input_ids'].shape[1]
# Generate
generate_ids = model.generate(**inputs, num_beams=4, max_new_tokens=512)
output_text = processor.batch_decode(generate_ids[:, prompt_length:], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
return output_text
image = gr.components.Image(type="pil", label="Chart Image")
input_prompt = gr.components.Textbox(label="Input Prompt")
model_output = gr.components.Textbox(label="Model Output")
examples = [["chart_example_1.png", "Describe the trend of the mortality rates for the Neonatal"],
["chart_example_2.png", "What is the share of respondants who prefer Facebook Messenger in the 30-59 age group?"]]
title = "Interactive Gradio Demo for ChartInstruct-Llama2 model"
interface = gr.Interface(fn=predict,
inputs=[image, input_prompt],
outputs=model_output,
examples=examples,
title=title,
theme='gradio/soft')
interface.launch()