Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from functools import partial | |
import torch | |
import torch.nn as nn | |
from timm.models.layers import trunc_normal_ as __call_trunc_normal_ | |
from cwm.data.masking_generator import RotatedTableMaskingGenerator | |
from cwm.model.model_utils import Block, _cfg, PatchEmbed, get_sinusoid_encoding_table | |
import cwm.eval.Flow.masking_flow as masking | |
import cwm.utils as utils | |
# from external.raft_interface import RAFTInterface | |
import cwm.eval.Flow.generator as generator | |
import matplotlib.pyplot as plt | |
import torch.nn.functional as F | |
from cwm.eval.Flow import flow_utils | |
import cwm.model.keypoint_utils as keypoint_utils | |
def trunc_normal_(tensor, mean=0., std=1.): | |
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) | |
def interpolate_pos_encoding(pos_embed, n_frames, h, w): | |
N = pos_embed.shape[1] | |
if N == (h * w * n_frames): | |
return pos_embed | |
old_h = old_w = int((N / n_frames) ** 0.5) | |
patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2) | |
patch_pos_embed = F.interpolate( | |
patch_pos_embed, | |
size=(h, w), | |
mode='bilinear', | |
) | |
return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) | |
class PretrainVisionTransformerEncoder(nn.Module): | |
""" Vision Transformer with support for patch or hybrid CNN input stage | |
""" | |
def __init__(self, img_size=224, patch_size=(16, 16), in_chans=3, num_classes=0, embed_dim=768, depth=12, | |
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., | |
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, tubelet_size=2, | |
num_frames=16, block_func=Block, k_bias=False, use_learnable_pos_emb=False, block_kwargs={}): | |
super().__init__() | |
self.num_classes = num_classes | |
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |
self.patch_size = (tubelet_size,) + patch_size | |
self.pt, self.ph, self.pw = self.patch_size | |
self.h = int(img_size / self.ph) | |
self.w = int(img_size / self.pw) | |
self.hw = self.h * self.w | |
self.dims = [self.h, self.w] | |
self.patch_embed = PatchEmbed( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
tubelet_size=tubelet_size, | |
num_frames=num_frames | |
) | |
num_patches = self.patch_embed.num_patches | |
self.num_patches = num_patches | |
self.num_frames = num_frames | |
if use_learnable_pos_emb: | |
self.use_learnable_pos_emb = True | |
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) | |
trunc_normal_(self.pos_embed, std=.02) | |
else: | |
# sine-cosine positional embeddings | |
self.use_learnable_pos_emb = False | |
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim) | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule | |
self.blocks = nn.ModuleList([ | |
block_func( | |
dim=embed_dim, in_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, | |
drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, **block_kwargs, k_bias=k_bias, | |
xla_flash=True) | |
for i in range(depth)]) | |
self.norm = norm_layer(embed_dim) | |
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def get_num_layers(self): | |
return len(self.blocks) | |
def no_weight_decay(self): | |
return {'pos_embed', 'cls_token'} | |
def get_classifier(self): | |
return self.head | |
def reset_classifier(self, num_classes, global_pool=''): | |
self.num_classes = num_classes | |
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
def _get_pos_embed(self): | |
return self.pos_embed | |
def forward_block(self, x, idx): | |
return self.blocks[idx](x) | |
def forward_features(self, x, mask, move_pos=None, static_pos=None, movement=None, res=1): | |
T = x.shape[2] | |
x = embed = self.patch_embed(x) | |
pos_embed = self._get_pos_embed().type_as(x).to(x.device).clone() | |
if not self.use_learnable_pos_emb: | |
pos_embed = pos_embed.detach() | |
if res != 1: | |
print("res") | |
p0 = self.patch_size[-2] | |
p1 = self.patch_size[-1] | |
pos_embed = interpolate_pos_encoding(self.pos_embed, T, int(224 // p0 * res), int(224 // p1 * res)) | |
x = x + pos_embed | |
B, _, C = x.shape | |
x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible | |
if move_pos is not None: | |
h, w = self.h, self.w | |
first_frame_emb = embed[:, :self.hw].view(B, h, w, C) # [B, h, w, C] | |
last_frame_pos_emb = pos_embed[:, -self.hw:].view(1, h, w, C).expand(B, -1, -1, -1) # [B, h, w, C] | |
denominator = torch.tensor([self.h, self.w]).view(1, 1, 2).to(x.device) | |
new_pos = move_pos + movement # [B, P, 2] | |
move_pos = move_pos / denominator * 2 - 1 | |
new_pos = (new_pos / denominator).clamp(0, 1) * 2 - 1 # handle special case where new_pos is out of bounds | |
static_pos = static_pos / denominator * 2 - 1 | |
moving_emb = utils.sample_embedding(first_frame_emb, move_pos, mode='nearest') # [B, P, C] | |
moving_pos_emb = utils.sample_embedding(last_frame_pos_emb, new_pos, mode='nearest') # [B, P, C] | |
static_emb = utils.sample_embedding(first_frame_emb, static_pos, mode='nearest') # [B, P, C] | |
static_pos_emb = utils.sample_embedding(last_frame_pos_emb, static_pos, mode='nearest') # [B, P, C] | |
x_vis = torch.cat([x_vis, moving_emb + moving_pos_emb, static_emb + static_pos_emb], dim=1) | |
# assert B == 1, "Only support batch size 1 for now" | |
# for (px, py) in move_patches: | |
# idx = px * self.w + py | |
# dx, dy = delta | |
# nx, ny = px + dx, py + dy | |
# new_idx = nx * self.w + ny + (self.patch_embed.num_frames - 1) * (self.h * self.w) | |
# | |
# emb = embed[:, idx] | |
# pos_emb = pos_embed[:, new_idx] | |
# emb = emb + pos_emb | |
# x_vis = torch.cat([x_vis, emb[None]], 1) | |
# if static_patches is not None: | |
# for (px, py) in static_patches: | |
# idx = px * self.w + py | |
# new_idx = px * self.w + py + (self.patch_embed.num_frames - 1) * (self.h * self.w) | |
# emb = embed[:, idx] | |
# pos_emb = pos_embed[:, new_idx] | |
# emb = emb + pos_emb | |
# x_vis = torch.cat([x_vis, emb[None]], 1) | |
for blk in self.blocks: | |
x_vis = blk(x_vis) | |
x_vis = self.norm(x_vis) | |
return x_vis | |
def _set_inputs(self, *args, **kwargs): | |
pass | |
def forward(self, x, mask, move_pos=None, static_pos=None, movement=None, res=1): | |
self._set_inputs(x, mask) | |
x = self.forward_features(x, mask, move_pos, static_pos, movement, res=res) | |
x = self.head(x) | |
return x | |
class PretrainVisionTransformerDecoder(nn.Module): | |
""" Vision Transformer with support for patch or hybrid CNN input stage | |
""" | |
def __init__(self, patch_size=(16, 16), num_classes=768, embed_dim=768, depth=12, | |
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., | |
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, block_func=Block, block_kwargs={}, | |
k_bias=False | |
): | |
super().__init__() | |
self.num_classes = num_classes | |
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |
self.patch_size = patch_size | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule | |
self.blocks = nn.ModuleList([ | |
block_func( | |
dim=embed_dim, in_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, | |
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, | |
init_values=init_values, **block_kwargs, k_bias=k_bias, | |
xla_flash=True) | |
for i in range(depth)]) | |
self.norm = norm_layer(embed_dim) | |
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def get_num_layers(self): | |
return len(self.blocks) | |
def no_weight_decay(self): | |
return {'pos_embed', 'cls_token'} | |
def get_classifier(self): | |
return self.head | |
def reset_classifier(self, num_classes, global_pool=''): | |
self.num_classes = num_classes | |
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() | |
def forward_block(self, x, idx): | |
return self.blocks[idx](x) | |
def get_last_tokens(self, x, return_token_num): | |
if return_token_num > 0: | |
return self.head(self.norm(x[:, -return_token_num:])) | |
elif return_token_num == 0: | |
return self.head(self.norm(x))[:, x.size(1):] | |
else: | |
return self.head(self.norm(x)) | |
def forward(self, x, return_token_num): | |
for blk in self.blocks: | |
x = blk(x) | |
if return_token_num > 0: | |
x = self.head(self.norm(x[:, -return_token_num:])) # only return the mask tokens predict pixels | |
else: | |
x = self.head(self.norm(x)) | |
return x | |
class PretrainVisionTransformer(nn.Module): | |
""" Vision Transformer with support for patch or hybrid CNN input stage | |
""" | |
default_input_kwargs = {'unnormalize': True} | |
def __init__(self, | |
img_size=224, | |
patch_size=(16, 16), | |
encoder_func=PretrainVisionTransformerEncoder, | |
encoder_in_chans=3, | |
encoder_num_classes=0, | |
encoder_embed_dim=768, | |
encoder_depth=12, | |
encoder_num_heads=12, | |
encoder_block_func=Block, | |
encoder_block_kwargs={}, | |
decoder_num_classes=None, | |
# For pretraining this parameter isn't relevant but must be set according to tube&patch size | |
decoder_embed_dim=512, | |
decoder_depth=8, | |
decoder_num_heads=8, | |
decoder_block_func=Block, | |
decoder_block_kwargs={}, | |
mlp_ratio=4., | |
qkv_bias=False, | |
k_bias=False, | |
qk_scale=None, | |
num_frames=16, | |
drop_rate=0., | |
attn_drop_rate=0., | |
drop_path_rate=0., | |
norm_layer=nn.LayerNorm, | |
init_values=0., | |
tubelet_size=2, | |
use_flash_attention=False, | |
use_learnable_pos_emb=False, | |
**kwargs | |
): | |
super().__init__() | |
encoder_block_kwargs.update({'flash_attention': use_flash_attention}) | |
decoder_block_kwargs.update({'flash_attention': use_flash_attention}) | |
self.tubelet_size = tubelet_size | |
num_classes = 3 * tubelet_size * ( | |
patch_size[0] * patch_size[1]) if decoder_num_classes is None else decoder_num_classes | |
self.encoder = encoder_func( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_chans=encoder_in_chans, | |
num_classes=encoder_num_classes, | |
embed_dim=encoder_embed_dim, | |
depth=encoder_depth, | |
num_heads=encoder_num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
drop_rate=drop_rate, | |
attn_drop_rate=attn_drop_rate, | |
drop_path_rate=drop_path_rate, | |
norm_layer=norm_layer, | |
init_values=init_values, | |
tubelet_size=tubelet_size, | |
num_frames=num_frames, | |
block_func=encoder_block_func, | |
block_kwargs=encoder_block_kwargs, | |
use_learnable_pos_emb=use_learnable_pos_emb, | |
k_bias=k_bias, | |
**kwargs) | |
self.decoder = PretrainVisionTransformerDecoder( | |
patch_size=patch_size, | |
num_classes=num_classes, | |
embed_dim=decoder_embed_dim, | |
depth=decoder_depth, | |
num_heads=decoder_num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
drop_rate=drop_rate, | |
attn_drop_rate=attn_drop_rate, | |
drop_path_rate=drop_path_rate, | |
norm_layer=norm_layer, | |
init_values=init_values, | |
block_func=decoder_block_func, | |
k_bias=k_bias, | |
block_kwargs=decoder_block_kwargs) | |
self.encoder_to_decoder = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=k_bias) | |
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) | |
trunc_normal_(self.mask_token, std=.02) | |
if use_learnable_pos_emb: | |
self.use_learnable_pos_emb = True | |
self.pos_embed = nn.Parameter(torch.zeros(self.encoder.num_patches, decoder_embed_dim)) | |
trunc_normal_(self.pos_embed, std=.02) | |
else: | |
self.use_learnable_pos_emb = False | |
self.pos_embed = get_sinusoid_encoding_table(self.encoder.num_patches, decoder_embed_dim) | |
self.num_frames = num_frames | |
self.num_patches = self.encoder.num_patches | |
if self.num_frames is not None: | |
self.num_patches_per_frame = self.num_patches // self.num_frames | |
else: | |
self.num_patches_per_frame = self.num_patches | |
self.patch_size = self.encoder.patch_size | |
if isinstance(img_size, int): | |
self.image_size = (img_size, img_size) | |
else: | |
assert hasattr(img_size, '__len__'), img_size | |
self.image_size = img_size | |
# self.flow_interface = RAFTInterface() | |
def mask_size(self): | |
return (self.num_frames // self.patch_size[0], | |
self.image_size[-2] // self.patch_size[-2], | |
self.image_size[-1] // self.patch_size[-1]) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
def get_num_layers(self): | |
return len(self.blocks) | |
def no_weight_decay(self): | |
return {'pos_embed', 'cls_token', 'mask_token'} | |
def adjust_input_resolution(self, H, W): | |
if self.image_size == [H, W]: | |
return | |
patch_size = self.encoder.patch_size[-2:] | |
self.image_size = [H, W] | |
self.encoder.h = int(H / self.encoder.ph) | |
self.encoder.w = int(W / self.encoder.pw) | |
self.encoder.hw = self.encoder.h * self.encoder.w | |
self.encoder.dims = [self.encoder.h, self.encoder.w] | |
dims = [int(s / p) for s, p in zip(self.image_size, patch_size)] | |
self.encoder.pos_embed = utils.interpolate_pos_encoding(self.encoder.pos_embed, 3, dims[0], dims[1]) | |
print('pos_embed', self.encoder.pos_embed.shape) | |
self.pos_embed = utils.interpolate_pos_encoding(self.pos_embed, 3, dims[0], dims[1]) | |
def forward(self, x, mask, forward_full=False, return_features=False, res=1, *args, **kwargs): | |
_, _, T, _, _ = x.shape | |
self.device = x.device | |
num_patches_per_frame = (x.shape[-1] // self.encoder.patch_size[1]) ** 2 | |
x_vis = self.encoder(x, mask, res=res, *args, **kwargs) | |
if return_features: | |
return x_vis | |
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d] | |
B, N, C = x_vis.shape | |
# add pos embedding | |
# if res != 1: | |
# p0 = self.patch_size[-2] | |
# p1 = self.patch_size[-1] | |
# pos_embed = interpolate_pos_encoding(self.pos_embed.unsqueeze(0), T, int(224 // p0 * res), int(224 // p1 * res)) | |
# else: | |
# pos_embed = self.pos_embed.unsqueeze(0) | |
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone() | |
if not self.use_learnable_pos_emb: | |
expand_pos_embed = expand_pos_embed.detach() | |
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C) | |
pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C) | |
nctx = num_patches_per_frame * (self.num_frames - 1) | |
x_vis = x_vis + pos_emd_vis | |
x_full = torch.cat([x_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d] | |
if forward_full: | |
x_full = torch.cat([x_vis, self.mask_token + expand_pos_embed[:, nctx:]], dim=1) # [B, N, C_d] | |
x_all = self.decoder(x_full, num_patches_per_frame) | |
x = x_all | |
else: | |
x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16] | |
return x | |
def get_counterfactual(self, x, move_patches): | |
''' | |
:param x: input tensor [1, C, T, H, W]: support only batch size 1 for now | |
:param move_patches: torch tensor [N, 4] sized array where each row contains patch motion [x1, y1, x2, y2] in pixel coordinates | |
:return: | |
''' | |
B, _, T, H, H = x.shape | |
mask = torch.ones(B, self.encoder.hw * self.encoder.num_frames).to(x.device).bool() | |
mask[:, :self.encoder.hw * (self.encoder.num_frames - 1)] = False | |
move_patches = (move_patches/H)*self.encoder.h | |
move_patches = move_patches.to(torch.int64) | |
for x1, y1, x2, y2 in move_patches: | |
idx2 = x2*self.encoder.w + y2 + (self.encoder.num_frames - 1) * (self.encoder.h * self.encoder.w) | |
mask[:, idx2] = False | |
im_x1 = x1*self.encoder.ph | |
im_y1 = y1*self.encoder.pw | |
im_x2 = x2*self.encoder.ph | |
im_y2 = y2*self.encoder.pw | |
x[:, :, -1, im_x2:im_x2+self.encoder.ph, im_y2:im_y2+self.encoder.pw] = x[:, :, -2, im_x1:im_x1+self.encoder.ph, im_y1:im_y1+self.encoder.pw] | |
prediction = self.forward(x, mask, forward_full=True) | |
prediction = utils.unpatchify_cwm( | |
prediction, | |
patch_size=self.encoder.patch_size[-1], | |
) # reshape the output to an image | |
return prediction | |
def get_directional_counterfactual(self, x, mask=None, move_pos=None, static_pos=None, movement=None, max_movement=None): | |
B, _, T, _, _ = x.shape | |
if mask is None: # default mask: all visible but the last frame | |
mask = torch.ones(B, self.encoder.hw * self.encoder.num_frames).to(x.device).bool() | |
mask[:, :self.encoder.hw * (self.encoder.num_frames - 1)] = False | |
if movement is None: # generate random motion if movement is not specified | |
assert max_movement is not None and move_pos is not None | |
movement = torch.randint(-max_movement, max_movement, move_pos.shape).to(x.device) # [B, num_samples, 2] | |
x_vis = self.encoder(x, mask, move_pos=move_pos, static_pos=static_pos, movement=movement) # [B, N_vis, C_e] | |
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d] | |
B, N, C = x_vis.shape | |
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(x.device).clone().detach() | |
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C) | |
if move_pos is not None: | |
h, w = self.encoder.h, self.encoder.w | |
last_frame_pos_emb = expand_pos_embed[:, -(h * w):].view(B, h, w, C) # [B, h, w, C] | |
# compute new locations of the moved patche, snormalize positions to range [-1, 1] | |
new_pos = move_pos + movement # [B, P, 2] | |
denominator = torch.tensor([h, w]).view(1, 1, 2).to(x.device) | |
new_pos = (new_pos / denominator).clamp(0, 1) * 2 - 1 | |
static_pos = static_pos / denominator * 2 - 1 | |
# sample the position embeddings of the moved and static patches | |
moving_pos_emb = utils.sample_embedding(last_frame_pos_emb, new_pos, mode='nearest') # [B, P, C] | |
static_pos_emb = utils.sample_embedding(last_frame_pos_emb, static_pos, mode='nearest') # [B, P, C] | |
# concatenate with the position embeddings to the visible patches | |
pos_emd_vis = torch.cat([pos_emd_vis, moving_pos_emb, static_pos_emb], dim=1) | |
# assert B == 1, "Only support batch size 1 for now" | |
# offset = (self.encoder.patch_embed.num_frames - 1) * (self.encoder.h * self.encoder.w) | |
# for (px, py) in move_patches: | |
# dx, dy = delta | |
# nx, ny = px + dx, py + dy | |
# new_idx = nx * self.encoder.w + ny + offset | |
# pos_emb = expand_pos_embed[:, new_idx] | |
# pos_emd_vis = torch.cat([pos_emd_vis, pos_emb[None]], 1) | |
# | |
# if static_patches is not None: | |
# for (px, py) in static_patches: | |
# new_idx = px * self.encoder.w + py + offset | |
# pos_emb = expand_pos_embed[:, new_idx] | |
# pos_emd_vis = torch.cat([pos_emd_vis, pos_emb[None]], 1) | |
pos_emd_mask = expand_pos_embed[mask].reshape(B, -1, C) | |
x_vis = x_vis + pos_emd_vis | |
x_full = torch.cat([x_vis, self.mask_token + pos_emd_mask], dim=1) # [B, N, C_d] | |
x = self.decoder(x_full, pos_emd_mask.shape[1]) # [B, N_mask, 3 * 16 * 16] | |
prediction = utils.unpatchify_cwm( | |
x, | |
patch_size=self.encoder.patch_size[-1], | |
mask=mask[:, -self.encoder.hw:] | |
) # reshape the output to an image | |
return prediction | |
def get_segment(self, x, mask=None, sampling_dist=None, num_segments=1, num_iters=4, num_samples=4, max_movement=4, | |
vis=True): | |
B, C, T, H, W = x.shape | |
N = num_samples | |
patch_size = self.encoder.patch_size[-1] | |
self.adjust_input_resolution(H, W) | |
# ## Step 0: define the sampling distribution for moving and static locations | |
# if sampling_dist is None: | |
# sampling_dist = utils.get_dino_predominance(x[:, :, 0], dims=self.encoder.dims, img_size=self.image_size)[0] | |
# sampling_dist = F.interpolate(sampling_dist, self.encoder.dims, mode='bilinear', align_corners=False) | |
# sampling_dist = sampling_dist.squeeze(1) ** 4 | |
# | |
# if vis: | |
# print('sampling_dist', sampling_dist.shape) | |
# plt.imshow(sampling_dist[0].cpu().numpy()) | |
# plt.title(f'Sampling distribution (max:{sampling_dist.max():.3f})') | |
# plt.show() | |
## Step 1: sample initial moving and static locations from the distribution | |
init_move_pos, init_static_pos, init_flow_mag, max_score = None, None, None, 0 | |
# Sample multiple positions for each segment and select the one with consistent outputs | |
for _ in range(N): | |
# sample one moving position per example in the batch | |
move_pos = utils.sample_positions_from_dist(size=[1, 1], dist=sampling_dist).repeat(N, 1, 1) # [BN, 1, 2] | |
# each move position has N static positions and movement directions | |
static_pos = utils.sample_positions_from_dist(size=[B * N, 1], dist=-sampling_dist) # [BN, 1, 2] | |
## compute initial flow maps | |
_x = x.repeat(N, 1, 1, 1, 1) # [BN, C, T, H, W] | |
pred = self.get_directional_counterfactual(_x, move_pos=move_pos, static_pos=static_pos, max_movement=max_movement) | |
flow, flow_mag = self.flow_interface(_x[:, :, 0].float(), pred.clamp(0, 1).float(), return_magnitude=True) | |
flow_mag = flow_mag.view(B, N, H, W) | |
scores = flow_mag.flatten(2, 3).std(dim=1).mean(-1) # [B, N] | |
print('scores', scores, flow_mag.shape) | |
if scores.mean(-1) > max_score: | |
init_move_pos, init_static_pos, init_flow_mag, max_score = move_pos, static_pos, flow_mag, scores | |
# visualize samples | |
if vis: | |
fig, axs = plt.subplots(1, num_samples, figsize=(2 * num_samples, 2 * num_samples)) | |
for i in range(num_samples): | |
move = move_pos[i, 0].cpu() | |
static = static_pos[i, 0].cpu() | |
flow_rgb = utils.flow_to_rgb(flow[i].cpu().permute(1, 2, 0)) | |
axs[i].imshow(flow_rgb) | |
axs[i].scatter(move[1] * patch_size, move[0] * patch_size, color='green', s=20) | |
axs[i].set_axis_off() | |
axs[i].scatter(static[1] * patch_size, static[0] * patch_size, color='red', s=20) | |
fig.subplots_adjust(wspace=0.01, hspace=0.01) # Change these values to adjust space | |
plt.show() | |
plt.close() | |
## Step 2: iteratively add more moving and static locations to refine the segment | |
prev_flow_mag = init_flow_mag | |
prev_move_pos = init_move_pos | |
prev_static_pos = init_static_pos | |
npos_per_iter = 1 | |
for it in range(num_iters): | |
print('Iteration', it) | |
sampling_dist = F.interpolate(prev_flow_mag, size=self.encoder.dims, mode='bilinear').mean(1) | |
# sample one moving position per example in the batch | |
move_pos = utils.sample_positions_from_dist(size=[1, npos_per_iter], dist=sampling_dist).repeat(N, 1, 1) | |
move_pos = torch.cat([prev_move_pos, move_pos], dim=1) | |
# each move position has N static positions and movement directions | |
static_pos = utils.sample_positions_from_dist(size=[B * N, npos_per_iter], | |
dist=-sampling_dist) # [BN, 1, 2] | |
static_pos = torch.cat([prev_static_pos, static_pos], dim=1) | |
pred = self.get_directional_counterfactual(_x, move_pos=move_pos, static_pos=static_pos, max_movement=max_movement) | |
flow, flow_mag = self.flow_interface(_x[:, :, 0].float(), pred.clamp(0, 1).float(), return_magnitude=True) | |
flow_mag = flow_mag.view(B, N, H, W) | |
scores = flow_mag.flatten(2, 3).std(dim=1).mean(-1) # [B, N] | |
print('scores', scores, flow_mag.shape) | |
if scores.mean(-1) > max_score: | |
init_move_pos, init_static_pos, init_flow_mag, max_score = move_pos, static_pos, flow_mag, scores | |
# visualize samples | |
if vis: | |
fig, axs = plt.subplots(1, num_samples, figsize=(2 * num_samples, 2 * num_samples)) | |
for i in range(num_samples): | |
flow_rgb = utils.flow_to_rgb(flow[i].cpu().permute(1, 2, 0)) | |
axs[i].imshow(flow_rgb) | |
axs[i].set_axis_off() | |
for k in range(move_pos.shape[1]): | |
move = move_pos[i, k].cpu() | |
static = static_pos[i, k].cpu() | |
axs[i].scatter(move[1] * patch_size, move[0] * patch_size, color='green', s=20) | |
axs[i].scatter(static[1] * patch_size, static[0] * patch_size, color='red', s=20) | |
fig.subplots_adjust(wspace=0.01, hspace=0.01) # Change these values to adjust space | |
plt.show() | |
plt.close() | |
## Step 3: iterate to add more moving and static motions | |
return None | |
def get_flow(self, img1, img2, conditioning_img=None, mode='jacobian', perturbation_patch_size=8, aggregation_patch_size=8, mask_ratio=0.0, num_scales=1, num_mask_samples=1): | |
''' | |
:param img1: input image 1 [B, C, H, W] | |
:param img2: input image 2 [B, C, H, W] | |
:param mode: which flow extraction method to use: 'jacobian' or 'optical_flow | |
:param mask_ratio: what frame2 mask ratio to use when extracting flow | |
:return: forward flow [B, 2, H, W] | |
''' | |
if mode == 'jacobian': | |
frame_size = 224 // self.patch_size[-1] | |
DFG = generator.DerivativeFlowGenerator( | |
predictor=self, | |
perturbation_patch_size=perturbation_patch_size, | |
aggregation_patch_size=aggregation_patch_size, | |
agg_power=None, | |
agg_channel_func=lambda x: F.relu(x.sum(-3, True)), | |
num_samples=5, | |
average_jacobian=False, | |
leave_one_out_sampling=False, | |
imagenet_normalize_inputs=False, | |
temporal_dim=2, | |
confidence_thresh=None | |
).to(img1.device) | |
maskgen_uniform = masking.PytorchMaskGeneratorWrapper( | |
mask_generator=masking.RotatedTableMaskingGenerator, | |
input_size=(self.num_frames, frame_size, frame_size), | |
mask_ratio=mask_ratio | |
).to(img1.device) | |
jac_fwd, forward_flow = flow_utils.extract_jacobians_and_flows(img1, img2, | |
DFG, | |
maskgen_uniform()[None]) | |
else: | |
frame_size = 224 // self.patch_size[-1] | |
mask_generator = RotatedTableMaskingGenerator( | |
input_size=(self.num_frames, frame_size, frame_size), | |
mask_ratio=mask_ratio, | |
tube_length=1, | |
batch_size=1, | |
mask_type='rotated_table' | |
) | |
forward_flow, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self, mask_generator, img1, img2, conditioning_img=conditioning_img, num_scales=num_scales, N_mask_samples=num_mask_samples) | |
return forward_flow | |
def get_keypoints(self, img1, img2, img3=None, num_keypoints=10, samples_per_keypoint=1): | |
''' | |
:param img1: input image 1 [B, C, H, W] imagenet normalized | |
:param img2: input image 2 [B, C, H, W] imagenet normalized | |
:param num_keypoints: number of keypoints to extract | |
:param samples_per_keypoint: number of samples per keypoint | |
''' | |
if self.num_frames == 2: | |
x = torch.stack([img1, img2], dim=2) | |
else: | |
if img3 is None: | |
x = torch.stack([img1, img1, img2], dim=2) | |
else: | |
x = torch.stack([img1, img2, img3], dim=2) | |
mask, choices, err_array, feat, keypoint_recon = keypoint_utils.get_keypoints_batch(self, x, samples_per_keypoint, num_keypoints) | |
return mask, choices, err_array, feat, keypoint_recon | |
def pretrain_vit_base_224_scaffold(img_size=224, **kwargs): | |
model = PretrainVisionTransformer( | |
img_size=img_size, | |
encoder_embed_dim=768, | |
encoder_depth=12, | |
encoder_num_heads=12, | |
encoder_num_classes=0, | |
decoder_embed_dim=512, | |
decoder_num_heads=16, | |
decoder_depth=8, | |
mlp_ratio=4, | |
qkv_bias=True, | |
k_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs) | |
model.default_cfg = _cfg() | |
return model | |
def pretrain_videomae_base_224_scaffold(**kwargs): | |
model = PretrainVisionTransformer( | |
img_size=224, | |
encoder_embed_dim=768, | |
encoder_depth=12, | |
encoder_num_heads=12, | |
encoder_num_classes=0, | |
decoder_embed_dim=384, | |
decoder_num_heads=6, | |
decoder_depth=4, | |
mlp_ratio=4, | |
qkv_bias=True, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
**kwargs) | |
model.default_cfg = _cfg() | |
return model | |
def vitb_8x8patch_3frames(**kwargs): | |
model = pretrain_vit_base_224_scaffold( | |
patch_size=(8, 8), | |
num_frames=3, | |
tubelet_size=1, | |
use_flash_attention=True, | |
**kwargs) | |
return model | |
def vitb_8x8patch_2frames(**kwargs): | |
model = pretrain_vit_base_224_scaffold( | |
patch_size=(8, 8), | |
num_frames=2, | |
tubelet_size=1, | |
use_flash_attention=True, | |
**kwargs) | |
return model | |
def vitb_8x8patch_2frames_vmae(**kwargs): | |
model = pretrain_videomae_base_224_scaffold( | |
patch_size=(8, 8), | |
num_frames=2, | |
tubelet_size=1, | |
use_flash_attention=True, | |
**kwargs) | |
return model | |
def vitb_4x4patch_2frames(**kwargs): | |
model = pretrain_videomae_base_224_scaffold( | |
patch_size=(4, 4), | |
num_frames=2, | |
tubelet_size=1, | |
**kwargs) | |
return model | |
from cwm.model.modeling_pretrain_cleaned_soft import pretrain_vit_base_256_scaffold | |
def vitb_8x8patch_2frames_encoder_mask_token( | |
use_flash_attention=False, **kwargs): | |
model = pretrain_vit_base_256_scaffold( | |
patch_size=(8, 8), | |
num_frames=2, | |
tubelet_size=1, | |
use_flash_attention=use_flash_attention, | |
interp_noise=False, | |
legacy=False, | |
xla_flash=False, | |
learn_pos_embed=True, | |
**kwargs) | |
return model | |
# def base_8x8patch_2frames_1tube(**kwargs): | |
# model = pretrain_videomae_base_224_scaffold( | |
# patch_size=(8, 8), | |
# num_frames=2, | |
# tubelet_size=1, | |
# **kwargs) | |
# return model | |