File size: 11,684 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import numpy as np
from physion_evaluator.feature_extract_interface import PhysionFeatureExtractor
from physion_evaluator.utils import DataAugmentationForVideoMAE

from torch.functional import F

from cwm.eval.Flow.flow_utils import get_occ_masks

from cwm.model.model_factory import model_factory
import torch

def load_predictor(
        model_func_,
        load_path_,
        **kwargs):
    predictor = model_func_(**kwargs).eval().requires_grad_(False)

    did_load = predictor.load_state_dict(
        torch.load(load_path_, map_location=torch.device("cpu"))['model'])
    predictor._predictor_load_path = load_path_
    print(did_load, load_path_)
    return predictor


class CWM(PhysionFeatureExtractor):
    def __init__(self, model_name, aggregate_embeddings=False):
        super().__init__()

        self.model = model_factory.load_model(model_name).cuda().half()

        self.num_frames = self.model.num_frames

        self.timestamps = np.arange(self.num_frames)

        ps = (224 // self.model.patch_size[1]) ** 2

        self.bool_masked_pos = np.zeros([ps * self.num_frames])
        self.bool_masked_pos[ps * (self.num_frames - 1):] = 1

        self.ps = ps

        self.aggregate_embeddings = aggregate_embeddings

    def transform(self):

        return DataAugmentationForVideoMAE(
            imagenet_normalize=True,
            rescale_size=224,
        ), 150, 4

    def fwd(self, videos):
        bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
        bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
        x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
                                  return_features=True)
        return x_encoded

    def extract_features(self, videos, for_flow=False):
        '''
        videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
        returns: [B, T, D] extracted features
        '''

        videos = videos.transpose(1, 2)

        all_features = []

        # repeat the last frame of the video
        videos = torch.cat([videos, videos[:, :, -1:]], dim=2)

        for x in range(0, 4, self.num_frames - 1):
            vid = videos[:, :, x:x + self.num_frames, :, :]
            all_features.append(self.fwd(vid))
            if self.aggregate_embeddings:
                feats = all_features[-1].mean(dim=1, keepdim=True)
                all_features[-1] = feats
                # feats = feats.view(feats.shape[0], -1, self.model.num_patches_per_frame, feats.shape[-1])
                # feats = feats.mean(dim=2)
                # all_features[-1] = feats

        x_encoded = torch.cat(all_features, dim=1)

        return x_encoded


class CWM_Keypoints(PhysionFeatureExtractor):
    def __init__(self, model_name):
        super().__init__()

        self.model = model_factory.load_model(model_name).cuda().half()

        self.frames = [[0, 1, 2], [1, 2, 3]]

        self.num_frames = self.model.num_frames

        self.ps = (224 // self.model.patch_size[1]) ** 2

        self.bool_masked_pos = np.zeros([self.ps * self.num_frames])
        self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1

        self.frame_gap = 150

        self.num_frames_dataset = 4

        self.res = 224


    def transform(self):

        return DataAugmentationForVideoMAE(
            imagenet_normalize=True,
            rescale_size=self.res,
        ), self.frame_gap, self.num_frames_dataset

    def fwd(self, videos):
        bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
        bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
        _, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
                                  return_features=True)
        return x_encoded

    def extract_features(self, videos, segments=None):
        '''
        videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
        returns: [B, T, D] extracted features
        '''

        videos = videos.transpose(1, 2)

        all_features = []

        for x, arr in enumerate(self.frames):

            #use the downsampled videos for keypoints
            vid = videos[:, :, arr, :, :].half()
            frame0 = vid[:, :, 0]
            frame1 = vid[:, :, 1]
            frame2 = vid[:, :, 2]

            #extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2
            mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2,  10, 1)

            #reshape the features to [batch size, num_features]
            k_feat = k_feat.view(k_feat.shape[0], -1)

            all_features.append(k_feat)

        x_encoded = torch.cat(all_features, dim=1)

        return x_encoded


