counterfactual-world-models / cwm /eval /Flow /flow_extraction_classes.py
rahulvenkk
app.py updated
6dfcb0f
raw
history blame
5.26 kB
import torch
import torch.nn as nn
import cwm.model.model_pretrain as vmae_tranformers
from . import flow_utils
from . import losses as bblosses
# Normal Resolution
def l2_norm(x):
return x.square().sum(-3, True).sqrt()
# x.shape
def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5):
fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd)
flow_diff_fwd = flow_fwd + fwd_bck_cycle
bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck)
flow_diff_bck = flow_bck + bck_fwd_cycle
norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2
norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2
occ_thresh_fwd = occ_thresh * norm_fwd + 0.5
occ_thresh_bck = occ_thresh * norm_bck + 0.5
occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float()
occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float()
return occ_mask_fwd, occ_mask_bck
class ExtractFlow(nn.Module):
def __init__(self):
super().__init__()
return
def forward(self, img1, img2):
'''
img1: first frame
img2: second frame
returns: flow map (h, w, 2)
'''
from cwm.data.masking_generator import RotatedTableMaskingGenerator
class CWM(ExtractFlow):
def __init__(self, model_name, patch_size, weights_path):
super().__init__()
self.patch_size = patch_size
model = getattr(vmae_tranformers, model_name)
vmae_8x8_full = model().cuda().eval().requires_grad_(False)
VMAE_LOAD_PATH = weights_path
did_load = vmae_8x8_full.load_state_dict(torch.load(VMAE_LOAD_PATH)['model'], strict=False)
print(did_load, VMAE_LOAD_PATH)
self.predictor = vmae_8x8_full
self.mask_generator = RotatedTableMaskingGenerator(
input_size=(vmae_8x8_full.num_frames, 28, 28),
mask_ratio=0.0,
tube_length=1,
batch_size=1,
mask_type='rotated_table'
)
def forward(self, img1, img2):
'''
img1: [3, 1024, 1024]
img1: [3, 1024, 1024]
both images are imagenet normalized
'''
with torch.no_grad():
FF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor,
self.mask_generator, img1[None],
img2[None],
num_scales=2,
min_scale=224,
N_mask_samples=1)
BF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor,
self.mask_generator,
img2[None],
img1[None],
num_scales=2,
min_scale=224,
N_mask_samples=1)
# FF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor,
# self.mask_generator, img1[None],
# img2[None], img2[None],
# neg_back_flow=True, num_scales=1,
# min_scale=224, N_mask_samples=1,
# mask_ratio=0.0)
#
# BF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor,
# self.mask_generator, img2[None],
# img1[None], img1[None],
# neg_back_flow=True, num_scales=1,
# min_scale=224, N_mask_samples=1,
# mask_ratio=0.0)
occ_mask = get_occ_masks(FF, BF)[0]
FF = FF * occ_mask
FF = FF[0]
return FF#.cpu().numpy().transpose([1, 2, 0])
class CWM_8x8(CWM):
def __init__(self):
super().__init__('vitb_8x8patch_3frames', 8,
'/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth')