import torch import torch.nn as nn import cwm.model.model_pretrain as vmae_tranformers from . import flow_utils from . import losses as bblosses # Normal Resolution def l2_norm(x): return x.square().sum(-3, True).sqrt() # x.shape def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5): fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd) flow_diff_fwd = flow_fwd + fwd_bck_cycle bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck) flow_diff_bck = flow_bck + bck_fwd_cycle norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2 norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2 occ_thresh_fwd = occ_thresh * norm_fwd + 0.5 occ_thresh_bck = occ_thresh * norm_bck + 0.5 occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float() occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float() return occ_mask_fwd, occ_mask_bck class ExtractFlow(nn.Module): def __init__(self): super().__init__() return def forward(self, img1, img2): ''' img1: first frame img2: second frame returns: flow map (h, w, 2) ''' from cwm.data.masking_generator import RotatedTableMaskingGenerator class CWM(ExtractFlow): def __init__(self, model_name, patch_size, weights_path): super().__init__() self.patch_size = patch_size model = getattr(vmae_tranformers, model_name) vmae_8x8_full = model().cuda().eval().requires_grad_(False) VMAE_LOAD_PATH = weights_path did_load = vmae_8x8_full.load_state_dict(torch.load(VMAE_LOAD_PATH)['model'], strict=False) print(did_load, VMAE_LOAD_PATH) self.predictor = vmae_8x8_full self.mask_generator = RotatedTableMaskingGenerator( input_size=(vmae_8x8_full.num_frames, 28, 28), mask_ratio=0.0, tube_length=1, batch_size=1, mask_type='rotated_table' ) def forward(self, img1, img2): ''' img1: [3, 1024, 1024] img1: [3, 1024, 1024] both images are imagenet normalized ''' with torch.no_grad(): FF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor, self.mask_generator, img1[None], img2[None], num_scales=2, min_scale=224, N_mask_samples=1) BF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor, self.mask_generator, img2[None], img1[None], num_scales=2, min_scale=224, N_mask_samples=1) # FF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor, # self.mask_generator, img1[None], # img2[None], img2[None], # neg_back_flow=True, num_scales=1, # min_scale=224, N_mask_samples=1, # mask_ratio=0.0) # # BF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor, # self.mask_generator, img2[None], # img1[None], img1[None], # neg_back_flow=True, num_scales=1, # min_scale=224, N_mask_samples=1, # mask_ratio=0.0) occ_mask = get_occ_masks(FF, BF)[0] FF = FF * occ_mask FF = FF[0] return FF#.cpu().numpy().transpose([1, 2, 0]) class CWM_8x8(CWM): def __init__(self): super().__init__('vitb_8x8patch_3frames', 8, '/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth')