import numpy as np from physion_evaluator.feature_extract_interface import PhysionFeatureExtractor from physion_evaluator.utils import DataAugmentationForVideoMAE from torch.functional import F from cwm.eval.Flow.flow_utils import get_occ_masks from cwm.model.model_factory import model_factory import torch def load_predictor( model_func_, load_path_, **kwargs): predictor = model_func_(**kwargs).eval().requires_grad_(False) did_load = predictor.load_state_dict( torch.load(load_path_, map_location=torch.device("cpu"))['model']) predictor._predictor_load_path = load_path_ print(did_load, load_path_) return predictor class CWM(PhysionFeatureExtractor): def __init__(self, model_name, aggregate_embeddings=False): super().__init__() self.model = model_factory.load_model(model_name).cuda().half() self.num_frames = self.model.num_frames self.timestamps = np.arange(self.num_frames) ps = (224 // self.model.patch_size[1]) ** 2 self.bool_masked_pos = np.zeros([ps * self.num_frames]) self.bool_masked_pos[ps * (self.num_frames - 1):] = 1 self.ps = ps self.aggregate_embeddings = aggregate_embeddings def transform(self): return DataAugmentationForVideoMAE( imagenet_normalize=True, rescale_size=224, ), 150, 4 def fwd(self, videos): bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool() bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0]) x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True, return_features=True) return x_encoded def extract_features(self, videos, for_flow=False): ''' videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm returns: [B, T, D] extracted features ''' videos = videos.transpose(1, 2) all_features = [] # repeat the last frame of the video videos = torch.cat([videos, videos[:, :, -1:]], dim=2) for x in range(0, 4, self.num_frames - 1): vid = videos[:, :, x:x + self.num_frames, :, :] all_features.append(self.fwd(vid)) if self.aggregate_embeddings: feats = all_features[-1].mean(dim=1, keepdim=True) all_features[-1] = feats # feats = feats.view(feats.shape[0], -1, self.model.num_patches_per_frame, feats.shape[-1]) # feats = feats.mean(dim=2) # all_features[-1] = feats x_encoded = torch.cat(all_features, dim=1) return x_encoded class CWM_Keypoints(PhysionFeatureExtractor): def __init__(self, model_name): super().__init__() self.model = model_factory.load_model(model_name).cuda().half() self.frames = [[0, 1, 2], [1, 2, 3]] self.num_frames = self.model.num_frames self.ps = (224 // self.model.patch_size[1]) ** 2 self.bool_masked_pos = np.zeros([self.ps * self.num_frames]) self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1 self.frame_gap = 150 self.num_frames_dataset = 4 self.res = 224 def transform(self): return DataAugmentationForVideoMAE( imagenet_normalize=True, rescale_size=self.res, ), self.frame_gap, self.num_frames_dataset def fwd(self, videos): bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool() bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0]) _, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True, return_features=True) return x_encoded def extract_features(self, videos, segments=None): ''' videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm returns: [B, T, D] extracted features ''' videos = videos.transpose(1, 2) all_features = [] for x, arr in enumerate(self.frames): #use the downsampled videos for keypoints vid = videos[:, :, arr, :, :].half() frame0 = vid[:, :, 0] frame1 = vid[:, :, 1] frame2 = vid[:, :, 2] #extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2 mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1) #reshape the features to [batch size, num_features] k_feat = k_feat.view(k_feat.shape[0], -1) all_features.append(k_feat) x_encoded = torch.cat(all_features, dim=1) return x_encoded class CWM_KeypointsFlow(PhysionFeatureExtractor): def __init__(self, model_name): super().__init__() self.model = model_factory.load_model(model_name).cuda().half() self.frames = [[0, 3, 6], [3, 6, 9], [6, 9, 9]] self.num_frames = self.model.num_frames self.timestamps = np.arange(self.num_frames) self.ps = (224 // self.model.patch_size[1]) ** 2 self.bool_masked_pos = np.zeros([self.ps * self.num_frames]) self.bool_masked_pos[self.ps * (self.num_frames - 1):] = 1 self.frame_gap = 50 self.num_frames_dataset = 9 self.res = 512 def transform(self): return DataAugmentationForVideoMAE( imagenet_normalize=True, rescale_size=self.res, ), self.frame_gap, self.num_frames_dataset def fwd(self, videos): bool_masked_pos = torch.tensor(self.bool_masked_pos).to(videos.device).unsqueeze(0).bool() bool_masked_pos = torch.cat([bool_masked_pos] * videos.shape[0]) _, x_encoded = self.model(videos.half(), bool_masked_pos, forward_full=True, return_features=True) return x_encoded def get_forward_flow(self, videos): fid = 6 forward_flow = self.model.get_flow(videos[:, :, fid], videos[:, :, fid + 1], conditioning_img=videos[:, :, fid + 2], mode='cosine') backward_flow = self.model.get_flow(videos[:, :, fid + 1], videos[:, :, fid], conditioning_img=videos[:, :, fid - 1], mode='cosine') occlusion_mask = get_occ_masks(forward_flow, backward_flow)[0] forward_flow = forward_flow * occlusion_mask forward_flow = torch.stack([forward_flow, forward_flow, forward_flow], dim=1) forward_flow = forward_flow.to(videos.device) forward_flow = F.interpolate(forward_flow, size=(2, 224, 224), mode='nearest') return forward_flow def extract_features(self, videos, segments=None): ''' videos: [B, T, C, H, W], T is usually 4 and videos are normalized with imagenet norm returns: [B, T, D] extracted features Note: For efficiency, the optical flow is computed and added for a single frame (300ms) as we found this to be sufficient for capturing temporal dynamics in our experiments. This approach can be extended to multiple frames if needed, depending on the complexity of the task. ''' #resize to 224 to get keypoints and features videos_downsampled = F.interpolate(videos.flatten(0, 1), size=(224, 224), mode='bilinear', align_corners=False) videos_downsampled = videos_downsampled.view(videos.shape[0], videos.shape[1], videos.shape[2], 224, 224) #for computing flow at higher resolution videos_ = F.interpolate(videos.flatten(0, 1), size=(1024, 1024), mode='bilinear', align_corners=False) videos = videos_.view(videos.shape[0], videos.shape[1], videos.shape[2], 1024, 1024) videos = videos.transpose(1, 2).half() videos_downsampled = videos_downsampled.transpose(1, 2).half() # Get the forward flow for the frame at 300ms forward_flow = self.get_forward_flow(videos) # Verify that there are no nans forward flow assert not torch.isnan(forward_flow).any(), "Forward flow is nan" all_features = [] for x, arr in enumerate(self.frames): #use the downsampled videos for keypoints vid = videos_downsampled[:, :, arr, :, :] frame0 = vid[:, :, 0] frame1 = vid[:, :, 1] frame2 = vid[:, :, 2] #extract features from the video frames frame0 and frame1 and include features at keypoint regions of frame2 mask, choices, err_array, k_feat, keypoint_recon = self.model.get_keypoints(frame0, frame1, frame2, 10, 1) #for the last set of frames only use features at keypoint regions of frame2 if (x == 2): k_feat = k_feat[:, -10:, :] #reshape the features to [batch size, num_features] k_feat = k_feat.view(k_feat.shape[0], -1) choices_image_resolution = choices * self.model.patch_size[1] # At 300ms, add optical flow patches at the detected keypoint locations # For the first frame (x == 0) if x == 0: # Extract the optical flow information from the forward flow matrix for the second channel (index 2) flow_keyp = forward_flow[:, 2] # Initialize a result tensor to store the flow patches # Tensor shape: [batch_size, 8x8 patch (flattened to 64) * 2 channels, 10 keypoints] flow = torch.zeros(vid.shape[0], 8 * 8 * 2, 10).to(videos.device) # Patch size shift (since 8x8 patches are being extracted) shift = 8 # Loop over each element in the batch to process individual video frames for b in range(flow_keyp.size(0)): # Extract the x and y coordinates of the keypoint locations for this batch element x_indices = choices_image_resolution[b, :, 0] y_indices = choices_image_resolution[b, :, 1] # For each keypoint (10 total keypoints in this case) for ind in range(10): # Extract the 8x8 patch of optical flow at each keypoint's (x, y) location # Flatten the patch and assign it to the corresponding slice in the result tensor flow[b, :, ind] = flow_keyp[b, :, y_indices[ind]:y_indices[ind] + shift, x_indices[ind]:x_indices[ind] + shift].flatten() # Reshape the flow tensor for easier concatenation (flatten across all patches) flow = flow.view(flow.shape[0], -1) # Concatenate the extracted optical flow features with the existing feature tensor (k_feat) k_feat = torch.cat([k_feat, flow], dim=1) all_features.append(k_feat) x_encoded = torch.cat(all_features, dim=1) return x_encoded class CWM_base_8x8_3frame(CWM): def __init__(self,): super().__init__('vitb_8x8patch_3frames') class CWM_base_8x8_3frame_mean_embed(CWM): def __init__(self,): super().__init__('vitb_8x8patch_3frames', aggregate_embeddings=True) # CWM* (keypoints only) 74.7 class CWM_base_8x8_3frame_keypoints(CWM_Keypoints): def __init__(self,): super().__init__('vitb_8x8patch_3frames') # CWM* (keypoints + Flow) 75.4 class CWM_base_8x8_3frame_keypoints_flow(CWM_KeypointsFlow): def __init__(self,): super().__init__('vitb_8x8patch_3frames')