rahulvenkk commited on
Commit
8e8833a
1 Parent(s): 19af009

big hard cwm

Browse files
app.py CHANGED
@@ -26,7 +26,7 @@ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
 
27
 
28
  # Load CWM 3-frame model (automatically download pre-trained checkpoint)
29
- model = model_factory.load_model('vitb_8x8patch_3frames')#.to(device)
30
 
31
  model.requires_grad_(False)
32
  model.eval()
@@ -89,7 +89,7 @@ import os
89
  # preloaded_images = load_preuploaded_images()
90
  #
91
  # print("Preloaded images:", preloaded_images)
92
- @spaces.GPU(duration=110)
93
  def get_c(x, points):
94
  x = utils.imagenet_normalize(x)#.to(device)
95
  with torch.no_grad():
@@ -107,6 +107,7 @@ with gr.Blocks() as demo:
107
  with gr.Column():
108
  # Input image
109
  original_image = gr.State(value=None) # store original image without arrows
 
110
  input_image = gr.Image(type="numpy", label="Upload Image")
111
 
112
  # Annotate arrows
@@ -125,12 +126,12 @@ with gr.Blocks() as demo:
125
  output_image = gr.Image(type='numpy')
126
 
127
  # Store the original image and resize to square size once uploaded
128
- def resize_to_square(img, size=448):
129
  print("Resizing image to square")
130
  img = Image.fromarray(img)
131
  transform = transforms.Compose([
132
  transforms.Resize((size, size)),
133
- transforms.CenterCrop(size)
134
  ])
135
  img = transform(img) # .transpose(1, 2, 0)
136
 
@@ -142,13 +143,13 @@ with gr.Blocks() as demo:
142
  img = np.array(Image.open(img_path))
143
  # print(f"Image uploaded with shape: {input.shape}")
144
  resized_img = resize_to_square(img)
145
- return resized_img, resized_img, []
146
 
147
 
148
  def store_img(img):
149
  resized_img = resize_to_square(img) # Resize the uploaded image to a square
150
  print(f"Image uploaded with shape: {resized_img.shape}")
151
- return resized_img, resized_img, []
152
 
153
 
154
  with gr.Row():
@@ -165,9 +166,9 @@ with gr.Blocks() as demo:
165
  # # run_on_click=True,
166
  # # label="Select an example image to test"
167
  # )
168
- gallery.select(load_img, outputs=[input_image, original_image, selected_points])
169
 
170
- input_image.upload(store_img, [input_image], [input_image, original_image, selected_points])
171
 
172
  # Get points and draw arrows or zero-length vectors based on the toggle
173
  def get_point(img, sel_pix, zero_length, evt: gr.SelectData):
@@ -251,7 +252,7 @@ with gr.Blocks() as demo:
251
  def run_model_on_points(points, input_image, original_image):
252
  H = input_image.shape[0]
253
  W = input_image.shape[1]
254
- factor = 224/H
255
  # Example: pretend the model processes points and returns a simple transformation on the image
256
  points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor
257
 
@@ -260,8 +261,8 @@ with gr.Blocks() as demo:
260
  img = Image.fromarray(original_image)
261
 
262
  transform = transforms.Compose([
263
- transforms.Resize((224, 224)),
264
- transforms.CenterCrop(224)
265
  ])
266
  img = np.array(transform(img))
267
 
@@ -272,7 +273,7 @@ with gr.Blocks() as demo:
272
  img = img[None]
273
 
274
  # reshape image to [B, C, T, H, W], C = 3, T = 3 (3-frame model), H = W = 224
275
- x = img[:, :, None].expand(-1, -1, 3, -1, -1)#.to(torch.float16)
276
 
277
  # Imagenet-normalize the inputs (standardization)
278
 
@@ -290,9 +291,9 @@ with gr.Blocks() as demo:
290
  return counterfactual
291
 
292
  # Run model when the button is clicked
293
- run_model_button.click(run_model_on_points, [selected_points, input_image, original_image], [output_image])
294
 
295
 
296
 
297
  # Launch the app
298
- demo.queue().launch()
 
26
 
27
 
28
  # Load CWM 3-frame model (automatically download pre-trained checkpoint)
29
+ model = model_factory.load_model('vitb_8x8patch_2frames_encoder_mask_token')#.to(device)
30
 
31
  model.requires_grad_(False)
32
  model.eval()
 
89
  # preloaded_images = load_preuploaded_images()
90
  #
91
  # print("Preloaded images:", preloaded_images)
92
+ # @spaces.GPU(duration=110)
93
  def get_c(x, points):
94
  x = utils.imagenet_normalize(x)#.to(device)
95
  with torch.no_grad():
 
107
  with gr.Column():
108
  # Input image
109
  original_image = gr.State(value=None) # store original image without arrows
110
+ original_image_high_res = gr.State(value=None) # store original image without arrows
111
  input_image = gr.Image(type="numpy", label="Upload Image")
112
 
113
  # Annotate arrows
 
126
  output_image = gr.Image(type='numpy')
127
 
128
  # Store the original image and resize to square size once uploaded
