Akshayram1 commited on
Commit
d6b6251
·
verified ·
1 Parent(s): acdf570

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import PIL.Image
4
+ import transformers
5
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
6
+ import torch
7
+ import string
8
+ import functools
9
+ import re
10
+ import flax.linen as nn
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import numpy as np
14
+ import spaces
15
+
16
+ model_id = "gv-hf/paligemma2-10b-mix-448"
17
+ COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval().to(device)
20
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
21
+
22
+ @spaces.GPU
23
+ def infer(image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
24
+ inputs = processor(text=text, images=image, return_tensors="pt").to(device)
25
+ with torch.inference_mode():
26
+ generated_ids = model.generate(
27
+ **inputs,
28
+ max_new_tokens=max_new_tokens,
29
+ do_sample=False
30
+ )
31
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)
32
+ return result[0][len(text):].lstrip("\n")
33
+
34
+ def parse_segmentation(input_image, input_text):
35
+ out = infer(input_image, input_text, max_new_tokens=200)
36
+ objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
37
+ labels = set(obj.get('name') for obj in objs if obj.get('name'))
38
+ color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
39
+ highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
40
+ annotated_img = (
41
+ input_image,
42
+ [
43
+ (
44
+ obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
45
+ obj['name'] or '',
46
+ )
47
+ for obj in objs
48
+ if 'mask' in obj or 'xyxy' in obj
49
+ ],
50
+ )
51
+ has_annotations = bool(annotated_img[1])
52
+ return annotated_img
53
+
54
+ def _get_params(checkpoint):
55
+ def transp(kernel):
56
+ return np.transpose(kernel, (2, 3, 1, 0))
57
+
58
+ def conv(name):
59
+ return {
60
+ 'bias': checkpoint[name + '.bias'],
61
+ 'kernel': transp(checkpoint[name + '.weight']),
62
+ }
63
+
64
+ def resblock(name):
65
+ return {
66
+ 'Conv_0': conv(name + '.0'),
67
+ 'Conv_1': conv(name + '.2'),
68
+ 'Conv_2': conv(name + '.4'),
69
+ }
70
+
71
+ return {
72
+ '_embeddings': checkpoint['_vq_vae._embedding'],
73
+ 'Conv_0': conv('decoder.0'),
74
+ 'ResBlock_0': resblock('decoder.2.net'),
75
+ 'ResBlock_1': resblock('decoder.3.net'),
76
+ 'ConvTranspose_0': conv('decoder.4'),
77
+ 'ConvTranspose_1': conv('decoder.6'),
78
+ 'ConvTranspose_2': conv('decoder.8'),
79
+ 'ConvTranspose_3': conv('decoder.10'),
80
+ 'Conv_1': conv('decoder.12'),
81
+ }
82
+
83
+ def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
84
+ batch_size, num_tokens = codebook_indices.shape
85
+ assert num_tokens == 16, codebook_indices.shape
86
+ unused_num_embeddings, embedding_dim = embeddings.shape
87
+
88
+ encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
89
+ encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
90
+ return encodings
91
+
92
+ @functools.cache
93
+ def _get_reconstruct_masks():
94
+ class ResBlock(nn.Module):
95
+ features: int
96
+
97
+ @nn.compact
98
+ def __call__(self, x):
99
+ original_x = x
100
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
101
+ x = nn.relu(x)
102
+ x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
103
+ x = nn.relu(x)
104
+ x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
105
+ return x + original_x
106
+
107
+ class Decoder(nn.Module):
108
+ @nn.compact
109
+ def __call__(self, x):
110
+ num_res_blocks = 2
111
+ dim = 128
112
+ num_upsample_layers = 4
113
+
114
+ x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
115
+ x = nn.relu(x)
116
+
117
+ for _ in range(num_res_blocks):
118
+ x = ResBlock(features=dim)(x)
119
+
120
+ for _ in range(num_upsample_layers):
121
+ x = nn.ConvTranspose(
122
+ features=dim,
123
+ kernel_size=(4, 4),
124
+ strides=(2, 2),
125
+ padding=2,
126
+ transpose_kernel=True,
127
+ )(x)
128
+ x = nn.relu(x)
129
+ dim //= 2
130
+
131
+ x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
132
+ return x
133
+
134
+ def reconstruct_masks(codebook_indices):
135
+ quantized = _quantized_values_from_codebook_indices(
136
+ codebook_indices, params['_embeddings']
137
+ )
138
+ return Decoder().apply({'params': params}, quantized)
139
+
140
+ with open(_MODEL_PATH, 'rb') as f:
141
+ params = _get_params(dict(np.load(f)))
142
+
143
+ return jax.jit(reconstruct_masks, backend='cpu')
144
+
145
+ _SEGMENT_DETECT_RE = re.compile(
146
+ r'(.*?)' +
147
+ r'<loc(\d{4})>' * 4 + r'\s*' +
148
+ '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
149
+ r'\s*([^;<>]+)? ?(?:; )?',
150
+ )
151
+
152
+ _MODEL_PATH = 'vae-oid.npz'
153
+
154
+ def extract_objs(text, width, height, unique_labels=False):
155
+ objs = []
156
+ seen = set()
157
+ while text:
158
+ m = _SEGMENT_DETECT_RE.match(text)
159
+ if not m:
160
+ break
161
+
162
+ gs = list(m.groups())
163
+ before = gs.pop(0)
164
+ name = gs.pop()
165
+ y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
166
+
167
+ y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
168
+ seg_indices = gs[4:20]
169
+ if seg_indices[0] is None:
170
+ mask = None
171
+ else:
172
+ seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
173
+ m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
174
+ m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
175
+ m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
176
+ mask = np.zeros([height, width])
177
+ if y2 > y1 and x2 > x1:
178
+ mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
179
+
180
+ content = m.group()
181
+ if before:
182
+ objs.append(dict(content=before))
183
+ content = content[len(before):]
184
+ while unique_labels and name in seen:
185
+ name = (name or '') + "'"
186
+ seen.add(name)
187
+ objs.append(dict(
188
+ content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
189
+ text = text[len(before) + len(content):]
190
+
191
+ if text:
192
+ objs.append(dict(content=text))
193
+
194
+ return objs
195
+
196
+ with gr.Blocks() as demo:
197
+ with gr.Tab("Text Generation"):
198
+ with gr.Row():
199
+ with gr.Column():
200
+ image = gr.Image(type="pil", width=512, height=512)
201
+ text_input = gr.Text(label="Input Text")
202
+ with gr.Column():
203
+ text_output = gr.Text(label="Text Output")
204
+ chat_btn = gr.Button()
205
+ tokens = gr.Slider(
206
+ label="Max New Tokens",
207
+ minimum=10,
208
+ maximum=200,
209
+ value=20,
210
+ step=10,
211
+ )
212
+
213
+ chat_btn.click(
214
+ fn=infer,
215
+ inputs=[image, text_input, tokens],
216
+ outputs=[text_output],
217
+ )
218
+
219
+ with gr.Tab("Segment/Detect"):
220
+ with gr.Row():
221
+ with gr.Column():
222
+ image = gr.Image(type="pil")
223
+ seg_input = gr.Text(label="Entities to Segment/Detect")
224
+ seg_btn = gr.Button("Submit")
225
+ with gr.Column():
226
+ annotated_image = gr.AnnotatedImage(label="Output")
227
+
228
+ seg_btn.click(
229
+ fn=parse_segmentation,
230
+ inputs=[image, seg_input],
231
+ outputs=[annotated_image],
232
+ )
233
+
234
+ if __name__ == "__main__":
235
+ demo.queue(max_size=10).launch(debug=True)