Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import torch.nn as nn | |
import cwm.model.model_pretrain as vmae_tranformers | |
from . import flow_utils | |
from . import losses as bblosses | |
# Normal Resolution | |
def l2_norm(x): | |
return x.square().sum(-3, True).sqrt() | |
# x.shape | |
def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5): | |
fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd) | |
flow_diff_fwd = flow_fwd + fwd_bck_cycle | |
bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck) | |
flow_diff_bck = flow_bck + bck_fwd_cycle | |
norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2 | |
norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2 | |
occ_thresh_fwd = occ_thresh * norm_fwd + 0.5 | |
occ_thresh_bck = occ_thresh * norm_bck + 0.5 | |
occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float() | |
occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float() | |
return occ_mask_fwd, occ_mask_bck | |
class ExtractFlow(nn.Module): | |
def __init__(self): | |
super().__init__() | |
return | |
def forward(self, img1, img2): | |
''' | |
img1: first frame | |
img2: second frame | |
returns: flow map (h, w, 2) | |
''' | |
from cwm.data.masking_generator import RotatedTableMaskingGenerator | |
class CWM(ExtractFlow): | |
def __init__(self, model_name, patch_size, weights_path): | |
super().__init__() | |
self.patch_size = patch_size | |
model = getattr(vmae_tranformers, model_name) | |
vmae_8x8_full = model().cuda().eval().requires_grad_(False) | |
VMAE_LOAD_PATH = weights_path | |
did_load = vmae_8x8_full.load_state_dict(torch.load(VMAE_LOAD_PATH)['model'], strict=False) | |
print(did_load, VMAE_LOAD_PATH) | |
self.predictor = vmae_8x8_full | |
self.mask_generator = RotatedTableMaskingGenerator( | |
input_size=(vmae_8x8_full.num_frames, 28, 28), | |
mask_ratio=0.0, | |
tube_length=1, | |
batch_size=1, | |
mask_type='rotated_table' | |
) | |
def forward(self, img1, img2): | |
''' | |
img1: [3, 1024, 1024] | |
img1: [3, 1024, 1024] | |
both images are imagenet normalized | |
''' | |
with torch.no_grad(): | |
FF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor, | |
self.mask_generator, img1[None], | |
img2[None], | |
num_scales=2, | |
min_scale=224, | |
N_mask_samples=1) | |
BF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor, | |
self.mask_generator, | |
img2[None], | |
img1[None], | |
num_scales=2, | |
min_scale=224, | |
N_mask_samples=1) | |
# FF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor, | |
# self.mask_generator, img1[None], | |
# img2[None], img2[None], | |
# neg_back_flow=True, num_scales=1, | |
# min_scale=224, N_mask_samples=1, | |
# mask_ratio=0.0) | |
# | |
# BF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor, | |
# self.mask_generator, img2[None], | |
# img1[None], img1[None], | |
# neg_back_flow=True, num_scales=1, | |
# min_scale=224, N_mask_samples=1, | |
# mask_ratio=0.0) | |
occ_mask = get_occ_masks(FF, BF)[0] | |
FF = FF * occ_mask | |
FF = FF[0] | |
return FF#.cpu().numpy().transpose([1, 2, 0]) | |
class CWM_8x8(CWM): | |
def __init__(self): | |
super().__init__('vitb_8x8patch_3frames', 8, | |
'/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth') | |