class CWM_KeypointsFlow(PhysionFeatureExtractor):
    def __init__(self, model_name):
        super().__init__()

        self.model = model_factory.load_model(model_name).cuda().half()

        self.frames = [[0, 3, 6], [3, 6, 9], [6, 9, 9]]

        self.num_frames = self.model.num_frames

        self.timestamps = np.arange(self.num_frames)

        self.ps = (224 // self.model.patch_size[1]) ** 2

        self.bool_masked_pos = np.zeros([self.ps * self.num_frames])
        self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1

        self.frame_gap = 50

        self.num_frames_dataset = 9

        self.res = 512

    def transform(self):

        return DataAugmentationForVideoMAE(
            imagenet_normalize=True,
            rescale_size=self.res,
        ), self.frame_gap, self.num_frames_dataset

    def fwd(self, videos):
        bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool()
        bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0])
        _, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True,
                                  return_features=True)
        return x_encoded

    def get_forward_flow(self, videos):

        fid = 6

        forward_flow = self.model.get_flow(videos[:, :, fid], videos[:, :, fid + 1], conditioning_img=videos[:, :, fid + 2], mode='cosine')

        backward_flow = self.model.get_flow(videos[:, :, fid + 1], videos[:, :, fid], conditioning_img=videos[:, :, fid - 1], mode='cosine')

        occlusion_mask = get_occ_masks(forward_flow, backward_flow)[0]

        forward_flow = forward_flow * occlusion_mask

        forward_flow = torch.stack([forward_flow, forward_flow, forward_flow], dim=1)

        forward_flow = forward_flow.to(videos.device)

        forward_flow = F.interpolate(forward_flow, size=(2, 224, 224), mode='nearest')

        return forward_flow

    def extract_features(self, videos, segments=None):
        '''
        videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm
        returns: [B, T, D] extracted features
        Note:
        For efficiency, the optical flow is computed and added for a single frame (300ms) as we found this to be sufficient
        for capturing temporal dynamics in our experiments. This approach can be extended to multiple frames if needed,
        depending on the complexity of the task.
        '''


        #resize to 224 to get keypoints and features
        videos_downsampled = F.interpolate(videos.flatten(0, 1), size=(224, 224), mode='bilinear', align_corners=False)
        videos_downsampled = videos_downsampled.view(videos.shape[0], videos.shape[1], videos.shape[2], 224, 224)

        #for computing flow at higher resolution
        videos_ = F.interpolate(videos.flatten(0, 1), size=(1024, 1024), mode='bilinear', align_corners=False)
        videos = videos_.view(videos.shape[0], videos.shape[1], videos.shape[2], 1024, 1024)

        videos = videos.transpose(1, 2).half()
        videos_downsampled = videos_downsampled.transpose(1, 2).half()

        # Get the forward flow for the frame at 300ms
        forward_flow = self.get_forward_flow(videos)

        # Verify that there are no nans forward flow
        assert not torch.isnan(forward_flow).any(), "Forward flow is nan"

        all_features = []

        for x, arr in enumerate(self.frames):

            #use the downsampled videos for keypoints
            vid = videos_downsampled[:, :, arr, :, :]
            frame0 = vid[:, :, 0]
            frame1 = vid[:, :, 1]
            frame2 = vid[:, :, 2]

            #extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2
            mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2,  10, 1)

            #for the last set of frames only use features at keypoint regions of frame2
            if (x == 2):
                k_feat = k_feat[:, -10:, :]

            #reshape the features to [batch size, num_features]
            k_feat = k_feat.view(k_feat.shape[0], -1)

            choices_image_resolution = choices * self.model.patch_size[1]

            # At 300ms, add optical flow patches at the detected keypoint locations
            # For the first frame (x == 0)
            if x == 0:
                # Extract the optical flow information from the forward flow matrix for the second channel (index 2)
                flow_keyp = forward_flow[:, 2]

                # Initialize a result tensor to store the flow patches
                # Tensor shape: [batch_size, 8x8 patch (flattened to 64) * 2 channels, 10 keypoints]
                flow = torch.zeros(vid.shape[0], 8 * 8 * 2, 10).to(videos.device)

                # Patch size shift (since 8x8 patches are being extracted)
                shift = 8

                # Loop over each element in the batch to process individual video frames
                for b in range(flow_keyp.size(0)):
                    # Extract the x and y coordinates of the keypoint locations for this batch element
                    x_indices = choices_image_resolution[b, :, 0]
                    y_indices = choices_image_resolution[b, :, 1]

                    # For each keypoint (10 total keypoints in this case)
                    for ind in range(10):
                        # Extract the 8x8 patch of optical flow at each keypoint's (x, y) location
                        # Flatten the patch and assign it to the corresponding slice in the result tensor
                        flow[b, :, ind] = flow_keyp[b, :, y_indices[ind]:y_indices[ind] + shift,
                                          x_indices[ind]:x_indices[ind] + shift].flatten()

                # Reshape the flow tensor for easier concatenation (flatten across all patches)
                flow = flow.view(flow.shape[0], -1)

                # Concatenate the extracted optical flow features with the existing feature tensor (k_feat)
                k_feat = torch.cat([k_feat, flow], dim=1)

            all_features.append(k_feat)

        x_encoded = torch.cat(all_features, dim=1)

        return x_encoded


class CWM_base_8x8_3frame(CWM):
    def __init__(self,):
        super().__init__('vitb_8x8patch_3frames')

class CWM_base_8x8_3frame_mean_embed(CWM):
    def __init__(self,):
        super().__init__('vitb_8x8patch_3frames', aggregate_embeddings=True)

# CWM* (keypoints only) 74.7
class CWM_base_8x8_3frame_keypoints(CWM_Keypoints):
    def __init__(self,):
        super().__init__('vitb_8x8patch_3frames')


# CWM* (keypoints + Flow) 75.4
class CWM_base_8x8_3frame_keypoints_flow(CWM_KeypointsFlow):
    def __init__(self,):
        super().__init__('vitb_8x8patch_3frames')