shenyunhang commited on
Commit
ea705ee
·
verified ·
1 Parent(s): c4f92fb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1002 -0
app.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+
5
+
6
+
7
+
8
+
9
+
10
+ ##############################################################
11
+ # copy from cognitron_vl/constants.py
12
+ ##############################################################
13
+ import logging
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ if True:
18
+ IMG_TAG_TOKEN = "<image>"
19
+ VID_TAG_TOKEN = "<video>"
20
+ AUD_TAG_TOKEN = "<audio>"
21
+
22
+ IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
23
+ IMG_START_TOKEN = '<img>'
24
+ IMG_END_TOKEN = '</img>'
25
+
26
+ VID_CONTEXT_TOKEN = '<VID_CONTEXT>'
27
+ VID_START_TOKEN = '<vid>'
28
+ VID_END_TOKEN = '</vid>'
29
+
30
+ PATCH_CONTEXT_TOKEN = '<PATCH_CONTEXT>'
31
+ PATCH_START_TOKEN = '<patch>'
32
+ PATCH_END_TOKEN = '</patch>'
33
+
34
+ AUD_START_TOKEN = '<|begin_of_audio|>'
35
+ AUD_END_TOKEN = '<|end_of_audio|>'
36
+
37
+ QUAD_START_TOKEN = '<quad>'
38
+ QUAD_END_TOKEN = '</quad>'
39
+ REF_START_TOKEN = '<ref>'
40
+ REF_END_TOKEN = '</ref>'
41
+ BOX_START_TOKEN = '<box>'
42
+ BOX_END_TOKEN = '</box>'
43
+
44
+
45
+ if False:
46
+ IMG_TAG_TOKEN = "<|image|>"
47
+ VID_TAG_TOKEN = "<|video|>"
48
+ AUD_TAG_TOKEN = "<|audio|>"
49
+
50
+ IMG_CONTEXT_TOKEN = '<|context_of_image|>'
51
+ IMG_START_TOKEN = '<|begin_of_image|>'
52
+ IMG_END_TOKEN = '<|end_of_image|>'
53
+
54
+ VID_CONTEXT_TOKEN = '<|context_of_video|>'
55
+ VID_START_TOKEN = '<|begin_of_video|>'
56
+ VID_END_TOKEN = '<|end_of_video|>'
57
+
58
+ PATCH_CONTEXT_TOKEN = '<|context_of_patch|>'
59
+ PATCH_START_TOKEN = '<|begin_of_patch|>'
60
+ PATCH_END_TOKEN = '<|end_of_patch|>'
61
+
62
+ AUD_START_TOKEN = '<|begin_of_audio|>'
63
+ AUD_END_TOKEN = '<|end_of_audio|>'
64
+
65
+ QUAD_START_TOKEN = '<|begin_of_quad|>'
66
+ QUAD_END_TOKEN = '<|end_of_quad|>'
67
+ REF_START_TOKEN = '<|begin_of_ref|>'
68
+ REF_END_TOKEN = '<|end_of_ref|>'
69
+ BOX_START_TOKEN = '<|begin_of_box|>'
70
+ BOX_END_TOKEN = '<|end_of_box|>'
71
+
72
+ logger.info(f"IMG_TAG_TOKEN {IMG_TAG_TOKEN}")
73
+ logger.info(f"VID_TAG_TOKEN {VID_TAG_TOKEN}")
74
+ logger.info(f"AUD_TAG_TOKEN {AUD_TAG_TOKEN}")
75
+ logger.info(f"IMG_CONTEXT_TOKEN {IMG_CONTEXT_TOKEN}")
76
+ logger.info(f"IMG_START_TOKEN {IMG_START_TOKEN}")
77
+ logger.info(f"IMG_END_TOKEN {IMG_END_TOKEN}")
78
+ logger.info(f"VID_CONTEXT_TOKEN {VID_CONTEXT_TOKEN}")
79
+ logger.info(f"VID_START_TOKEN {VID_START_TOKEN}")
80
+ logger.info(f"VID_END_TOKEN {VID_END_TOKEN}")
81
+ logger.info(f"PATCH_CONTEXT_TOKEN {PATCH_CONTEXT_TOKEN}")
82
+ logger.info(f"PATCH_START_TOKEN {PATCH_START_TOKEN}")
83
+ logger.info(f"PATCH_END_TOKEN {PATCH_END_TOKEN}")
84
+ logger.info(f"AUD_START_TOKEN {AUD_START_TOKEN}")
85
+ logger.info(f"AUD_END_TOKEN {AUD_END_TOKEN}")
86
+
87
+ # IMAGENET_MEAN = (0.485, 0.456, 0.406)
88
+ # IMAGENET_STD = (0.229, 0.224, 0.225)
89
+
90
+ # CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)
91
+ # CLIP_STD = (0.2686295, 0.2613025, 0.2757711)
92
+
93
+ # SIGLIP_MEAN = (0.5, 0.5, 0.5)
94
+ # SIGLIP_STD = (0.5, 0.5, 0.5)
95
+
96
+
97
+ IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
98
+ IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
99
+ IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
100
+ IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
101
+ OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
102
+ OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
103
+
104
+
105
+
106
+ # Model Constants
107
+ IGNORE_INDEX = -100
108
+ IMAGE_TOKEN_INDEX = -200
109
+ DEFAULT_IMAGE_TOKEN = IMG_CONTEXT_TOKEN
110
+ DEFAULT_IMAGE_PATCH_TOKEN = PATCH_CONTEXT_TOKEN
111
+ DEFAULT_IM_START_TOKEN = IMG_START_TOKEN
112
+ DEFAULT_IM_END_TOKEN = IMG_END_TOKEN
113
+
114
+
115
+ ##############################################################
116
+
117
+ ##############################################################
118
+ # copy from cognitron_vl/data/processor/image_processor.py
119
+ ##############################################################
120
+ import math
121
+ import os
122
+
123
+ import cv2
124
+ import natsort
125
+ import numpy as np
126
+ import torch
127
+ from PIL import Image
128
+
129
+ import decord
130
+ # from cognitron_vl.constants import (
131
+ # IMAGENET_DEFAULT_MEAN,
132
+ # IMAGENET_DEFAULT_STD,
133
+ # IMAGENET_STANDARD_MEAN,
134
+ # IMAGENET_STANDARD_STD,
135
+ # OPENAI_CLIP_MEAN,
136
+ # OPENAI_CLIP_STD,
137
+ # )
138
+
139
+
140
+ class ImageProcessor:
141
+ def __init__(
142
+ self,
143
+ process_type,
144
+ image_size=448,
145
+ normalize_type="imagenet",
146
+ min_patch_grid=1,
147
+ max_patch_grid=6,
148
+ ):
149
+ self.process_type = process_type
150
+ self.image_size = image_size
151
+
152
+ if normalize_type == "imagenet":
153
+ MEAN, STD = IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
154
+ elif normalize_type == "clip":
155
+ MEAN, STD = OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
156
+ elif normalize_type == "siglip":
157
+ MEAN, STD = IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
158
+ else:
159
+ raise NotImplementedError
160
+ self.mean = MEAN
161
+ self.std = STD
162
+
163
+ self.patch_size = image_size
164
+ self.min_patch_grid = min_patch_grid
165
+ self.max_patch_grid = max_patch_grid
166
+
167
+ if self.process_type == "anyres":
168
+ self.grid_pinpoints = [
169
+ (i, j)
170
+ for i in range(min_patch_grid, max_patch_grid + 1)
171
+ for j in range(min_patch_grid, max_patch_grid + 1)
172
+ ]
173
+ self.possible_resolutions = [
174
+ [dim * self.patch_size for dim in pair] for pair in self.grid_pinpoints
175
+ ]
176
+ print(f"grid_pinpoints {self.grid_pinpoints}")
177
+ print(f"possible_resolutions {self.possible_resolutions}")
178
+
179
+ if self.process_type == "dynamic":
180
+ max_num = self.max_patch_grid
181
+ min_num = self.min_patch_grid
182
+ # calculate the existing image aspect ratio
183
+ target_ratios = set(
184
+ (i, j)
185
+ for n in range(min_num, max_num + 1)
186
+ for i in range(1, n + 1)
187
+ for j in range(1, n + 1)
188
+ if i * j <= max_num and i * j >= min_num
189
+ )
190
+ self.target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
191
+ self.possible_resolutions = [
192
+ [dim * self.patch_size for dim in pair] for pair in self.target_ratios
193
+ ]
194
+ print(f"target_ratios {self.target_ratios}")
195
+ print(f"possible_resolutions {self.possible_resolutions}")
196
+
197
+ def get_frame_paths(self, frame_root, num_frames=8):
198
+ os.makedirs(frame_root, exist_ok=True)
199
+
200
+ self.frame_tmpl = "frame-{}-of-{}.jpg"
201
+ return [
202
+ os.path.join(frame_root, self.frame_tmpl.format(i, num_frames))
203
+ for i in range(1, num_frames + 1)
204
+ ]
205
+
206
+ def save_video_frames(self, vid_path, max_fps=1, num_frames=8):
207
+
208
+ vid = decord.VideoReader(vid_path, num_threads=1)
209
+
210
+ step_size = len(vid) / (num_frames + 1)
211
+ # step_size = max(1, step_size)
212
+ fps = vid.get_avg_fps()
213
+ step_size = max(fps / max_fps, step_size)
214
+
215
+ # indices = [int(i * step_size) for i in range(1, num_frames + 1)]
216
+ indices = [int(i * step_size) for i in range(0, num_frames)]
217
+ indices = [i for i in indices if i < len(vid)]
218
+
219
+ num_frames = len(indices)
220
+
221
+ frame_paths = self.get_frame_paths(vid_path + ".saved_frames", num_frames)
222
+ flag = np.all([os.path.exists(p) for p in frame_paths])
223
+ if flag:
224
+ return frame_paths
225
+
226
+ images = [vid[i].asnumpy() for i in indices]
227
+ images = [Image.fromarray(arr) for arr in images]
228
+
229
+ for im, pth in zip(images, frame_paths):
230
+ # if not os.path.exists(pth):
231
+ # im.save(pth)
232
+ im.save(pth)
233
+ # print(f"save_video_frames vid_path {vid_path} fps {fps} len(vid) {len(vid)} frame_paths {frame_paths}")
234
+ return frame_paths
235
+
236
+ def get_video_frames(self, vid_path, max_fps=1, num_frames=8):
237
+
238
+ vid = decord.VideoReader(vid_path, num_threads=1)
239
+
240
+ step_size = len(vid) / (num_frames + 1)
241
+ # step_size = max(1, step_size)
242
+ fps = vid.get_avg_fps()
243
+ step_size = max(fps / max_fps, step_size)
244
+
245
+ # indices = [int(i * step_size) for i in range(1, num_frames + 1)]
246
+ indices = [int(i * step_size) for i in range(0, num_frames)]
247
+ indices = [i for i in indices if i < len(vid)]
248
+
249
+ images = [vid[i].asnumpy() for i in indices]
250
+ images = [Image.fromarray(arr) for arr in images]
251
+
252
+ # print(f"save_video_frames vid_path {vid_path} fps {fps} len(vid) {len(vid)} frame_paths {frame_paths}")
253
+ return images
254
+
255
+ def process_video(self, video_file_or_dir, max_num_frame=8, max_fps=1):
256
+ if os.path.isdir(video_file_or_dir):
257
+ all_filepath = []
258
+ for root, dirs, files in os.walk(video_file_or_dir):
259
+ for filename in files:
260
+ if (
261
+ filename.endswith("png")
262
+ or filename.endswith("jpeg")
263
+ or filename.endswith("jpg")
264
+ ):
265
+ filepath = os.path.join(root, filename)
266
+ all_filepath.append(filepath)
267
+
268
+ if len(all_filepath) == 0:
269
+ return None
270
+
271
+ # all_filepath.sort()
272
+ all_filepath = natsort.natsorted(all_filepath)
273
+ total_frame = len(all_filepath)
274
+ if "ShareGPTVideo" in video_file_or_dir:
275
+ fps = 2
276
+ else:
277
+ fps = 1
278
+ target_frame = int(min(total_frame / fps * max_fps, max_num_frame))
279
+ index = [int(1.0 * total_frame / target_frame) * x for x in range(target_frame)]
280
+
281
+ selected_filepath = [all_filepath[x] for x in index]
282
+
283
+ img_or_path_list = selected_filepath
284
+ # print(f"process_video {img_or_path_list}")
285
+ elif os.path.isfile(video_file_or_dir):
286
+ # frame_paths = self.save_video_frames(
287
+ # video_file_or_dir, num_frames=max_num_frame, max_fps=max_fps
288
+ # )
289
+ # img_or_path_list = frame_paths
290
+ img_or_path_list = self.get_video_frames(
291
+ video_file_or_dir, num_frames=max_num_frame, max_fps=max_fps
292
+ )
293
+ else:
294
+ # print(f"FileNotFoundError {video_file_or_dir}")
295
+ raise NotImplementedError
296
+
297
+ return self.process_images(img_or_path_list), img_or_path_list
298
+
299
+ def process_images(self, img_or_path_list):
300
+
301
+ if isinstance(img_or_path_list[0], str):
302
+ images = [Image.open(x).convert("RGB") for x in img_or_path_list]
303
+ elif isinstance(img_or_path_list[0], Image.Image):
304
+ images = [x.convert("RGB") for x in img_or_path_list]
305
+ else:
306
+ images = img_or_path_list
307
+
308
+ def expand2square(pil_img, background_color):
309
+ width, height = pil_img.size
310
+ if width == height:
311
+ return pil_img
312
+ elif width > height:
313
+ result = Image.new(pil_img.mode, (width, width), background_color)
314
+ result.paste(pil_img, (0, (width - height) // 2))
315
+ return result
316
+ else:
317
+ result = Image.new(pil_img.mode, (height, height), background_color)
318
+ result.paste(pil_img, ((height - width) // 2, 0))
319
+ return result
320
+
321
+ image_tensor = torch.ones([len(images), 3, self.image_size, self.image_size])
322
+
323
+ for i, image in enumerate(images):
324
+ image = expand2square(image, tuple(int(x * 255) for x in self.mean))
325
+
326
+ image = image.resize(
327
+ (self.image_size, self.image_size), resample=Image.Resampling.BICUBIC
328
+ )
329
+
330
+ image = np.array(image, dtype=np.float32)
331
+ image = image * 1.0 / 255.0
332
+
333
+ mean = np.array(self.mean, dtype=image.dtype)
334
+ std = np.array(self.std, dtype=image.dtype)
335
+ image = (image - mean) / std
336
+
337
+ image = torch.tensor(image, dtype=torch.float32)
338
+ image = image.permute(2, 0, 1)
339
+
340
+ image_tensor[i] = image
341
+
342
+ return image_tensor
343
+
344
+ def process_images_with_subpatch(self, img_or_path):
345
+ if self.process_type == "anyres":
346
+ return self.process_anyres(img_or_path)
347
+ if self.process_type == "dynamic":
348
+ return self.process_dynamic(img_or_path)
349
+
350
+ if isinstance(img_or_path, str):
351
+ image = Image.open(img_or_path).convert("RGB")
352
+ elif isinstance(img_or_path, Image.Image):
353
+ image = img_or_path.convert("RGB")
354
+ else:
355
+ image = img_or_path
356
+
357
+ return self.process_images([images])
358
+
359
+ def process_anyres(self, img_or_path):
360
+ if isinstance(img_or_path, str):
361
+ image = Image.open(img_or_path).convert("RGB")
362
+ elif isinstance(img_or_path, Image.Image):
363
+ image = img_or_path.convert("RGB")
364
+ else:
365
+ image = img_or_path
366
+
367
+ best_resolution = select_best_resolution(image.size, self.possible_resolutions)
368
+ image_padded = resize_and_pad_image(image, best_resolution)
369
+ patches = divide_to_patches(image_padded, self.patch_size)
370
+
371
+ if best_resolution == (self.patch_size, self.patch_size):
372
+ image_patches = [image]
373
+ else:
374
+ image_patches = [image] + patches
375
+
376
+ image_patches = self.process_images(image_patches)
377
+
378
+ # print(f"image {image.size} best_resolution {best_resolution} image_padded {image_padded.size} patches {len(patches)} image_patches {image_patches.size()}")
379
+
380
+ return image_patches, best_resolution
381
+
382
+ def process_dynamic(self, img_or_path):
383
+ if isinstance(img_or_path, str):
384
+ image = Image.open(img_or_path).convert("RGB")
385
+ elif isinstance(img_or_path, Image.Image):
386
+ image = img_or_path.convert("RGB")
387
+ else:
388
+ image = img_or_path
389
+
390
+ image_patches, best_resolution = dynamic_preprocess(
391
+ image,
392
+ min_num=self.min_patch_grid,
393
+ max_num=self.max_patch_grid,
394
+ image_size=self.patch_size,
395
+ use_thumbnail=True,
396
+ )
397
+
398
+ image_patches = self.process_images(image_patches)
399
+
400
+ # print(f"image {image.size} best_resolution {best_resolution} image_padded {image_padded.size} patches {len(patches)} image_patches {image_patches.size()}")
401
+
402
+ return image_patches, best_resolution
403
+
404
+
405
+ def select_best_resolution(original_size, possible_resolutions):
406
+ """
407
+ Selects the best resolution from a list of possible resolutions based on the original size.
408
+
409
+ Args:
410
+ original_size (tuple): The original size of the image in the format (width, height).
411
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
412
+
413
+ Returns:
414
+ tuple: The best fit resolution in the format (width, height).
415
+ """
416
+ original_width, original_height = original_size
417
+ best_fit = None
418
+ max_effective_resolution = 0
419
+ min_wasted_resolution = float("inf")
420
+
421
+ for width, height in possible_resolutions:
422
+ # Calculate the downscaled size to keep the aspect ratio
423
+ scale = min(width / original_width, height / original_height)
424
+ downscaled_width, downscaled_height = int(original_width * scale), int(
425
+ original_height * scale
426
+ )
427
+
428
+ # Calculate effective and wasted resolutions
429
+ effective_resolution = min(
430
+ downscaled_width * downscaled_height, original_width * original_height
431
+ )
432
+ wasted_resolution = (width * height) - effective_resolution
433
+
434
+ if effective_resolution > max_effective_resolution or (
435
+ effective_resolution == max_effective_resolution
436
+ and wasted_resolution < min_wasted_resolution
437
+ ):
438
+ max_effective_resolution = effective_resolution
439
+ min_wasted_resolution = wasted_resolution
440
+ best_fit = (width, height)
441
+
442
+ return best_fit
443
+
444
+
445
+ def resize_and_pad_image(image, target_resolution):
446
+ """
447
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
448
+
449
+ Args:
450
+ image (PIL.Image.Image): The input image.
451
+ target_resolution (tuple): The target resolution (width, height) of the image.
452
+
453
+ Returns:
454
+ PIL.Image.Image: The resized and padded image.
455
+ """
456
+ original_width, original_height = image.size
457
+ target_width, target_height = target_resolution
458
+
459
+ # Determine which dimension (width or height) to fill
460
+ scale_w = target_width / original_width
461
+ scale_h = target_height / original_height
462
+
463
+ if scale_w < scale_h:
464
+ # Width will be filled completely
465
+ new_width = target_width
466
+ new_height = min(math.ceil(original_height * scale_w), target_height)
467
+ else:
468
+ # Height will be filled completely
469
+ new_height = target_height
470
+ new_width = min(math.ceil(original_width * scale_h), target_width)
471
+
472
+ # Resize the image
473
+ resized_image = image.resize((new_width, new_height))
474
+
475
+ # Create a new image with the target size and paste the resized image onto it
476
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
477
+ paste_x = (target_width - new_width) // 2
478
+ paste_y = (target_height - new_height) // 2
479
+ new_image.paste(resized_image, (paste_x, paste_y))
480
+
481
+ return new_image
482
+
483
+
484
+ def divide_to_patches(image, patch_size):
485
+ """
486
+ Divides an image into patches of a specified size.
487
+
488
+ Args:
489
+ image (PIL.Image.Image): The input image.
490
+ patch_size (int): The size of each patch.
491
+
492
+ Returns:
493
+ list: A list of PIL.Image.Image objects representing the patches.
494
+ """
495
+ patches = []
496
+ width, height = image.size
497
+ for i in range(0, height, patch_size):
498
+ for j in range(0, width, patch_size):
499
+ box = (j, i, j + patch_size, i + patch_size)
500
+ patch = image.crop(box)
501
+ patches.append(patch)
502
+
503
+ return patches
504
+
505
+
506
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
507
+ best_ratio_diff = float("inf")
508
+ best_ratio = (1, 1)
509
+ area = width * height
510
+ for ratio in target_ratios:
511
+ target_aspect_ratio = ratio[0] / ratio[1]
512
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
513
+ if ratio_diff < best_ratio_diff:
514
+ best_ratio_diff = ratio_diff
515
+ best_ratio = ratio
516
+ elif ratio_diff == best_ratio_diff:
517
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
518
+ best_ratio = ratio
519
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
520
+ return best_ratio
521
+
522
+
523
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
524
+ orig_width, orig_height = image.size
525
+ aspect_ratio = orig_width / orig_height
526
+
527
+ # calculate the existing image aspect ratio
528
+ target_ratios = set(
529
+ (i, j)
530
+ for n in range(min_num, max_num + 1)
531
+ for i in range(1, n + 1)
532
+ for j in range(1, n + 1)
533
+ if i * j <= max_num and i * j >= min_num
534
+ )
535
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
536
+
537
+ # find the closest aspect ratio to the target
538
+ target_aspect_ratio = find_closest_aspect_ratio(
539
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
540
+ )
541
+
542
+ # calculate the target width and height
543
+ target_width = image_size * target_aspect_ratio[0]
544
+ target_height = image_size * target_aspect_ratio[1]
545
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
546
+
547
+ # resize the image
548
+ resized_img = image.resize((target_width, target_height))
549
+ processed_images = []
550
+ for i in range(blocks):
551
+ box = (
552
+ (i % (target_width // image_size)) * image_size,
553
+ (i // (target_width // image_size)) * image_size,
554
+ ((i % (target_width // image_size)) + 1) * image_size,
555
+ ((i // (target_width // image_size)) + 1) * image_size,
556
+ )
557
+ # split the image
558
+ split_img = resized_img.crop(box)
559
+ processed_images.append(split_img)
560
+ assert len(processed_images) == blocks
561
+ if use_thumbnail and len(processed_images) != 1:
562
+ thumbnail_img = image.resize((image_size, image_size))
563
+ # processed_images.append(thumbnail_img)
564
+ processed_images = [
565
+ thumbnail_img,
566
+ ] + processed_images
567
+ return processed_images, (target_width, target_height)
568
+
569
+
570
+ ##############################################################
571
+
572
+ ##############################################################
573
+ # modify from long_vita_megatron/tasks/inference/module.py
574
+ ##############################################################
575
+ def get_external_inputs(tokens, image_list=None, image_path_list=None, video_path_list=None):
576
+ print(f"get_external_inputs tokens {tokens.size()}")
577
+ tokens = tokens.tolist()
578
+
579
+ image_token_length = 256
580
+ max_num_frame = 4096
581
+ max_fps = 1
582
+
583
+ # from cognitron_vl.constants import IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN, VID_START_TOKEN, VID_END_TOKEN, VID_CONTEXT_TOKEN, PATCH_START_TOKEN, PATCH_END_TOKEN, PATCH_CONTEXT_TOKEN, IMG_TAG_TOKEN, VID_TAG_TOKEN
584
+ image_tag = "<image>"
585
+ video_tag = "<video>"
586
+
587
+ IMG_CONTEXT_ID = tokenizer(IMG_CONTEXT_TOKEN, add_special_tokens=False).input_ids
588
+ IMG_START_ID = tokenizer(IMG_START_TOKEN, add_special_tokens=False).input_ids
589
+ IMG_END_ID = tokenizer(IMG_END_TOKEN, add_special_tokens=False).input_ids
590
+
591
+ VID_CONTEXT_ID = tokenizer(VID_CONTEXT_TOKEN, add_special_tokens=False).input_ids
592
+ VID_START_ID = tokenizer(VID_START_TOKEN, add_special_tokens=False).input_ids
593
+ VID_END_ID = tokenizer(VID_END_TOKEN, add_special_tokens=False).input_ids
594
+
595
+ PATCH_CONTEXT_ID = tokenizer(PATCH_CONTEXT_TOKEN, add_special_tokens=False).input_ids
596
+ PATCH_START_ID = tokenizer(PATCH_START_TOKEN, add_special_tokens=False).input_ids
597
+ PATCH_END_ID = tokenizer(PATCH_END_TOKEN, add_special_tokens=False).input_ids
598
+
599
+ IMG_TAG_ID = tokenizer(IMG_TAG_TOKEN, add_special_tokens=False).input_ids
600
+ VID_TAG_ID = tokenizer(VID_TAG_TOKEN, add_special_tokens=False).input_ids
601
+
602
+ assert len(IMG_CONTEXT_ID) == 1
603
+ assert len(IMG_START_ID) == 1
604
+ assert len(IMG_END_ID) == 1
605
+
606
+ assert len(VID_CONTEXT_ID) == 1
607
+ assert len(VID_START_ID) == 1
608
+ assert len(VID_END_ID) == 1
609
+
610
+ assert len(PATCH_CONTEXT_ID) == 1
611
+ assert len(PATCH_START_ID) == 1
612
+ assert len(PATCH_END_ID) == 1
613
+
614
+ IMG_CONTEXT_ID = IMG_CONTEXT_ID[0]
615
+ IMG_START_ID = IMG_START_ID[0]
616
+ IMG_END_ID = IMG_END_ID[0]
617
+
618
+ VID_CONTEXT_ID = VID_CONTEXT_ID[0]
619
+ VID_START_ID = VID_START_ID[0]
620
+ VID_END_ID = VID_END_ID[0]
621
+
622
+ PATCH_CONTEXT_ID = PATCH_CONTEXT_ID[0]
623
+ PATCH_START_ID = PATCH_START_ID[0]
624
+ PATCH_END_ID = PATCH_END_ID[0]
625
+
626
+ IMG_TAG_ID = IMG_TAG_ID[0]
627
+ VID_TAG_ID = VID_TAG_ID[0]
628
+
629
+ nl_tokens = tokenizer("\n", add_special_tokens=False).input_ids
630
+
631
+ image_indices = []
632
+ images = []
633
+
634
+ # ----------------------------------------------------------------
635
+ # image
636
+ for batch_idx, input_ids in enumerate(tokens):
637
+ # img_positions = [i for i, x in enumerate(input_ids) if x == IMG_CONTEXT_ID]
638
+ img_positions = [i for i, x in enumerate(input_ids) if x == IMG_TAG_ID]
639
+ if len(img_positions) == 0:
640
+ continue
641
+ if image_path_list is not None:
642
+ assert len(img_positions) == len(image_path_list), f"{img_positions} {image_path_list} {IMG_CONTEXT_TOKEN} {IMG_CONTEXT_ID} {tokens}"
643
+ if image_list is not None:
644
+ assert len(img_positions) == len(image_list), f"{img_positions} {image_list} {IMG_CONTEXT_TOKEN} {IMG_CONTEXT_ID} {tokens}"
645
+
646
+ new_input_ids = []
647
+ st = 0
648
+ for img_idx, img_pos in enumerate(img_positions):
649
+ if image_path_list is not None:
650
+ image_patches, (best_width, best_height) = image_processor.process_images_with_subpatch(image_path_list[img_idx])
651
+ if image_list is not None:
652
+ image_patches, (best_width, best_height) = image_processor.process_images_with_subpatch(image_list[img_idx])
653
+ images.append(image_patches)
654
+ print(f"get_external_inputs best_width {best_width} best_height {best_height}")
655
+
656
+ new_input_ids += input_ids[st:img_pos]
657
+
658
+ new_input_ids += [IMG_START_ID]
659
+
660
+ image_indice_b = torch.zeros(
661
+ 1, image_token_length, dtype=torch.int64
662
+ ) # This will change in collate_fn
663
+ image_indice_s = (
664
+ torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length)
665
+ .unsqueeze(0)
666
+ .repeat(1, 1)
667
+ )
668
+ image_indice_b_s = torch.stack(
669
+ [image_indice_b, image_indice_s], dim=0
670
+ ) # 2, num_image, image_length
671
+ image_indices.append(image_indice_b_s)
672
+
673
+ new_input_ids += [IMG_CONTEXT_ID] * image_token_length
674
+
675
+ new_input_ids += [IMG_END_ID]
676
+
677
+ if len(image_patches) > 1:
678
+ for i in range(0, best_height, image_processor.patch_size):
679
+ new_input_ids += nl_tokens
680
+
681
+ for j in range(0, best_width, image_processor.patch_size):
682
+ new_input_ids += [PATCH_START_ID]
683
+
684
+ image_indice_b = torch.zeros(
685
+ 1, image_token_length, dtype=torch.int64
686
+ ) # This will change in collate_fn
687
+ image_indice_s = (
688
+ torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length)
689
+ .unsqueeze(0)
690
+ .repeat(1, 1)
691
+ )
692
+ image_indice_b_s = torch.stack(
693
+ [image_indice_b, image_indice_s], dim=0
694
+ ) # 2, num_image, image_length
695
+ image_indices.append(image_indice_b_s)
696
+
697
+ new_input_ids += [PATCH_CONTEXT_ID] * image_token_length
698
+
699
+ new_input_ids += [PATCH_END_ID]
700
+ # print(f"get_external_dict i {i} j {j} new_input_ids {len(new_input_ids)}")
701
+
702
+ st = img_pos + 1
703
+
704
+ new_input_ids += input_ids[st:]
705
+
706
+ input_ids = new_input_ids
707
+ tokens[batch_idx] = input_ids
708
+
709
+ # ----------------------------------------------------------------
710
+ # video
711
+ for batch_idx, input_ids in enumerate(tokens):
712
+ # vid_positions = [i for i, x in enumerate(input_ids) if x == VID_CONTEXT_ID]
713
+ vid_positions = [i for i, x in enumerate(input_ids) if x == VID_TAG_ID]
714
+ if len(vid_positions) == 0:
715
+ continue
716
+ if video_path_list is not None:
717
+ assert len(vid_positions) == len(video_path_list), f"{vid_positions} {video_path_list} {VID_CONTEXT_TOKEN} {VID_CONTEXT_ID} {tokens}"
718
+ if image_path_list is not None:
719
+ assert len(vid_positions) == len(image_path_list), f"{vid_positions} {image_path_list} {VID_CONTEXT_TOKEN} {VID_CONTEXT_ID} {tokens}"
720
+ if image_list is not None:
721
+ assert len(vid_positions) == len(image_list), f"{vid_positions} {image_list} {VID_CONTEXT_TOKEN} {VID_CONTEXT_ID} {tokens}"
722
+
723
+ new_input_ids = []
724
+ st = 0
725
+ for vid_idx, vid_pos in enumerate(vid_positions):
726
+ if video_path_list is not None:
727
+ video_frames, _ = image_processor.process_video(video_path_list[vid_idx], max_num_frame, max_fps)
728
+ if image_path_list is not None:
729
+ video_frames = image_processor.process_images([image_path_list[vid_idx]])
730
+ if image_list is not None:
731
+ video_frames = image_processor.process_images([image_list[vid_idx]])
732
+
733
+ images.append(video_frames)
734
+
735
+ new_input_ids += input_ids[st:vid_pos]
736
+
737
+ for _ in video_frames:
738
+ new_input_ids += [VID_START_ID]
739
+
740
+ image_indice_b = torch.zeros(
741
+ 1, image_token_length, dtype=torch.int64
742
+ ) # This will change in collate_fn
743
+ image_indice_s = (
744
+ torch.arange(len(new_input_ids), len(new_input_ids) + image_token_length)
745
+ .unsqueeze(0)
746
+ .repeat(1, 1)
747
+ )
748
+ image_indice_b_s = torch.stack(
749
+ [image_indice_b, image_indice_s], dim=0
750
+ ) # 2, num_image, image_length
751
+ image_indices.append(image_indice_b_s)
752
+
753
+ new_input_ids += [VID_CONTEXT_ID] * image_token_length
754
+
755
+ new_input_ids += [VID_END_ID]
756
+
757
+ st = vid_pos + 1
758
+
759
+ new_input_ids += input_ids[st:]
760
+
761
+ input_ids = new_input_ids
762
+ tokens[batch_idx] = input_ids
763
+
764
+ if len(images) > 0:
765
+ images = torch.cat(images, dim=0)
766
+ image_indices = torch.cat(image_indices, dim=1)
767
+
768
+ image_indices = image_indices.contiguous().to(torch.cuda.current_device())
769
+ if True:
770
+ images = torch.tensor(images, dtype=torch.bfloat16).contiguous().to(torch.cuda.current_device())
771
+
772
+ else:
773
+ images = torch.tensor(images, dtype=torch.float16).contiguous().to(torch.cuda.current_device())
774
+
775
+ print(f"get_external_inputs images {images.size()}")
776
+ print(f"get_external_inputs image_indices {image_indices.size()}")
777
+
778
+ else:
779
+ images = None
780
+ image_indices = None
781
+
782
+ print(f"get_external_inputs images {images}")
783
+ print(f"get_external_inputs image_indices {image_indices}")
784
+
785
+ tokens = torch.tensor(tokens, dtype=torch.long, device='cuda')
786
+
787
+ print(f"get_external_inputs tokens {tokens.size()}")
788
+
789
+ return tokens, images, image_indices
790
+
791
+ ##############################################################
792
+
793
+ from transformers import AutoModelForCausalLM, AutoTokenizer
794
+ from transformers.generation import GenerationConfig
795
+ import torch
796
+
797
+ import importlib
798
+ if importlib.util.find_spec("torch_npu") is not None:
799
+ print("Loading torch_npu")
800
+ import torch_npu
801
+ from torch_npu.contrib import transfer_to_npu
802
+ # torch.npu.set_compile_mode(jit_compile=True)
803
+
804
+
805
+ import sys
806
+ import os
807
+ import natsort
808
+
809
+ torch.manual_seed(1234)
810
+
811
+ model_name_or_path = "VITA-MLLM/Long-VITA-128K_HF"
812
+
813
+ device_map = "auto"
814
+ # device_map = "npu:0"
815
+ # torch_dtype=torch.float16
816
+ torch_dtype=torch.bfloat16
817
+ # torch_dtype=torch.float32
818
+
819
+ tokenizer = AutoTokenizer.from_pretrained(
820
+ model_name_or_path,
821
+ trust_remote_code=True
822
+ )
823
+ print("tokenizer", tokenizer)
824
+
825
+ model = AutoModelForCausalLM.from_pretrained(
826
+ model_name_or_path,
827
+ trust_remote_code=True,
828
+ device_map=device_map,
829
+ torch_dtype=torch_dtype,
830
+ attn_implementation="flash_attention_2",
831
+ ).eval()
832
+ # print("model", model)
833
+
834
+ model.generation_config = GenerationConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
835
+
836
+ model.generation_config.max_new_tokens = 1024
837
+ model.generation_config.chat_format = "chatml"
838
+ model.generation_config.max_window_size = 1310720
839
+ model.generation_config.do_sample = False
840
+ model.generation_config.use_cache = True
841
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
842
+
843
+ # from cognitron_vl.data.processor.image_processor import ImageProcessor
844
+ image_processor = ImageProcessor(
845
+ process_type="dynamic",
846
+ image_size=448,
847
+ normalize_type="imagenet",
848
+ min_patch_grid=1,
849
+ max_patch_grid=12,
850
+ )
851
+
852
+ import gradio as gr
853
+ import spaces
854
+
855
+ @spaces.GPU(duration=120)
856
+ def inference_model(messages, image_path_list, video_path_list):
857
+
858
+ default_system_message = [
859
+ {
860
+ "role": "system",
861
+ "content": "You are a helpful AI assistant.",
862
+ }
863
+ ]
864
+ messages = default_system_message + messages
865
+
866
+ inputs = tokenizer.apply_chat_template(
867
+ messages,
868
+ tokenize=True,
869
+ add_generation_prompt=True,
870
+ return_tensors="pt",
871
+ )
872
+ # .to("cuda")
873
+ print("input", tokenizer.decode(inputs[0], skip_special_tokens=False), flush=True)
874
+
875
+ inputs, images, image_indices = get_external_inputs(inputs, image_path_list=image_path_list, video_path_list=video_path_list)
876
+ # inputs = inputs.to("cuda")
877
+ # images = images.to("cuda")
878
+ # image_indices = image_indices.to("cuda")
879
+
880
+
881
+ outputs = model.generate(inputs=inputs, images=images, image_indices=image_indices)
882
+
883
+
884
+ # output = tokenizer.decode(outputs[0], skip_special_tokens=False)
885
+ output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
886
+ print(f"output {output}", flush=True)
887
+
888
+ return output
889
+
890
+
891
+ import time
892
+ import filetype
893
+
894
+
895
+ font_size = "2.5em"
896
+ html = f"""
897
+ <p align="center" style="font-size: {font_size}; line-height: 1;">
898
+ <span style="display: inline-block; vertical-align: middle;">{model_name_or_path.split('/')[-1]}</span>
899
+ </p>
900
+ <center>
901
+ <font size=3>
902
+ <b>Long-VITA</b> has been fully open-sourced on <a href='https://huggingface.co/VITA-MLLM'>😊 Huggingface</a> and <a href='https://github.com/VITA-MLLM/Long-VITA'>🌟 GitHub</a>. If you find Long-VITA useful, a like❤️ or a star🌟 would be appreciated.
903
+ </font>
904
+ </center>
905
+ """
906
+
907
+ def add_message(history, message):
908
+ for x in message["files"]:
909
+ history.append({"role": "user", "content": {"path": x}})
910
+ if message["text"] is not None:
911
+ history.append({"role": "user", "content": message["text"]})
912
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
913
+
914
+
915
+ def bot(history: list):
916
+ print("#" * 100)
917
+ messages = []
918
+ image_path_list = []
919
+ video_path_list = []
920
+ for message in history:
921
+ # print(f"message {message}")
922
+ role = message["role"]
923
+ content = message["content"]
924
+ if isinstance(content, str):
925
+ if len(messages) == 0 or messages[-1]["role"] != role:
926
+ messages.append(
927
+ {
928
+ "role": role,
929
+ "content": "",
930
+ }
931
+ )
932
+ messages[-1]["content"] = messages[-1]["content"] + content
933
+
934
+ else:
935
+ for filepath in content:
936
+ if filetype.is_image(filepath):
937
+ # print(f"{filepath} is a valid image...")
938
+ if len(messages) == 0 or messages[-1]["role"] != role:
939
+ messages.append(
940
+ {
941
+ "role": role,
942
+ "content": "",
943
+ }
944
+ )
945
+ messages[-1]["content"] = "<image>" + messages[-1]["content"]
946
+ image_path_list.append(filepath)
947
+
948
+ elif filetype.is_video(filepath):
949
+ # print(f"{filepath} is a valid video...")
950
+ if len(messages) == 0 or messages[-1]["role"] != role:
951
+ messages.append(
952
+ {
953
+ "role": role,
954
+ "content": "",
955
+ }
956
+ )
957
+ messages[-1]["content"] = "<video>" + messages[-1]["content"]
958
+ video_path_list.append(filepath)
959
+
960
+ print(f"messages {messages}")
961
+ print(f"image_path_list {image_path_list}")
962
+ print(f"video_path_list {video_path_list}")
963
+
964
+ if len(image_path_list) == 0:
965
+ image_path_list = None
966
+ if len(video_path_list) == 0:
967
+ video_path_list = None
968
+
969
+ output = inference_model(messages, image_path_list, video_path_list)
970
+
971
+ history.append({"role": "assistant", "content": output})
972
+
973
+ return history
974
+
975
+
976
+ with gr.Blocks(title=model_name_or_path.split('/')[-1] + "🔥🚀🔥", theme=gr.themes.Ocean()) as demo:
977
+ gr.HTML(html)
978
+ with gr.Row():
979
+ chatbot = gr.Chatbot(type="messages", elem_id="chatbot", bubble_full_width=False, height=800)
980
+
981
+ with gr.Row():
982
+ chat_input = gr.MultimodalTextbox(
983
+ interactive=True,
984
+ file_count="multiple",
985
+ file_types=['image', 'video'],
986
+ placeholder="Enter message or upload file...",
987
+ show_label=False,
988
+ # sources=["microphone", "upload"],
989
+ sources=["upload"],
990
+ )
991
+
992
+
993
+ chat_msg = chat_input.submit(
994
+ add_message, [chatbot, chat_input], [chatbot, chat_input]
995
+ )
996
+ bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
997
+ bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
998
+
999
+ demo.launch(
1000
+ server_port=8501,
1001
+ server_name="0.0.0.0",
1002
+ )