129
+ def resize_to_square(img, size=512):
130
  print("Resizing image to square")
131
  img = Image.fromarray(img)
132
  transform = transforms.Compose([
133
  transforms.Resize((size, size)),
134
+ # transforms.CenterCrop(size)
135
  ])
136
  img = transform(img) # .transpose(1, 2, 0)
137
 
 
143
  img = np.array(Image.open(img_path))
144
  # print(f"Image uploaded with shape: {input.shape}")
145
  resized_img = resize_to_square(img)
146
+ return resized_img, resized_img, img, []
147
 
148
 
149
  def store_img(img):
150
  resized_img = resize_to_square(img) # Resize the uploaded image to a square
151
  print(f"Image uploaded with shape: {resized_img.shape}")
152
+ return resized_img, resized_img, img, []
153
 
154
 
155
  with gr.Row():
 
166
  # # run_on_click=True,
167
  # # label="Select an example image to test"
168
  # )
169
+ gallery.select(load_img, outputs=[input_image, original_image, original_image_high_res, selected_points])
170
 
171
+ input_image.upload(store_img, [input_image], [input_image, original_image, original_image_high_res, selected_points])
172
 
173
  # Get points and draw arrows or zero-length vectors based on the toggle
174
  def get_point(img, sel_pix, zero_length, evt: gr.SelectData):
 
252
  def run_model_on_points(points, input_image, original_image):
253
  H = input_image.shape[0]
254
  W = input_image.shape[1]
255
+ factor = 256/H
256
  # Example: pretend the model processes points and returns a simple transformation on the image
257
  points = torch.from_numpy(np.array(points).reshape(-1, 4)) * factor
258
 
 
261
  img = Image.fromarray(original_image)
262
 
263
  transform = transforms.Compose([
264
+ transforms.Resize((256, 256)),
265
+ # transforms.CenterCrop(256)
266
  ])
267
  img = np.array(transform(img))
268
 
 
273
  img = img[None]
274
 
275
  # reshape image to [B, C, T, H, W], C = 3, T = 3 (3-frame model), H = W = 224
276
+ x = img[:, :, None].expand(-1, -1, 2, -1, -1)#.to(torch.float16)
277
 
278
  # Imagenet-normalize the inputs (standardization)
279
 
 
291
  return counterfactual
292
 
293
  # Run model when the button is clicked
294
+ run_model_button.click(run_model_on_points, [selected_points, input_image, original_image_high_res], [output_image])
295
 
296
 
297
 
298
  # Launch the app
299
+ demo.queue().launch(inbrowser=True, share=True)
cwm/model/model_factory.py CHANGED
@@ -26,6 +26,11 @@ _model_catalogue ={
26
  "init_fn": model_pretrain.vitb_8x8patch_2frames,
27
  },
28
 
 
 
 
 
 
29
  }
30
 
31
 
 
26
  "init_fn": model_pretrain.vitb_8x8patch_2frames,
27
  },
28
 
29
+ "vitb_8x8patch_2frames_encoder_mask_token": {
30
+ "path": "cwm/2frame_cwm_mask_token.pth",
31
+ "init_fn": model_pretrain.vitb_8x8patch_2frames_encoder_mask_token,
32
+ },
33
+
34
  }
35
 
36
 
cwm/model/model_pretrain.py CHANGED
@@ -818,6 +818,23 @@ def vitb_4x4patch_2frames(**kwargs):
818
  **kwargs)
819
  return model
820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
821
  # def base_8x8patch_2frames_1tube(**kwargs):
822
  # model = pretrain_videomae_base_224_scaffold(
823
  # patch_size=(8, 8),
 
818
  **kwargs)
819
  return model
820
 
821
+ from cwm.model.modeling_pretrain_cleaned_soft import pretrain_vit_base_256_scaffold
822
+
823
+ def vitb_8x8patch_2frames_encoder_mask_token(
824
+ use_flash_attention=False, **kwargs):
825
+ model = pretrain_vit_base_256_scaffold(
826
+ patch_size=(8, 8),
827
+ num_frames=2,
828
+ tubelet_size=1,
829
+ use_flash_attention=use_flash_attention,
830
+ interp_noise=False,
831
+ legacy=False,
832
+ xla_flash=False,
833
+ learn_pos_embed=True,
834
+ **kwargs)
835
+ return model
836
+
837
+
838
  # def base_8x8patch_2frames_1tube(**kwargs):
839
  # model = pretrain_videomae_base_224_scaffold(
840
  # patch_size=(8, 8),
cwm/model/modeling_pretrain_cleaned_soft.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from timm.models.layers import trunc_normal_ as __call_trunc_normal_
7
+ from einops import rearrange
8
+ from cwm.model.model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
9
+
10
+ from torch import Tensor
11
+ import cwm.utils as utils
12
+
13
+ def trunc_normal_(tensor, mean=0., std=1.):
14
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
15
+
16
+
17
+ def interpolate_pos_encoding(pos_embed, n_frames, h, w):
18
+ N = pos_embed.shape[1]
19
+ if N == (h * w * n_frames):
20
+ return pos_embed
21
+ old_h = old_w = int((N / n_frames) ** 0.5)
22
+ patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2)
23
+
24
+ patch_pos_embed = F.interpolate(
25
+ patch_pos_embed,
26
+ size=(h, w),
27
+ mode='bicubic',
28
+ )
29
+ return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0)
30
+
31
+ PRINT_PADDING = False
32
+
33
+ class PretrainVisionTransformerEncoder(nn.Module):
34
+ """ Vision Transformer with support for patch or hybrid CNN input stage
35
+ """
36
+ def __init__(self, img_size=224, patch_size=(16, 16), in_chans=3, num_classes=0, embed_dim=768, depth=12,
37
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
38
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2,
39
+ use_learnable_pos_emb=False, num_frames=16, embed_per_frame=False, clumping_factor=None, block_func=Block, k_bias=False, interp_noise=False, block_kwargs={}, legacy=False, xla_flash=False, learn_pos_embed=False):
40
+ super().__init__()
41
+ self.num_classes = num_classes
42
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
43
+ self.patch_size = (tubelet_size,) + patch_size
44
+ self.pt, self.ph, self.pw = self.patch_size
45
+ self.h = int(img_size / self.ph)
46
+ self.w = int(img_size / self.pw)
47
+ self.hw = self.h * self.w
48
+
49
+ self.clumping_factor = clumping_factor
50
+ self.interp_noise = interp_noise
51
+
52
+ self.embed_dim = embed_dim
53
+ self.num_heads = num_heads
54
+
55
+ if self.clumping_factor is not None: # Clump the context frame for memory efficiency
56
+ self.clumping_embed = nn.Conv3d(in_channels=embed_dim, out_channels=embed_dim,
57
+ kernel_size=(tubelet_size, clumping_factor, clumping_factor),
58
+ stride=(tubelet_size, clumping_factor, clumping_factor))
59
+
60
+ self._embed_per_frame = embed_per_frame
61
+ if not self._embed_per_frame:
62
+ self.patch_embed = PatchEmbed(
63
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,tubelet_size=tubelet_size,num_frames=num_frames)
64
+ num_patches = self.patch_embed.num_patches
65
+ elif self._embed_per_frame:
66
+ assert (num_frames % tubelet_size) == 0
67
+ num_embeddings = (num_frames // tubelet_size)
68
+ self.patch_embed = nn.ModuleList([
69
+ PatchEmbed(
70
+ img_size=img_size, patch_size=patch_size,
71
+ in_chans=in_chans, embed_dim=embed_dim,
72
+ tubelet_size=tubelet_size, num_frames=tubelet_size)
73
+ for _ in range(num_embeddings)])
74
+ num_patches = self.patch_embed[0].num_patches * num_embeddings
75
+
76
+ self.num_patches = num_patches
77
+ self.num_frames = num_frames
78
+ print("NUM PATCHES IN ENCODER", self.num_patches)
79
+
80
+ self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
81
+
82
+ if learn_pos_embed:
83
+ self.pos_embed = nn.Parameter(self.pos_embed)
84
+
85
+ self.learn_pos_embed = learn_pos_embed
86
+
87
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
88
+ self.blocks = nn.ModuleList([
89
+ block_func(
90
+ dim=embed_dim, in_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
91
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
92
+ init_values=init_values, **block_kwargs, k_bias=k_bias, legacy=legacy, xla_flash=xla_flash)
93
+ for i in range(depth)])
94
+ self.norm = norm_layer(embed_dim)
95
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
96
+
97
+ if use_learnable_pos_emb:
98
+ trunc_normal_(self.pos_embed, std=.02)
99
+
100
+ self.apply(self._init_weights)
101
+
102
+ def _set_pos_embed(self, dim=None):
103
+ if dim is None:
104
+ dim = self.embed_dim
105
+ if self.pos_embed is None:
106
+ self.pos_embed = get_sinusoid_encoding_table(
107
+ self.num_patches, dim)
108
+
109
+
110
+ def _init_weights(self, m):
111
+ if isinstance(m, nn.Linear):
112
+ nn.init.xavier_uniform_(m.weight)
113
+ if isinstance(m, nn.Linear) and m.bias is not None:
114
+ nn.init.constant_(m.bias, 0)
115
+ elif isinstance(m, nn.LayerNorm):
116
+ nn.init.constant_(m.bias, 0)
117
+ nn.init.constant_(m.weight, 1.0)
118
+
119
+ def get_num_layers(self):
120
+ return len(self.blocks)
121
+
122
+ @torch.jit.ignore
123
+ def no_weight_decay(self):
124
+ return {'pos_embed', 'cls_token'}
125
+
126
+ def get_classifier(self):
127
+ return self.head
128
+
129
+ def reset_classifier(self, num_classes, global_pool=''):
130
+ self.num_classes = num_classes
131
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
132
+
133
+ def _get_pos_embed(self):
134
+ return self.pos_embed
135
+
136
+ def forward_block(self, x, idx):
137
+ return self.blocks[idx](x)
138
+
139
+ def interpolate_tensor_with_mask_token(self,
140
+ x: Tensor, mask: Tensor, mask_token: Tensor, invert: bool = True
141
+ ) -> Tensor:
142
+ """
143
+ Where mask == (0 if invert else 1), return x
144
+ where mask == (1 if invert else 0), return mask_token
145
+ Linearly interpolate between these using value of mask.
146
+ """
147
+ # mask_token = mask_token
148
+ # breakpoint()
149
+ B, N, C = x.shape
150
+ assert mask.shape[1] == N, (
151
+ f"Number of tokens in mask ({mask.shape[1]}) does not match "
152
+ f"number of tokens in input ({N})"
153
+ )
154
+
155
+ assert mask_token.shape[-1] == C, (
156
+ f"Dimensionality of mask token ({mask_token.shape[-1]}) does not match "
157
+ f"dimensionality of tokens in input ({C})"
158
+ )
159
+
160
+ # convert mask to interpolation weights in range [0., 1.]
161
+ mask = mask.to(x).clip(min=0.0, max=1.0)
162
+ mask = (1.0 - mask) if invert else mask
163
+ mask = mask.unsqueeze(-1) # [B, N, 1]
164
+
165
+ # expand mask token
166
+ mask_token = mask_token.view(1, 1, C).expand(B, N, -1)
167
+
168
+ # interpolate
169
+ start = mask_token
170
+ end = x
171
+
172
+ return start + mask * (end - start)
173
+
174
+ def interpolate_tensor_with_noise(self,
175
+ x: Tensor, mask: Tensor, invert: bool = True
176
+ ) -> Tensor:
177
+ """
178
+ Where mask == (0 if invert else 1), return x
179
+ where mask == (1 if invert else 0), return mask_token
180
+ Linearly interpolate between these using value of mask.
181
+ """
182
+ # mask_token = mask_token
183
+ # breakpoint()
184
+ B, N, C = x.shape
185
+ assert mask.shape[1] == N, (
186
+ f"Number of tokens in mask ({mask.shape[1]}) does not match "
187
+ f"number of tokens in input ({N})"
188
+ )
189
+
190
+ # convert mask to interpolation weights in range [0., 1.]
191
+ mask = mask.to(x).clip(min=0.0, max=1.0)
192
+ mask = (1.0 - mask) if invert else mask
193
+ mask = mask.unsqueeze(-1) # [B, N, 1]
194
+
195
+ # ImageNet mean and std
196
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
197
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
198
+
199
+ # Generate a 3x8x8 patch of random numbers from a normal distribution
200
+ # with the same mean and std as ImageNet images
201
+ rand_vec = torch.randn(B, N, 3, self.patch_size[-2], self.patch_size[-1]) * std + mean
202
+
203
+ rand_vec = rand_vec.to(x.device).to(x.dtype).view(B, N, -1)
204
+ # interpolate
205
+ start = rand_vec
206
+ end = x
207
+
208
+ return start + mask * (end - start)
209
+
210
+ def tokenize(self, x, mask=None):
211
+
212
+ if not self._embed_per_frame:
213
+ x = self.patch_embed(x)
214
+ elif self._embed_per_frame:
215
+ x = torch.cat([
216
+ self.patch_embed[i](
217
+ x[:,:,(i*self.pt):((i+1)*self.pt)])
218
+ for i in range(len(self.patch_embed))], 1)
219
+
220
+ pos_embed = self._get_pos_embed().type_as(x).to(x.device).clone()
221
+ if not self._learnable_pos_embed:
222
+ pos_embed = pos_embed.detach()
223
+ x = x + pos_embed
224
+ return (x, mask)
225
+
226
+ def tokenize_and_mask(self, x, mask):
227
+
228
+ x, mask = self.tokenize(x, mask)
229
+ B, _, C = x.shape
230
+ # breakpoint()
231
+ x_vis = x[~mask].reshape(B, -1, C)
232
+ return x_vis
233
+
234
+ def tokenize_and_mask_variable_size(self, x, mask):
235
+
236
+ x, mask = self.tokenize(x, mask)
237
+ B, _, C = x.shape
238
+ all_batches = []
239
+ max_len = 0
240
+ all_len = []
241
+ for i in range(B):
242
+ x_vis = x[i, ~mask[i]]
243
+ if x_vis.shape[0] > max_len:
244
+ max_len = x_vis.shape[0]
245
+ all_batches.append(x_vis)
246
+ all_len.append(x_vis.shape[0])
247
+
248
+ #pad all batches to max_len in a single line
249
+ x_vis = torch.stack([F.pad(batch, (0,0,0,max_len-batch.shape[0]), mode='constant', value=0) for batch in all_batches])
250
+
251
+ return x_vis, all_len
252
+
253
+ def forward_features(self, x, mask, move_patches, static_patches, delta, mask_token, res=1, return_feat_layer=None):
254
+ _, _, T, H, W = x.shape
255
+
256
+ if self.interp_noise:
257
+ #patchify x with patch size[0], patch size[1]
258
+ p0 = self.patch_size[-2]
259
+ p1 = self.patch_size[-1]
260
+ x = rearrange(x, 'b c t (h p0) (w p1) -> b (t h w) (p0 p1 c)', p0=p0, p1=p1, h=H//p0, w=W//p1) # x: [B, N, C]
261
+
262
+ x = self.interpolate_tensor_with_noise(x, mask, invert=True)
263
+ x = rearrange(x, 'b n (p c) -> b n p c', c=3)
264
+ # Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch.
265
+ x = rearrange(x,
266
+ 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)',
267
+ p0=1,
268
+ p1=self.patch_size[-2],
269
+ p2=self.patch_size[-1],
270
+ h=H//self.patch_size[-2],
271
+ w=W//self.patch_size[-1])
272
+
273
+ x = embed = self.patch_embed(x)
274
+
275
+ if res != 1:
276
+
277
+ p0 = self.patch_size[-2]
278
+ p1 = self.patch_size[-1]
279
+ pos_embed = interpolate_pos_encoding(self.pos_embed, T, int(256 // p0 * res), int(256 // p1 * res))
280
+ else:
281
+
282
+ pos_embed = self._get_pos_embed()
283
+
284
+ pos_embed = pos_embed.type_as(x) # .to(x.device).clone()
285
+
286
+ if not self.learn_pos_embed:
287
+ pos_embed = pos_embed.to(x.device).clone().detach()
288
+
289
+ x = x + pos_embed
290
+ B, _, C = x.shape
291
+ # x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
292
+ if not self.interp_noise:
293
+ x_vis = self.interpolate_tensor_with_mask_token(x, mask, mask_token, invert=True)
294
+ else:
295
+ x_vis = x
296
+
297
+ if move_patches is not None:
298
+
299
+ assert B == 1, "Only support batch size 1 for now"
300
+ for (px, py) in move_patches:
301
+ idx = px * self.w + py
302
+ dx, dy = delta
303
+ nx, ny = px + dx, py + dy
304
+ new_idx = nx * self.w + ny + (self.patch_embed.num_frames - 1) * (self.h * self.w)
305
+
306
+ emb = embed[:, idx]
307
+ pos_emb = pos_embed[:, new_idx]
308
+ emb = emb + pos_emb
309
+ x_vis = torch.cat([x_vis, emb[None]], 1)
310
+
311
+ if static_patches is not None:
312
+ for (px, py) in static_patches:
313
+ idx = px * self.w + py
314
+ new_idx = px * self.w + py + (self.patch_embed.num_frames - 1) * (self.h * self.w)
315
+ emb = embed[:, idx]
316
+ pos_emb = pos_embed[:, new_idx]
317
+ emb = emb + pos_emb
318
+ x_vis = torch.cat([x_vis, emb[None]], 1)
319
+
320
+ for blk_idx, blk in enumerate(self.blocks):
321
+ x_vis = blk(x_vis)
322
+ if blk_idx == return_feat_layer:
323
+ return x_vis
324
+
325
+ x_vis = self.norm(x_vis)
326
+ return x_vis
327
+
328
+ def _set_inputs(self, *args, **kwargs):
329
+ pass
330
+
331
+ def forward(self, x, mask, mask_token, return_feat_layer=None, timestamps=None, move_patches=None, static_patches=None, delta=None, res=1):
332
+ self._set_inputs(x, mask)
333
+ # pass input through the encoder
334
+ x = self.forward_features(x, mask, move_patches, static_patches, delta, mask_token, return_feat_layer=return_feat_layer, res=res)
335
+ # if return_feat_layer is not None and is lesser than the number of blocks it means that we are returning the
336
+ # features of an intermediate block layer. in this case we do not want to apply the head layer
337
+ if return_feat_layer is not None and return_feat_layer < len(self.blocks):
338
+ return x
339
+ # if we are passing through the entire encoder transformer we apply the head layer
340
+ x = self.head(x)
341
+ return x
342
+
343
+ class PretrainVisionTransformerDecoder(nn.Module):
344
+ """ Vision Transformer with support for patch or hybrid CNN input stage
345
+ """
346
+ def __init__(self, patch_size=(16, 16), num_classes=768, embed_dim=768, depth=12,
347
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
348
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, block_func=Block, block_kwargs={}, k_bias=False, legacy=True, xla_flash=False
349
+ ):
350
+ super().__init__()
351
+
352
+
353
+ self.num_classes = num_classes
354
+
355
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
356
+ self.patch_size = patch_size
357
+
358
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
359
+ self.blocks = nn.ModuleList([
360
+ block_func(
361
+ dim=embed_dim, in_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
362
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
363
+ init_values=init_values, **block_kwargs, k_bias=k_bias, legacy=legacy, xla_flash=xla_flash)
364
+ for i in range(depth)])
365
+ self.norm = norm_layer(embed_dim)
366
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
367
+
368
+ self.apply(self._init_weights)
369
+
370
+
371
+ def _init_weights(self, m):
372
+ if isinstance(m, nn.Linear):
373
+ nn.init.xavier_uniform_(m.weight)
374
+ if isinstance(m, nn.Linear) and m.bias is not None:
375
+ nn.init.constant_(m.bias, 0)
376
+ elif isinstance(m, nn.LayerNorm):
377
+ nn.init.constant_(m.bias, 0)
378
+ nn.init.constant_(m.weight, 1.0)
379
+
380
+ def get_num_layers(self):
381
+ return len(self.blocks)
382
+
383
+ @torch.jit.ignore
384
+ def no_weight_decay(self):
385
+ return {'pos_embed', 'cls_token'}
386
+
387
+ def get_classifier(self):
388
+ return self.head
389
+
390
+ def reset_classifier(self, num_classes, global_pool=''):
391
+ self.num_classes = num_classes
392
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
393
+
394
+ def forward_block(self, x, idx):
395
+ return self.blocks[idx](x)
396
+
397
+ def get_last_tokens(self, x, return_token_num):
398
+ if return_token_num > 0:
399
+ return self.head(self.norm(x[:,-return_token_num:]))
400
+ elif return_token_num == 0:
401
+ return self.head(self.norm(x))[:,x.size(1):]
402
+ else:
403
+ return self.head(self.norm(x))
404
+
405
+ def forward(self, x, return_token_num, return_feat_layer=None):
406
+
407
+ # pass input through the decoder
408
+ for blk_idx, blk in enumerate(self.blocks):
409
+ x = blk(x)
410
+ # if we are returning the features of an intermediate block
411
+ # do so and skip the remaining computation
412
+ if blk_idx == return_feat_layer:
413
+ return x
414
+
415
+ if return_token_num > 0:
416
+ x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
417
+ else:
418
+ x = self.head(self.norm(x))
419
+
420
+ return x
421
+
422
+ class PretrainVisionTransformer(nn.Module):
423
+ """ Vision Transformer with support for patch or hybrid CNN input stage
424
+ """
425
+ default_input_kwargs = {'unnormalize': True}
426
+ def __init__(self,
427
+ img_size=224,
428
+ patch_size=(16, 16),
429
+ main_input=None,
430
+ main_input_kwargs=default_input_kwargs,
431
+ encoder_func=PretrainVisionTransformerEncoder,
432
+ encoder_in_chans=3,
433
+ encoder_num_classes=0,
434
+ encoder_embed_dim=768,
435
+ encoder_depth=12,
436
+ encoder_num_heads=12,
437
+ encoder_block_func=Block,
438
+ encoder_block_kwargs={},
439
+ decoder_num_classes=None, # For pretraining this parameter isn't relevant but must be set according to tube&patch size
440
+ decoder_embed_dim=512,
441
+ decoder_depth=8,
442
+ decoder_num_heads=8,
443
+ decoder_block_func=Block,
444
+ decoder_block_kwargs={},
445
+ mlp_ratio=4.,
446
+ qkv_bias=False,
447
+ k_bias=False,
448
+ qk_scale=None,
449
+ num_frames=16,
450
+ drop_rate=0.,
451
+ attn_drop_rate=0.,
452
+ drop_path_rate=0.,
453
+ norm_layer=nn.LayerNorm,
454
+ init_values=0.,
455
+ spacetime_separable_pos_embed=False,
456
+ tubelet_size=2,
457
+ num_classes=0, # avoid the error from create_fn in timm
458
+ in_chans=0, # avoid the error from create_fn in timm
459
+ embed_per_frame=False,
460
+ flow_model_ckpt=None,
461
+ flow_frames=None,
462
+ random_input=False,
463
+ use_flash_attention=False,
464
+ additional_decoder_for_transition=False,
465
+ additional_decoder_for_x3_hat=False,
466
+ clumping_factor=None,
467
+ return_detectron_format=False,
468
+ out_feature='out_feature',
469
+ interp_noise=False,
470
+ legacy=True,
471
+ xla_flash=False,
472
+ learn_pos_embed=False,
473
+ **kwargs
474
+ ):
475
+ super().__init__()
476
+
477
+ encoder_block_kwargs.update({'flash_attention': use_flash_attention})
478
+ decoder_block_kwargs.update({'flash_attention': use_flash_attention})
479
+
480
+ self.clumping_factor = clumping_factor
481
+
482
+ self.interp_noise = interp_noise
483
+
484
+ self.learn_pos_embed = learn_pos_embed
485
+
486
+ if self.clumping_factor is not None:
487
+ print('Clumping factor = %d' % self.clumping_factor)
488
+ self.clumping_embed = nn.Conv3d(in_channels=decoder_embed_dim, out_channels=decoder_embed_dim,
489
+ kernel_size=(1, clumping_factor, clumping_factor),
490
+ stride=(1, clumping_factor, clumping_factor))
491
+ self.clumping_embed.apply(self._init_weights)
492
+
493
+ self.up = nn.ConvTranspose2d(decoder_embed_dim, decoder_embed_dim, kernel_size=2, stride=2)
494
+ self.up.apply(self._init_weights)
495
+
496
+ self.encoder = encoder_func(
497
+ img_size=img_size,
498
+ patch_size=patch_size,
499
+ in_chans=encoder_in_chans,
500
+ num_classes=encoder_num_classes,
501
+ embed_dim=encoder_embed_dim,
502
+ depth=encoder_depth,
503
+ num_heads=encoder_num_heads,
504
+ mlp_ratio=mlp_ratio,
505
+ qkv_bias=qkv_bias,
506
+ qk_scale=qk_scale,
507
+ drop_rate=drop_rate,
508
+ attn_drop_rate=attn_drop_rate,
509
+ drop_path_rate=drop_path_rate,
510
+ norm_layer=norm_layer,
511
+ init_values=init_values,
512
+ tubelet_size=tubelet_size,
513
+ num_frames=num_frames,
514
+ embed_per_frame=embed_per_frame,
515
+ block_func=encoder_block_func,
516
+ block_kwargs=encoder_block_kwargs,
517
+ clumping_factor=clumping_factor,
518
+ k_bias=k_bias,
519
+ interp_noise = interp_noise,
520
+ legacy=legacy,
521
+ xla_flash=xla_flash,
522
+ learn_pos_embed=learn_pos_embed,
523
+ **kwargs)
524
+
525
+ if not return_detectron_format:
526
+ self.decoder = PretrainVisionTransformerDecoder(
527
+ patch_size=patch_size,
528
+ num_classes= 3*tubelet_size*(patch_size[0]*patch_size[1]) if decoder_num_classes is None else decoder_num_classes,
529
+ embed_dim=decoder_embed_dim,
530
+ depth=decoder_depth,
531
+ num_heads=decoder_num_heads,
532
+ mlp_ratio=mlp_ratio,
533
+ qkv_bias=qkv_bias,
534
+ qk_scale=qk_scale,
535
+ drop_rate=drop_rate,
536
+ attn_drop_rate=attn_drop_rate,
537
+ drop_path_rate=drop_path_rate,
538
+ norm_layer=norm_layer,
539
+ init_values=init_values,
540
+ block_func=decoder_block_func,
541
+ k_bias=k_bias, xla_flash=xla_flash,
542
+ block_kwargs=decoder_block_kwargs, legacy=legacy)
543
+
544
+ self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=k_bias)
545
+
546
+ if not self.interp_noise:
547
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, encoder_embed_dim))
548
+ trunc_normal_(self.mask_token, std=.02)
549
+ else:
550
+ self.mask_token = None
551
+
552
+ self.timestamps = None
553
+ self.encoder.timestamps = None
554
+
555
+ if self.learn_pos_embed:
556
+ self.pos_embed = nn.Parameter(get_sinusoid_encoding_table(self.encoder.num_patches, decoder_embed_dim))
557
+ else:
558
+ self.pos_embed = get_sinusoid_encoding_table(self.encoder.num_patches, decoder_embed_dim)
559
+
560
+ self.num_frames = num_frames
561
+ self.num_patches = self.encoder.num_patches
562
+ if self.num_frames is not None:
563
+ self.num_patches_per_frame = self.num_patches // self.num_frames
564
+ else:
565
+ self.num_patches_per_frame = self.num_patches
566
+ self.patch_size = self.encoder.patch_size
567
+ if isinstance(img_size, int):
568
+ self.image_size = (img_size, img_size)
569
+ else:
570
+ assert hasattr(img_size, '__len__'), img_size
571
+ self.image_size = img_size
572
+
573
+ self.return_detectron_format = return_detectron_format
574
+
575
+ @property
576
+ def mask_size(self):
577
+ return (self.num_frames // self.patch_size[0],
578
+ self.image_size[-2] // self.patch_size[-2],
579
+ self.image_size[-1] // self.patch_size[-1])
580
+
581
+ def _init_weights(self, m):
582
+ if isinstance(m, nn.Linear):
583
+ nn.init.xavier_uniform_(m.weight)
584
+ if isinstance(m, nn.Linear) and m.bias is not None:
585
+ nn.init.constant_(m.bias, 0)
586
+ elif isinstance(m, nn.LayerNorm):
587
+ nn.init.constant_(m.bias, 0)
588
+ nn.init.constant_(m.weight, 1.0)
589
+
590
+ def get_num_layers(self):
591
+ return len(self.blocks)
592
+
593
+ @torch.jit.ignore
594
+ def no_weight_decay(self):
595
+ return {'pos_embed', 'cls_token', 'mask_token'}
596
+
597
+
598
+
599
+ def unpatchify(self, x, mask):
600
+ # Define the input tensor
601
+ B, N, C = x.shape # batch size
602
+ h, w = self.mask_size[-2:]
603
+ patch_size = self.patch_size[-2:]
604
+
605
+ recon = torch.zeros(B, h*w, C).to(x)
606
+ recon[mask[:, -h*w:]] = x.flatten(0, 1)
607
+
608
+ rec_imgs = rearrange(recon, 'b n (p c) -> b n p c', c=3)
609
+ # Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch.
610
+ rec_imgs = rearrange(rec_imgs,
611
+ 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)',
612
+ p0=1,
613
+ p1=patch_size[0],
614
+ p2=patch_size[1],
615
+ h=h,
616
+ w=w)
617
+
618
+ # MEAN = torch.from_numpy(np.array((0.485, 0.456, 0.406))[None, :, None, None, None]).cuda().half()
619
+ # STD = torch.from_numpy(np.array((0.229, 0.224, 0.225))[None, :, None, None, None]).cuda().half()
620
+ #
621
+ # rec_imgs = (rec_imgs - MEAN) / STD
622
+
623
+ return rec_imgs
624
+
625
+
626
+ def forward(self, x, mask, timestamps=None, return_feat_layer=None, res=1, *args, get_encoder_out=False, **kwargs):
627
+
628
+ _, _, T, _, _ = x.shape
629
+
630
+ self.device = x.device
631
+
632
+ enc_out = self.encoder(x, mask, self.mask_token, timestamps=timestamps, return_feat_layer=return_feat_layer, res=res, *args, **kwargs) # [B, N_vis, C_e]
633
+
634
+ x_vis = self.encoder_to_decoder(enc_out)
635
+
636
+ # check if we are returning the features of an intermediate block layer
637
+ if return_feat_layer is not None:
638
+ # if the returned layer is one of the encoder layers (the first N_enc layers) we return the features
639
+ # if the return feat layer is exactly N_enc then we are returning the layer after the entire encoder block
640
+ # in both cases this manifests as returning x_vis, since self.encoder will return either the final block embedding
641
+ # or the final head embedding depending on the return_feat_layer
642
+ # in either case we subtract the number of encoder blocks + 1 (for the intermediate embedding layer)
643
+ # from the return_feat_layer to get the correct index for the decoder block
644
+ return_feat_layer = return_feat_layer - len(self.encoder.blocks) - 1
645
+ if return_feat_layer < 0:
646
+ return x_vis
647
+
648
+ # add pos embedding
649
+ if res != 1:
650
+ p0 = self.patch_size[-2]
651
+ p1 = self.patch_size[-1]
652
+ pos_embed = interpolate_pos_encoding(self.pos_embed, T, int(256 // p0 * res), int(256 // p1 * res))
653
+ else:
654
+ pos_embed = self.pos_embed
655
+ dec_pos_embed = pos_embed.expand(x_vis.size(0), -1, -1).type_as(x)
656
+
657
+ if not self.learn_pos_embed:
658
+ dec_pos_embed = dec_pos_embed.to(x.device).clone().detach()
659
+
660
+ x_vis = x_vis + dec_pos_embed
661
+
662
+ # pass input through the decoder, this will automatically return an intermediate layer if return_feat_layer is set
663
+ x_all = self.decoder(x_vis, 0, return_feat_layer=return_feat_layer)
664
+
665
+ if get_encoder_out:
666
+ return x_all, enc_out
667
+
668
+ return x_all
669
+
670
+ def get_counterfactual(self, x, move_patches):
671
+ '''
672
+ :param x: input tensor [1, C, T, H, W]: support only batch size 1 for now
673
+ :param move_patches: torch tensor [N, 4] sized array where each row contains patch motion [x1, y1, x2, y2] in pixel coordinates
674
+ :return:
675
+ '''
676
+ B, _, T, H, H = x.shape
677
+
678
+ mask = torch.ones(B, self.encoder.hw * self.encoder.num_frames).to(x.device).bool()
679
+ mask[:, :self.encoder.hw * (self.encoder.num_frames - 1)] = False
680
+
681
+ move_patches = (move_patches / H) * self.encoder.h
682
+ move_patches = move_patches.to(torch.int64)
683
+
684
+ for x1, y1, x2, y2 in move_patches:
685
+ idx2 = x2 * self.encoder.w + y2 + (self.encoder.num_frames - 1) * (self.encoder.h * self.encoder.w)
686
+ mask[:, idx2] = False
687
+ im_x1 = x1 * self.encoder.ph
688
+ im_y1 = y1 * self.encoder.pw
689
+ im_x2 = x2 * self.encoder.ph
690
+ im_y2 = y2 * self.encoder.pw
691
+ x[:, :, -1, im_x2:im_x2 + self.encoder.ph, im_y2:im_y2 + self.encoder.pw] = x[:, :, -2,
692
+ im_x1:im_x1 + self.encoder.ph,
693
+ im_y1:im_y1 + self.encoder.pw]
694
+
695
+ prediction = self.forward(x, mask)[:, -self.encoder.hw:]
696
+
697
+ prediction = utils.unpatchify_cwm(
698
+ prediction,
699
+ patch_size=self.encoder.patch_size[-1],
700
+ ) # reshape the output to an image
701
+
702
+ return prediction
703
+
704
+
705
+ def pretrain_vit_base_256_scaffold(**kwargs):
706
+ model = PretrainVisionTransformer(
707
+ img_size=256,
708
+ encoder_embed_dim=768,
709
+ encoder_depth=12,
710
+ encoder_num_heads=12,
711
+ encoder_num_classes=0,
712
+ decoder_embed_dim=768,
713
+ decoder_num_heads=12,
714
+ decoder_depth=12,
715
+ mlp_ratio=4,
716
+
717
+ qkv_bias=True,
718
+ k_bias=True,
719
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
720
+ **kwargs)
721
+ model.default_cfg = _cfg()
722
+ return model
723
+