flocolombari commited on
Commit
0b3abbe
·
1 Parent(s): 3d5f136

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -312
app.py CHANGED
@@ -1,313 +1,4 @@
1
- import gradio as gr
2
- import random
3
- import numpy as np
4
- import os
5
- import requests
6
- import torch
7
- import torchvision.transforms as T
8
- from PIL import Image
9
- from transformers import AutoProcessor, AutoModelForVision2Seq
10
- import cv2
11
- import ast
12
 
13
- colors = [
14
- (0, 255, 0),
15
- (0, 0, 255),
16
- (255, 255, 0),
17
- (255, 0, 255),
18
- (0, 255, 255),
19
- (114, 128, 250),
20
- (0, 165, 255),
21
- (0, 128, 0),
22
- (144, 238, 144),
23
- (238, 238, 175),
24
- (255, 191, 0),
25
- (0, 128, 0),
26
- (226, 43, 138),
27
- (255, 0, 255),
28
- (0, 215, 255),
29
- (255, 0, 0),
30
- ]
31
-
32
- color_map = {
33
- f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for color_id, color in enumerate(colors)
34
- }
35
-
36
-
37
- def is_overlapping(rect1, rect2):
38
- x1, y1, x2, y2 = rect1
39
- x3, y3, x4, y4 = rect2
40
- return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
41
-
42
-
43
- def draw_entity_boxes_on_image(image, entities, show=False, save_path=None, entity_index=-1):
44
- """_summary_
45
- Args:
46
- image (_type_): image or image path
47
- collect_entity_location (_type_): _description_
48
- """
49
- if isinstance(image, Image.Image):
50
- image_h = image.height
51
- image_w = image.width
52
- image = np.array(image)[:, :, [2, 1, 0]]
53
- elif isinstance(image, str):
54
- if os.path.exists(image):
55
- pil_img = Image.open(image).convert("RGB")
56
- image = np.array(pil_img)[:, :, [2, 1, 0]]
57
- image_h = pil_img.height
58
- image_w = pil_img.width
59
- else:
60
- raise ValueError(f"invaild image path, {image}")
61
- elif isinstance(image, torch.Tensor):
62
- # pdb.set_trace()
63
- image_tensor = image.cpu()
64
- reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
65
- reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
66
- image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
67
- pil_img = T.ToPILImage()(image_tensor)
68
- image_h = pil_img.height
69
- image_w = pil_img.width
70
- image = np.array(pil_img)[:, :, [2, 1, 0]]
71
- else:
72
- raise ValueError(f"invaild image format, {type(image)} for {image}")
73
-
74
- if len(entities) == 0:
75
- return image
76
-
77
- indices = list(range(len(entities)))
78
- if entity_index >= 0:
79
- indices = [entity_index]
80
-
81
- # Not to show too many bboxes
82
- entities = entities[:len(color_map)]
83
-
84
- new_image = image.copy()
85
- previous_bboxes = []
86
- # size of text
87
- text_size = 1
88
- # thickness of text
89
- text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
90
- box_line = 3
91
- (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
92
- base_height = int(text_height * 0.675)
93
- text_offset_original = text_height - base_height
94
- text_spaces = 3
95
-
96
- # num_bboxes = sum(len(x[-1]) for x in entities)
97
- used_colors = colors # random.sample(colors, k=num_bboxes)
98
-
99
- color_id = -1
100
- for entity_idx, (entity_name, (start, end), bboxes) in enumerate(entities):
101
- color_id += 1
102
- if entity_idx not in indices:
103
- continue
104
- for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
105
- # if start is None and bbox_id > 0:
106
- # color_id += 1
107
- orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm * image_w), int(y1_norm * image_h), int(x2_norm * image_w), int(y2_norm * image_h)
108
-
109
- # draw bbox
110
- # random color
111
- color = used_colors[color_id] # tuple(np.random.randint(0, 255, size=3).tolist())
112
- new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
113
-
114
- l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
115
-
116
- x1 = orig_x1 - l_o
117
- y1 = orig_y1 - l_o
118
-
119
- if y1 < text_height + text_offset_original + 2 * text_spaces:
120
- y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
121
- x1 = orig_x1 + r_o
122
-
123
- # add text background
124
- (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
125
- text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
126
-
127
- for prev_bbox in previous_bboxes:
128
- while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox):
129
- text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
130
- text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
131
- y1 += (text_height + text_offset_original + 2 * text_spaces)
132
-
133
- if text_bg_y2 >= image_h:
134
- text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
135
- text_bg_y2 = image_h
136
- y1 = image_h
137
- break
138
-
139
- alpha = 0.5
140
- for i in range(text_bg_y1, text_bg_y2):
141
- for j in range(text_bg_x1, text_bg_x2):
142
- if i < image_h and j < image_w:
143
- if j < text_bg_x1 + 1.35 * c_width:
144
- # original color
145
- bg_color = color
146
- else:
147
- # white
148
- bg_color = [255, 255, 255]
149
- new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(np.uint8)
150
-
151
- cv2.putText(
152
- new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces), cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
153
- )
154
- # previous_locations.append((x1, y1))
155
- previous_bboxes.append((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2))
156
-
157
- pil_image = Image.fromarray(new_image[:, :, [2, 1, 0]])
158
- if save_path:
159
- pil_image.save(save_path)
160
- if show:
161
- pil_image.show()
162
-
163
- return pil_image
164
-
165
-
166
- def main():
167
-
168
- ckpt = "ydshieh/kosmos-2-patch14-224"
169
-
170
- model = AutoModelForVision2Seq.from_pretrained(ckpt, trust_remote_code=True).to("cuda")
171
- processor = AutoProcessor.from_pretrained(ckpt, trust_remote_code=True)
172
-
173
- def generate_predictions(image_input, text_input):
174
-
175
- # Save the image and load it again to match the original Kosmos-2 demo.
176
- # (https://github.com/microsoft/unilm/blob/f4695ed0244a275201fff00bee495f76670fbe70/kosmos-2/demo/gradio_app.py#L345-L346)
177
- user_image_path = "/tmp/user_input_test_image.jpg"
178
- image_input.save(user_image_path)
179
- # This might give different results from the original argument `image_input`
180
- image_input = Image.open(user_image_path)
181
-
182
- if text_input == "Brief":
183
- text_input = "<grounding>An image of"
184
- elif text_input == "Detailed":
185
- text_input = "<grounding>Describe this image in detail:"
186
- else:
187
- text_input = f"<grounding>{text_input}"
188
-
189
- inputs = processor(text=text_input, images=image_input, return_tensors="pt")
190
-
191
- generated_ids = model.generate(
192
- pixel_values=inputs["pixel_values"].to("cuda"),
193
- input_ids=inputs["input_ids"][:, :-1].to("cuda"),
194
- attention_mask=inputs["attention_mask"][:, :-1].to("cuda"),
195
- img_features=None,
196
- img_attn_mask=inputs["img_attn_mask"][:, :-1].to("cuda"),
197
- use_cache=True,
198
- max_new_tokens=128,
199
- )
200
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
201
-
202
- # By default, the generated text is cleanup and the entities are extracted.
203
- processed_text, entities = processor.post_process_generation(generated_text)
204
-
205
- annotated_image = draw_entity_boxes_on_image(image_input, entities, show=False)
206
-
207
- color_id = -1
208
- entity_info = []
209
- filtered_entities = []
210
- for entity in entities:
211
- entity_name, (start, end), bboxes = entity
212
- if start == end:
213
- # skip bounding bbox without a `phrase` associated
214
- continue
215
- color_id += 1
216
- # for bbox_id, _ in enumerate(bboxes):
217
- # if start is None and bbox_id > 0:
218
- # color_id += 1
219
- entity_info.append(((start, end), color_id))
220
- filtered_entities.append(entity)
221
-
222
- colored_text = []
223
- prev_start = 0
224
- end = 0
225
- for idx, ((start, end), color_id) in enumerate(entity_info):
226
- if start > prev_start:
227
- colored_text.append((processed_text[prev_start:start], None))
228
- colored_text.append((processed_text[start:end], f"{color_id}"))
229
- prev_start = end
230
-
231
- if end < len(processed_text):
232
- colored_text.append((processed_text[end:len(processed_text)], None))
233
-
234
- return annotated_image, colored_text, str(filtered_entities)
235
-
236
- term_of_use = """
237
- ### Terms of use
238
- By using this model, users are required to agree to the following terms:
239
- The model is intended for academic and research purposes.
240
- The utilization of the model to create unsuitable material is strictly forbidden and not endorsed by this work.
241
- The accountability for any improper or unacceptable application of the model rests exclusively with the individuals who generated such content.
242
-
243
- ### License
244
- This project is licensed under the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct).
245
- """
246
-
247
- with gr.Blocks(title="Kosmos-2", theme=gr.themes.Base()).queue() as demo:
248
- gr.Markdown(("""
249
- # Kosmos-2: Grounding Multimodal Large Language Models to the World
250
- [[Paper]](https://arxiv.org/abs/2306.14824) [[Code]](https://github.com/microsoft/unilm/blob/master/kosmos-2)
251
- """))
252
- with gr.Row():
253
- with gr.Column():
254
- image_input = gr.Image(type="pil", label="Test Image")
255
- text_input = gr.Radio(["Brief", "Detailed"], label="Description Type", value="Brief")
256
-
257
- run_button = gr.Button(label="Run", visible=True)
258
-
259
- with gr.Column():
260
- image_output = gr.Image(type="pil")
261
- text_output1 = gr.HighlightedText(
262
- label="Generated Description",
263
- combine_adjacent=False,
264
- show_legend=True,
265
- ).style(color_map=color_map)
266
-
267
- with gr.Row():
268
- with gr.Column():
269
- gr.Examples(examples=[
270
- ["images/two_dogs.jpg", "Detailed"],
271
- ["images/snowman.png", "Brief"],
272
- ["images/man_ball.png", "Detailed"],
273
- ], inputs=[image_input, text_input])
274
- with gr.Column():
275
- gr.Examples(examples=[
276
- ["images/six_planes.png", "Brief"],
277
- ["images/quadrocopter.jpg", "Brief"],
278
- ["images/carnaby_street.jpg", "Brief"],
279
- ], inputs=[image_input, text_input])
280
- gr.Markdown(term_of_use)
281
-
282
- # record which text span (label) is selected
283
- selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
284
-
285
- # record the current `entities`
286
- entity_output = gr.Textbox(visible=False)
287
-
288
- # get the current selected span label
289
- def get_text_span_label(evt: gr.SelectData):
290
- if evt.value[-1] is None:
291
- return -1
292
- return int(evt.value[-1])
293
- # and set this information to `selected`
294
- text_output1.select(get_text_span_label, None, selected)
295
-
296
- # update output image when we change the span (enity) selection
297
- def update_output_image(img_input, image_output, entities, idx):
298
- entities = ast.literal_eval(entities)
299
- updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
300
- return updated_image
301
- selected.change(update_output_image, [image_input, image_output, entity_output, selected], [image_output])
302
-
303
- run_button.click(fn=generate_predictions,
304
- inputs=[image_input, text_input],
305
- outputs=[image_output, text_output1, entity_output],
306
- show_progress=True, queue=True)
307
-
308
- demo.launch(share=False)
309
-
310
-
311
- if __name__ == "__main__":
312
- main()
313
- # trigger
 
1
+ # Use a pipeline as a high-level helper
2
+ from transformers import pipeline
 
 
 
 
 
 
 
 
 
3
 
4
+ pipe = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")