rahulvenkk
big hard cwm
8e8833a
from functools import partial
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_ as __call_trunc_normal_
from cwm.data.masking_generator import RotatedTableMaskingGenerator
from cwm.model.model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table
import cwm.eval.Flow.masking_flow as masking
import cwm.utils as utils
# from external.raft_interface import RAFTInterface
import cwm.eval.Flow.generator as generator
import matplotlib.pyplot as plt
import torch.nn.functional as F
from cwm.eval.Flow import flow_utils
import cwm.model.keypoint_utils as keypoint_utils
def trunc_normal_(tensor, mean=0., std=1.):
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
def interpolate_pos_encoding(pos_embed, n_frames, h, w):
N = pos_embed.shape[1]
if N == (h * w * n_frames):
return pos_embed
old_h = old_w = int((N / n_frames) ** 0.5)
patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2)
patch_pos_embed = F.interpolate(
patch_pos_embed,
size=(h, w),
mode='bilinear',
)
return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0)
class PretrainVisionTransformerEncoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=(16, 16), in_chans=3, num_classes=0, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2,
num_frames=16, block_func=Block, k_bias=False, use_learnable_pos_emb=False, block_kwargs={}):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_size = (tubelet_size,) + patch_size
self.pt, self.ph, self.pw = self.patch_size
self.h = int(img_size / self.ph)
self.w = int(img_size / self.pw)
self.hw = self.h * self.w
self.dims = [self.h, self.w]
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
tubelet_size=tubelet_size,
num_frames=num_frames
)
num_patches = self.patch_embed.num_patches
self.num_patches = num_patches
self.num_frames = num_frames
if use_learnable_pos_emb:
self.use_learnable_pos_emb = True
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.pos_embed, std=.02)
else:
# sine-cosine positional embeddings
self.use_learnable_pos_emb = False
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
block_func(
dim=embed_dim, in_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, **block_kwargs, k_bias=k_bias,
xla_flash=True)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _get_pos_embed(self):
return self.pos_embed
def forward_block(self, x, idx):
return self.blocks[idx](x)
def forward_features(self, x, mask, move_pos=None, static_pos=None, movement=None, res=1):
T = x.shape[2]
x = embed = self.patch_embed(x)
pos_embed = self._get_pos_embed().type_as(x).to(x.device).clone()
if not self.use_learnable_pos_emb:
pos_embed = pos_embed.detach()
if res != 1:
print("res")
p0 = self.patch_size[-2]
p1 = self.patch_size[-1]
pos_embed = interpolate_pos_encoding(self.pos_embed, T, int(224 // p0 * res), int(224 // p1 * res))
x = x + pos_embed
B, _, C = x.shape
x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
if move_pos is not None:
h, w = self.h, self.w
first_frame_emb = embed[:, :self.hw].view(B, h, w, C) # [B, h, w, C]
last_frame_pos_emb = pos_embed[:, -self.hw:].view(1, h, w, C).expand(B, -1, -1, -1) # [B, h, w, C]
denominator = torch.tensor([self.h, self.w]).view(1, 1, 2).to(x.device)
new_pos = move_pos + movement # [B, P, 2]
move_pos = move_pos / denominator * 2 - 1
new_pos = (new_pos / denominator).clamp(0, 1) * 2 - 1 # handle special case where new_pos is out of bounds
static_pos = static_pos / denominator * 2 - 1
moving_emb = utils.sample_embedding(first_frame_emb, move_pos, mode='nearest') # [B, P, C]
moving_pos_emb = utils.sample_embedding(last_frame_pos_emb, new_pos, mode='nearest') # [B, P, C]
static_emb = utils.sample_embedding(first_frame_emb, static_pos, mode='nearest') # [B, P, C]
static_pos_emb = utils.sample_embedding(last_frame_pos_emb, static_pos, mode='nearest') # [B, P, C]
x_vis = torch.cat([x_vis, moving_emb + moving_pos_emb, static_emb + static_pos_emb], dim=1)
# assert B == 1, "Only support batch size 1 for now"
# for (px, py) in move_patches:
# idx = px * self.w + py
# dx, dy = delta
# nx, ny = px + dx, py + dy
# new_idx = nx * self.w + ny + (self.patch_embed.num_frames - 1) * (self.h * self.w)
#
# emb = embed[:, idx]
# pos_emb = pos_embed[:, new_idx]
# emb = emb + pos_emb
# x_vis = torch.cat([x_vis, emb[None]], 1)
# if static_patches is not None:
# for (px, py) in static_patches:
# idx = px * self.w + py
# new_idx = px * self.w + py + (self.patch_embed.num_frames - 1) * (self.h * self.w)
# emb = embed[:, idx]
# pos_emb = pos_embed[:, new_idx]
# emb = emb + pos_emb
# x_vis = torch.cat([x_vis, emb[None]], 1)
for blk in self.blocks:
x_vis = blk(x_vis)
x_vis = self.norm(x_vis)
return x_vis
def _set_inputs(self, *args, **kwargs):
pass
def forward(self, x, mask, move_pos=None, static_pos=None, movement=None, res=1):
self._set_inputs(x, mask)
x = self.forward_features(x, mask, move_pos, static_pos, movement, res=res)
x = self.head(x)
return x
class PretrainVisionTransformerDecoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, patch_size=(16, 16), num_classes=768, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, block_func=Block, block_kwargs={},
k_bias=False
):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_size = patch_size
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
block_func(
dim=embed_dim, in_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, **block_kwargs, k_bias=k_bias,
xla_flash=True)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_block(self, x, idx):
return self.blocks[idx](x)
def get_last_tokens(self, x, return_token_num):
if return_token_num > 0:
return self.head(self.norm(x[:, -return_token_num:]))
elif return_token_num == 0:
return self.head(self.norm(x))[:, x.size(1):]
else:
return self.head(self.norm(x))
def forward(self, x, return_token_num):
for blk in self.blocks:
x = blk(x)
if return_token_num > 0:
x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels
else:
x = self.head(self.norm(x))
return x
class PretrainVisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
default_input_kwargs = {'unnormalize': True}
def __init__(self,
img_size=224,
patch_size=(16, 16),
encoder_func=PretrainVisionTransformerEncoder,
encoder_in_chans=3,
encoder_num_classes=0,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_block_func=Block,
encoder_block_kwargs={},
decoder_num_classes=None,
# For pretraining this parameter isn't relevant but must be set according to tube&patch size
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=8,
decoder_block_func=Block,
decoder_block_kwargs={},
mlp_ratio=4.,
qkv_bias=False,
k_bias=False,
qk_scale=None,
num_frames=16,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=0.,
tubelet_size=2,
use_flash_attention=False,
use_learnable_pos_emb=False,
**kwargs
):
super().__init__()
encoder_block_kwargs.update({'flash_attention': use_flash_attention})
decoder_block_kwargs.update({'flash_attention': use_flash_attention})
self.tubelet_size = tubelet_size
num_classes = 3 * tubelet_size * (
patch_size[0] * patch_size[1]) if decoder_num_classes is None else decoder_num_classes
self.encoder = encoder_func(
img_size=img_size,
patch_size=patch_size,
in_chans=encoder_in_chans,
num_classes=encoder_num_classes,
embed_dim=encoder_embed_dim,
depth=encoder_depth,
num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
num_frames=num_frames,
block_func=encoder_block_func,
block_kwargs=encoder_block_kwargs,
use_learnable_pos_emb=use_learnable_pos_emb,
k_bias=k_bias,
**kwargs)
self.decoder = PretrainVisionTransformerDecoder(
patch_size=patch_size,
num_classes=num_classes,
embed_dim=decoder_embed_dim,
depth=decoder_depth,
num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
block_func=decoder_block_func,
k_bias=k_bias,
block_kwargs=decoder_block_kwargs)
self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=k_bias)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
trunc_normal_(self.mask_token, std=.02)
if use_learnable_pos_emb:
self.use_learnable_pos_emb = True
self.pos_embed = nn.Parameter(torch.zeros(self.encoder.num_patches, decoder_embed_dim))
trunc_normal_(self.pos_embed, std=.02)
else:
self.use_learnable_pos_emb = False
self.pos_embed = get_sinusoid_encoding_table(self.encoder.num_patches, decoder_embed_dim)
self.num_frames = num_frames
self.num_patches = self.encoder.num_patches
if self.num_frames is not None:
self.num_patches_per_frame = self.num_patches // self.num_frames
else:
self.num_patches_per_frame = self.num_patches
self.patch_size = self.encoder.patch_size
if isinstance(img_size, int):
self.image_size = (img_size, img_size)
else:
assert hasattr(img_size, '__len__'), img_size
self.image_size = img_size
# self.flow_interface = RAFTInterface()
@property
def mask_size(self):
return (self.num_frames // self.patch_size[0],
self.image_size[-2] // self.patch_size[-2],
self.image_size[-1] // self.patch_size[-1])
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'mask_token'}
def adjust_input_resolution(self, H, W):
if self.image_size == [H, W]:
return
patch_size = self.encoder.patch_size[-2:]
self.image_size = [H, W]
self.encoder.h = int(H / self.encoder.ph)
self.encoder.w = int(W / self.encoder.pw)
self.encoder.hw = self.encoder.h * self.encoder.w
self.encoder.dims = [self.encoder.h, self.encoder.w]
dims = [int(s / p) for s, p in zip(self.image_size, patch_size)]
self.encoder.pos_embed = utils.interpolate_pos_encoding(self.encoder.pos_embed, 3, dims[0], dims[1])
print('pos_embed', self.encoder.pos_embed.shape)
self.pos_embed = utils.interpolate_pos_encoding(self.pos_embed, 3, dims[0], dims[1])
def forward(self, x, mask, forward_full=False, return_features=False, res=1, *args, **kwargs):
_, _, T, _, _ = x.shape
self.device = x.device
num_patches_per_frame = (x.shape[-1] // self.encoder.patch_size[1]) ** 2
x_vis = self.encoder(x, mask, res=res, *args, **kwargs)
if return_features:
return x_vis
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
B, N, C = x_vis.shape
# add pos embedding
# if res != 1:
# p0 = self.patch_size[-2]
# p1 = self.patch_size[-1]
# pos_embed = interpolate_pos_encoding(self.pos_embed.unsqueeze(0), T, int(224 // p0 * res), int(224 // p1 * res))
# else:
# pos_embed = self.pos_embed.unsqueeze(0)
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone()
if not self.use_learnable_pos_emb:
expand_pos_embed = expand_pos_embed.detach()
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
nctx = num_patches_per_frame * (self.num_frames - 1)
x_vis = x_vis + pos_emd_vis
x_full = torch.cat([x_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d]
if forward_full:
x_full = torch.cat([x_vis, self.mask_token + expand_pos_embed[:, nctx:]], dim=1) # [B, N, C_d]
x_all = self.decoder(x_full, num_patches_per_frame)
x = x_all
else:
x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
return x
def get_counterfactual(self, x, move_patches):
'''
:param x: input tensor [1, C, T, H, W]: support only batch size 1 for now
:param move_patches: torch tensor [N, 4] sized array where each row contains patch motion [x1, y1, x2, y2] in pixel coordinates
:return:
'''
B, _, T, H, H = x.shape
mask = torch.ones(B, self.encoder.hw * self.encoder.num_frames).to(x.device).bool()
mask[:, :self.encoder.hw * (self.encoder.num_frames - 1)] = False
move_patches = (move_patches/H)*self.encoder.h
move_patches = move_patches.to(torch.int64)
for x1, y1, x2, y2 in move_patches:
idx2 = x2*self.encoder.w + y2 + (self.encoder.num_frames - 1) * (self.encoder.h * self.encoder.w)
mask[:, idx2] = False
im_x1 = x1*self.encoder.ph
im_y1 = y1*self.encoder.pw
im_x2 = x2*self.encoder.ph
im_y2 = y2*self.encoder.pw
x[:, :, -1, im_x2:im_x2+self.encoder.ph, im_y2:im_y2+self.encoder.pw] = x[:, :, -2, im_x1:im_x1+self.encoder.ph, im_y1:im_y1+self.encoder.pw]
prediction = self.forward(x, mask, forward_full=True)
prediction = utils.unpatchify_cwm(
prediction,
patch_size=self.encoder.patch_size[-1],
) # reshape the output to an image
return prediction
def get_directional_counterfactual(self, x, mask=None, move_pos=None, static_pos=None, movement=None, max_movement=None):
B, _, T, _, _ = x.shape
if mask is None: # default mask: all visible but the last frame
mask = torch.ones(B, self.encoder.hw * self.encoder.num_frames).to(x.device).bool()
mask[:, :self.encoder.hw * (self.encoder.num_frames - 1)] = False
if movement is None: # generate random motion if movement is not specified
assert max_movement is not None and move_pos is not None
movement = torch.randint(-max_movement, max_movement, move_pos.shape).to(x.device) # [B, num_samples, 2]
x_vis = self.encoder(x, mask, move_pos=move_pos, static_pos=static_pos, movement=movement) # [B, N_vis, C_e]
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
B, N, C = x_vis.shape
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
if move_pos is not None:
h, w = self.encoder.h, self.encoder.w
last_frame_pos_emb = expand_pos_embed[:, -(h * w):].view(B, h, w, C) # [B, h, w, C]
# compute new locations of the moved patche, snormalize positions to range [-1, 1]
new_pos = move_pos + movement # [B, P, 2]
denominator = torch.tensor([h, w]).view(1, 1, 2).to(x.device)
new_pos = (new_pos / denominator).clamp(0, 1) * 2 - 1
static_pos = static_pos / denominator * 2 - 1
# sample the position embeddings of the moved and static patches
moving_pos_emb = utils.sample_embedding(last_frame_pos_emb, new_pos, mode='nearest') # [B, P, C]
static_pos_emb = utils.sample_embedding(last_frame_pos_emb, static_pos, mode='nearest') # [B, P, C]
# concatenate with the position embeddings to the visible patches
pos_emd_vis = torch.cat([pos_emd_vis, moving_pos_emb, static_pos_emb], dim=1)
# assert B == 1, "Only support batch size 1 for now"
# offset = (self.encoder.patch_embed.num_frames - 1) * (self.encoder.h * self.encoder.w)
# for (px, py) in move_patches:
# dx, dy = delta
# nx, ny = px + dx, py + dy
# new_idx = nx * self.encoder.w + ny + offset
# pos_emb = expand_pos_embed[:, new_idx]
# pos_emd_vis = torch.cat([pos_emd_vis, pos_emb[None]], 1)
#
# if static_patches is not None:
# for (px, py) in static_patches:
# new_idx = px * self.encoder.w + py + offset
# pos_emb = expand_pos_embed[:, new_idx]
# pos_emd_vis = torch.cat([pos_emd_vis, pos_emb[None]], 1)
pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C)
x_vis = x_vis + pos_emd_vis
x_full = torch.cat([x_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d]
x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16]
prediction = utils.unpatchify_cwm(
x,
patch_size=self.encoder.patch_size[-1],
mask=mask[:, -self.encoder.hw:]
) # reshape the output to an image
return prediction
@torch.no_grad()
def get_segment(self, x, mask=None, sampling_dist=None, num_segments=1, num_iters=4, num_samples=4, max_movement=4,
vis=True):
B, C, T, H, W = x.shape
N = num_samples
patch_size = self.encoder.patch_size[-1]
self.adjust_input_resolution(H, W)
# ## Step 0: define the sampling distribution for moving and static locations
# if sampling_dist is None:
# sampling_dist = utils.get_dino_predominance(x[:, :, 0], dims=self.encoder.dims, img_size=self.image_size)[0]
# sampling_dist = F.interpolate(sampling_dist, self.encoder.dims, mode='bilinear', align_corners=False)
# sampling_dist = sampling_dist.squeeze(1) ** 4
#
# if vis:
# print('sampling_dist', sampling_dist.shape)
# plt.imshow(sampling_dist[0].cpu().numpy())
# plt.title(f'Sampling distribution (max:{sampling_dist.max():.3f})')
# plt.show()
## Step 1: sample initial moving and static locations from the distribution
init_move_pos, init_static_pos, init_flow_mag, max_score = None, None, None, 0
# Sample multiple positions for each segment and select the one with consistent outputs
for _ in range(N):
# sample one moving position per example in the batch
move_pos = utils.sample_positions_from_dist(size=[1, 1], dist=sampling_dist).repeat(N, 1, 1) # [BN, 1, 2]
# each move position has N static positions and movement directions
static_pos = utils.sample_positions_from_dist(size=[B * N, 1], dist=-sampling_dist) # [BN, 1, 2]
## compute initial flow maps
_x = x.repeat(N, 1, 1, 1, 1) # [BN, C, T, H, W]
pred = self.get_directional_counterfactual(_x, move_pos=move_pos, static_pos=static_pos, max_movement=max_movement)
flow, flow_mag = self.flow_interface(_x[:, :, 0].float(), pred.clamp(0, 1).float(), return_magnitude=True)
flow_mag = flow_mag.view(B, N, H, W)
scores = flow_mag.flatten(2, 3).std(dim=1).mean(-1) # [B, N]
print('scores', scores, flow_mag.shape)
if scores.mean(-1) > max_score:
init_move_pos, init_static_pos, init_flow_mag, max_score = move_pos, static_pos, flow_mag, scores
# visualize samples
if vis:
fig, axs = plt.subplots(1, num_samples, figsize=(2 * num_samples, 2 * num_samples))
for i in range(num_samples):
move = move_pos[i, 0].cpu()
static = static_pos[i, 0].cpu()
flow_rgb = utils.flow_to_rgb(flow[i].cpu().permute(1, 2, 0))
axs[i].imshow(flow_rgb)
axs[i].scatter(move[1] * patch_size, move[0] * patch_size, color='green', s=20)
axs[i].set_axis_off()
axs[i].scatter(static[1] * patch_size, static[0] * patch_size, color='red', s=20)
fig.subplots_adjust(wspace=0.01, hspace=0.01) # Change these values to adjust space
plt.show()
plt.close()
## Step 2: iteratively add more moving and static locations to refine the segment
prev_flow_mag = init_flow_mag
prev_move_pos = init_move_pos
prev_static_pos = init_static_pos
npos_per_iter = 1
for it in range(num_iters):
print('Iteration', it)
sampling_dist = F.interpolate(prev_flow_mag, size=self.encoder.dims, mode='bilinear').mean(1)
# sample one moving position per example in the batch
move_pos = utils.sample_positions_from_dist(size=[1, npos_per_iter], dist=sampling_dist).repeat(N, 1, 1)
move_pos = torch.cat([prev_move_pos, move_pos], dim=1)
# each move position has N static positions and movement directions
static_pos = utils.sample_positions_from_dist(size=[B * N, npos_per_iter],
dist=-sampling_dist) # [BN, 1, 2]
static_pos = torch.cat([prev_static_pos, static_pos], dim=1)
pred = self.get_directional_counterfactual(_x, move_pos=move_pos, static_pos=static_pos, max_movement=max_movement)
flow, flow_mag = self.flow_interface(_x[:, :, 0].float(), pred.clamp(0, 1).float(), return_magnitude=True)
flow_mag = flow_mag.view(B, N, H, W)
scores = flow_mag.flatten(2, 3).std(dim=1).mean(-1) # [B, N]
print('scores', scores, flow_mag.shape)
if scores.mean(-1) > max_score:
init_move_pos, init_static_pos, init_flow_mag, max_score = move_pos, static_pos, flow_mag, scores
# visualize samples
if vis:
fig, axs = plt.subplots(1, num_samples, figsize=(2 * num_samples, 2 * num_samples))
for i in range(num_samples):
flow_rgb = utils.flow_to_rgb(flow[i].cpu().permute(1, 2, 0))
axs[i].imshow(flow_rgb)
axs[i].set_axis_off()
for k in range(move_pos.shape[1]):
move = move_pos[i, k].cpu()
static = static_pos[i, k].cpu()
axs[i].scatter(move[1] * patch_size, move[0] * patch_size, color='green', s=20)
axs[i].scatter(static[1] * patch_size, static[0] * patch_size, color='red', s=20)
fig.subplots_adjust(wspace=0.01, hspace=0.01) # Change these values to adjust space
plt.show()
plt.close()
## Step 3: iterate to add more moving and static motions
return None
@torch.no_grad()
def get_flow(self, img1, img2, conditioning_img=None, mode='jacobian', perturbation_patch_size=8, aggregation_patch_size=8, mask_ratio=0.0, num_scales=1, num_mask_samples=1):
'''
:param img1: input image 1 [B, C, H, W]
:param img2: input image 2 [B, C, H, W]
:param mode: which flow extraction method to use: 'jacobian' or 'optical_flow
:param mask_ratio: what frame2 mask ratio to use when extracting flow
:return: forward flow [B, 2, H, W]
'''
if mode == 'jacobian':
frame_size = 224 // self.patch_size[-1]
DFG = generator.DerivativeFlowGenerator(
predictor=self,
perturbation_patch_size=perturbation_patch_size,
aggregation_patch_size=aggregation_patch_size,
agg_power=None,
agg_channel_func=lambda x: F.relu(x.sum(-3, True)),
num_samples=5,
average_jacobian=False,
leave_one_out_sampling=False,
imagenet_normalize_inputs=False,
temporal_dim=2,
confidence_thresh=None
).to(img1.device)
maskgen_uniform = masking.PytorchMaskGeneratorWrapper(
mask_generator=masking.RotatedTableMaskingGenerator,
input_size=(self.num_frames, frame_size, frame_size),
mask_ratio=mask_ratio
).to(img1.device)
jac_fwd, forward_flow = flow_utils.extract_jacobians_and_flows(img1, img2,
DFG,
maskgen_uniform()[None])
else:
frame_size = 224 // self.patch_size[-1]
mask_generator = RotatedTableMaskingGenerator(
input_size=(self.num_frames, frame_size, frame_size),
mask_ratio=mask_ratio,
tube_length=1,
batch_size=1,
mask_type='rotated_table'
)
forward_flow, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self, mask_generator, img1, img2, conditioning_img=conditioning_img, num_scales=num_scales, N_mask_samples=num_mask_samples)
return forward_flow
@torch.no_grad()
def get_keypoints(self, img1, img2, img3=None, num_keypoints=10, samples_per_keypoint=1):
'''
:param img1: input image 1 [B, C, H, W] imagenet normalized
:param img2: input image 2 [B, C, H, W] imagenet normalized
:param num_keypoints: number of keypoints to extract
:param samples_per_keypoint: number of samples per keypoint
'''
if self.num_frames == 2:
x = torch.stack([img1, img2], dim=2)
else:
if img3 is None:
x = torch.stack([img1, img1, img2], dim=2)
else:
x = torch.stack([img1, img2, img3], dim=2)
mask, choices, err_array, feat, keypoint_recon = keypoint_utils.get_keypoints_batch(self, x, samples_per_keypoint, num_keypoints)
return mask, choices, err_array, feat, keypoint_recon
def pretrain_vit_base_224_scaffold(img_size=224, **kwargs):
model = PretrainVisionTransformer(
img_size=img_size,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_num_classes=0,
decoder_embed_dim=512,
decoder_num_heads=16,
decoder_depth=8,
mlp_ratio=4,
qkv_bias=True,
k_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
return model
def pretrain_videomae_base_224_scaffold(**kwargs):
model = PretrainVisionTransformer(
img_size=224,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_num_classes=0,
decoder_embed_dim=384,
decoder_num_heads=6,
decoder_depth=4,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
return model
def vitb_8x8patch_3frames(**kwargs):
model = pretrain_vit_base_224_scaffold(
patch_size=(8, 8),
num_frames=3,
tubelet_size=1,
use_flash_attention=True,
**kwargs)
return model
def vitb_8x8patch_2frames(**kwargs):
model = pretrain_vit_base_224_scaffold(
patch_size=(8, 8),
num_frames=2,
tubelet_size=1,
use_flash_attention=True,
**kwargs)
return model
def vitb_8x8patch_2frames_vmae(**kwargs):
model = pretrain_videomae_base_224_scaffold(
patch_size=(8, 8),
num_frames=2,
tubelet_size=1,
use_flash_attention=True,
**kwargs)
return model
def vitb_4x4patch_2frames(**kwargs):
model = pretrain_videomae_base_224_scaffold(
patch_size=(4, 4),
num_frames=2,
tubelet_size=1,
**kwargs)
return model
from cwm.model.modeling_pretrain_cleaned_soft import pretrain_vit_base_256_scaffold
def vitb_8x8patch_2frames_encoder_mask_token(
use_flash_attention=False, **kwargs):
model = pretrain_vit_base_256_scaffold(
patch_size=(8, 8),
num_frames=2,
tubelet_size=1,
use_flash_attention=use_flash_attention,
interp_noise=False,
legacy=False,
xla_flash=False,
learn_pos_embed=True,
**kwargs)
return model
# def base_8x8patch_2frames_1tube(**kwargs):
# model = pretrain_videomae_base_224_scaffold(
# patch_size=(8, 8),
# num_frames=2,
# tubelet_size=1,
# **kwargs)
# return model