File size: 5,261 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.nn as nn

import cwm.model.model_pretrain as vmae_tranformers
from . import flow_utils
from . import losses  as bblosses


# Normal Resolution
def l2_norm(x):
    return x.square().sum(-3, True).sqrt()




# x.shape
def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5):
    fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd)
    flow_diff_fwd = flow_fwd + fwd_bck_cycle

    bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck)
    flow_diff_bck = flow_bck + bck_fwd_cycle

    norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2
    norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2

    occ_thresh_fwd = occ_thresh * norm_fwd + 0.5
    occ_thresh_bck = occ_thresh * norm_bck + 0.5

    occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float()
    occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float()

    return occ_mask_fwd, occ_mask_bck


class ExtractFlow(nn.Module):

    def __init__(self):
        super().__init__()
        return

    def forward(self, img1, img2):
        '''
        img1: first frame
        img2: second frame
        returns: flow map (h, w, 2)
        '''

from cwm.data.masking_generator import RotatedTableMaskingGenerator

class CWM(ExtractFlow):
    def __init__(self, model_name, patch_size, weights_path):
        super().__init__()

        self.patch_size = patch_size
        model = getattr(vmae_tranformers, model_name)
        vmae_8x8_full = model().cuda().eval().requires_grad_(False)

        VMAE_LOAD_PATH = weights_path
        did_load = vmae_8x8_full.load_state_dict(torch.load(VMAE_LOAD_PATH)['model'], strict=False)
        print(did_load, VMAE_LOAD_PATH)

        self.predictor = vmae_8x8_full

        self.mask_generator = RotatedTableMaskingGenerator(
            input_size=(vmae_8x8_full.num_frames, 28, 28),
            mask_ratio=0.0,
            tube_length=1,
            batch_size=1,
            mask_type='rotated_table'
        )

    def forward(self, img1, img2):
        '''
        img1: [3, 1024, 1024]
        img1: [3, 1024, 1024]
        both images are imagenet normalized
        '''

        with torch.no_grad():
            FF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor,
                                                                                 self.mask_generator, img1[None],
                                                                                 img2[None],
                                                                                 num_scales=2,
                                                                                 min_scale=224,
                                                                                N_mask_samples=1)

            BF, _ = flow_utils.scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(self.predictor,
                                                                                         self.mask_generator,
                                                                                         img2[None],
                                                                                         img1[None],
                                                                                         num_scales=2,
                                                                                         min_scale=224,
                                                                                         N_mask_samples=1)

            # FF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor,
            #                                                                      self.mask_generator, img1[None],
            #                                                                      img2[None], img2[None],
            #                                                                      neg_back_flow=True, num_scales=1,
            #                                                                      min_scale=224, N_mask_samples=1,
            #                                                                      mask_ratio=0.0)
            #
            # BF, _ = flow_utils.get_honglin_3frame_vmae_optical_flow_crop_batched(self.predictor,
            #                                                                      self.mask_generator, img2[None],
            #                                                                      img1[None], img1[None],
            #                                                                      neg_back_flow=True, num_scales=1,
            #                                                                      min_scale=224, N_mask_samples=1,
            #                                                                      mask_ratio=0.0)

        occ_mask = get_occ_masks(FF, BF)[0]

        FF = FF * occ_mask

        FF = FF[0]

        return FF#.cpu().numpy().transpose([1, 2, 0])


class CWM_8x8(CWM):
    def __init__(self):
        super().__init__('vitb_8x8patch_3frames', 8,
                         '/ccn2/u/honglinc/cwm_checkpoints/ablation_3frame_no_clumping_mr0.90_extra_data_ep400/checkpoint-399.pth')