File size: 3,500 Bytes
8f558df
21fcfe6
ffcf8f2
21fcfe6
8f558df
02558d9
a533ef3
425e364
a533ef3
02558d9
5ca3297
ffcf8f2
5ca3297
ffcf8f2
5ca3297
ffcf8f2
5ca3297
 
ffcf8f2
 
5ca3297
 
 
5c6a1a7
5ca3297
ffcf8f2
 
 
5ca3297
8f558df
02558d9
1ac43cd
 
02558d9
 
 
 
 
 
 
 
 
 
 
 
 
 
8f558df
 
21fcfe6
1a981c9
02558d9
8f558df
6d7705b
dcf6d05
 
 
 
 
 
6d7705b
dcf6d05
1cc7126
6d7705b
 
1cc7126
dcf6d05
 
 
6d7705b
dcf6d05
 
 
 
6d7705b
 
 
 
 
dcf6d05
6d7705b
dcf6d05
 
 
 
 
 
 
2406bfd
dcf6d05
 
 
 
 
 
 
df30ad6
8f558df
 
 
 
 
 
 
 
 
 
 
ffcf8f2
8f558df
 
7890490
8f558df
 
 
 
21fcfe6
1a981c9
8f558df
755339c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import spaces
from transformers import Idefics3ForConditionalGeneration, AutoProcessor
import torch
from PIL import Image
from datetime import datetime
import numpy as np
import os


DESCRIPTION = """
# SmolVLM-trl-sft-ChartQA Demo

This is a demo Space for a fine-tuned version of [SmolVLM](https://huggingface.co./HuggingFaceTB/SmolVLM-Instruct) trained using [ChatQA dataset](https://huggingface.co./datasets/HuggingFaceM4/ChartQA).

The corresponding model is located [here](https://huggingface.co./sergiopaniego/smolvlm-instruct-trl-sft-ChartQA).
"""

model_id = "HuggingFaceTB/SmolVLM-Instruct"
model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    #_attn_implementation="flash_attention_2",
)

processor = AutoProcessor.from_pretrained(model_id)
adapter_path = "sergiopaniego/smolvlm-instruct-trl-sft-ChartQA"
model.load_adapter(adapter_path)

def array_to_image_path(image_array):
    if image_array is None:
        raise ValueError("No image provided. Please upload an image before submitting.")
    # Convert numpy array to PIL Image
    img = Image.fromarray(np.uint8(image_array))
    
    # Generate a unique filename using timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"image_{timestamp}.png"
    
    # Save the image
    img.save(filename)
    
    # Get the full path of the saved image
    full_path = os.path.abspath(filename)
    
    return full_path


@spaces.GPU
def run_example(image, text_input=None):
    image_path = array_to_image_path(image)
    image = Image.fromarray(image).convert("RGB")
    
    messages = [
    {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "text": None,
                },
                {
                    "text": text_input, 
                    "type": "text"
                },
            ],
        }
    ]
        
    # Preparation for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs = []
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image_inputs.append([image])
    
    inputs = processor(
        text=text,
        images=image_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    
    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=1024)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    return output_text[0]

css = """
  #output {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Tab(label="SmolVLM-trl-sft-ChartQA Input"):
        with gr.Row():
            with gr.Column():
                input_img = gr.Image(label="Input Picture")
                text_input = gr.Textbox(label="Question")
                submit_btn = gr.Button(value="Submit")
            with gr.Column():
                output_text = gr.Textbox(label="Output Text")

        submit_btn.click(run_example, [input_img, text_input], [output_text])

demo.queue(api_open=False)
demo.launch(debug=True)