import random import math import numpy as np import torch import torch.nn.functional as F from . import losses as bblosses import kornia IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) def compute_optical_flow(embedding_tensor, mask_tensor, frame_size): # Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame mask_unrolled = mask_tensor.view(-1) second_frame_unmask_indices = torch.where(mask_unrolled[frame_size ** 2:] == False)[0] # Divide the embedding tensor into two parts: corresponding to the first and the second frame first_frame_embeddings = embedding_tensor[0, :frame_size ** 2, :] second_frame_embeddings = embedding_tensor[0, frame_size ** 2:, :] # print(first_frame_embeddings.shape, second_frame_embeddings.shape, embedding_tensor.shape) # Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T) norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :] cos_sim_matrix = dot_product / norms # Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1) # Convert the 1D pixel indices into 2D coordinates second_frame_y = second_frame_unmask_indices // frame_size second_frame_x = second_frame_unmask_indices % frame_size first_frame_y = first_frame_most_similar_indices // frame_size first_frame_x = first_frame_most_similar_indices % frame_size # Compute the x and y displacements and convert them to float displacements_x = (second_frame_x - first_frame_x).float() displacements_y = (second_frame_y - first_frame_y).float() # Initialize optical flow tensor optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device) # Assign the computed displacements to the corresponding pixels in the optical flow tensor optical_flow[0, second_frame_y, second_frame_x] = displacements_x optical_flow[1, second_frame_y, second_frame_x] = displacements_y return optical_flow def get_minimal_224_crops_new_batched(video_tensor, N): B, T, C, H, W = video_tensor.shape # Calculate the number of crops needed in both the height and width dimensions num_crops_h = math.ceil(H / 224) if H > 224 else 1 num_crops_w = math.ceil(W / 224) if W > 224 else 1 # Calculate the step size for the height and width dimensions step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1)) step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1)) # Create a list to store the cropped tensors and their start positions cropped_tensors = [] crop_positions = [] # Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list for i in range(num_crops_h): for j in range(num_crops_w): start_h = i * step_size_h start_w = j * step_size_w end_h = min(start_h + 224, H) end_w = min(start_w + 224, W) crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w] cropped_tensors.append(crop) crop_positions.append((start_h, start_w)) D = len(cropped_tensors) # If N is greater than D, generate additional random crops if N > D and H > 224 and W > 224: # check if H and W are greater than 224 for _ in range(N - D): start_h = random.randint(0, H - 224) start_w = random.randint(0, W - 224) crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] cropped_tensors.append(crop) crop_positions.append((start_h, start_w)) # Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224) cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors] return cropped_tensors, crop_positions def create_weighted_mask_batched(h, w): y_mask = np.linspace(0, 1, h) y_mask = np.minimum(y_mask, 1 - y_mask) x_mask = np.linspace(0, 1, w) x_mask = np.minimum(x_mask, 1 - x_mask) weighted_mask = np.outer(y_mask, x_mask) return torch.from_numpy(weighted_mask).float() def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape): B, T, C, H, W = original_shape # Initialize an empty tensor to store the reconstructed video reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device) # Create a tensor to store the sum of weighted masks weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device) # Create a weighted mask for the crops weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device) weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor. for idx, crop in enumerate(cropped_tensors): start_h, start_w = crop_positions[idx] # Multiply the crop with the weighted mask weighted_crop = crop * weighted_mask # Add the weighted crop to the corresponding location in the reconstructed_video tensor reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop # Update the weighted_masks_sum tensor weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask # Add a small epsilon value to avoid division by zero epsilon = 1e-8 # Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon reconstructed_video /= (weighted_masks_sum + epsilon) return reconstructed_video def l2_norm(x): return x.square().sum(-3, True).sqrt() resize = lambda x, a: F.interpolate(x, [int(a * x.shape[-2]), int(a * x.shape[-1])], mode='bilinear', align_corners=False) upsample = lambda x, H, W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False) def get_occ_masks(flow_fwd, flow_bck, occ_thresh=0.5): fwd_bck_cycle, _ = bblosses.backward_warp(img2=flow_bck, flow=flow_fwd) flow_diff_fwd = flow_fwd + fwd_bck_cycle bck_fwd_cycle, _ = bblosses.backward_warp(img2=flow_fwd, flow=flow_bck) flow_diff_bck = flow_bck + bck_fwd_cycle norm_fwd = l2_norm(flow_fwd) ** 2 + l2_norm(fwd_bck_cycle) ** 2 norm_bck = l2_norm(flow_bck) ** 2 + l2_norm(bck_fwd_cycle) ** 2 occ_thresh_fwd = occ_thresh * norm_fwd + 0.5 occ_thresh_bck = occ_thresh * norm_bck + 0.5 occ_mask_fwd = 1 - (l2_norm(flow_diff_fwd) ** 2 > occ_thresh_fwd).float() occ_mask_bck = 1 - (l2_norm(flow_diff_bck) ** 2 > occ_thresh_bck).float() return occ_mask_fwd, occ_mask_bck def forward_backward_cycle_consistency(flow_fwd, flow_bck, niters=10): # Make sure to be using axes-swapped, upsampled flows! bck_flow_clone = flow_bck.clone().detach() fwd_flow_clone = flow_fwd.clone().detach() for i in range(niters): fwd_bck_cycle_orig, _ = bblosses.backward_warp(img2=bck_flow_clone, flow=fwd_flow_clone) flow_diff_fwd_orig = fwd_flow_clone + fwd_bck_cycle_orig fwd_flow_clone = fwd_flow_clone - flow_diff_fwd_orig/2 bck_fwd_cycle_orig, _ = bblosses.backward_warp(img2=fwd_flow_clone, flow=bck_flow_clone) flow_diff_bck_orig = bck_flow_clone + bck_fwd_cycle_orig bck_flow_clone = bck_flow_clone - flow_diff_bck_orig/2 return fwd_flow_clone, bck_flow_clone from PIL import Image def resize_flow_map(flow_map, target_size): """ Resize a flow map to a target size while adjusting the flow vectors. Parameters: flow_map (numpy.ndarray): Input flow map of shape (H, W, 2) where each pixel contains a (dx, dy) flow vector. target_size (tuple): Target size (height, width) for the resized flow map. Returns: numpy.ndarray: Resized and scaled flow map of shape (target_size[0], target_size[1], 2). """ # Get the original size flow_map = flow_map[0].detach().cpu().numpy() flow_map = flow_map.transpose(1, 2, 0) original_size = flow_map.shape[:2] # Separate the flow map into two channels: dx and dy flow_map_x = flow_map[:, :, 0] flow_map_y = flow_map[:, :, 1] # Convert each flow channel to a PIL image for resizing flow_map_x_img = Image.fromarray(flow_map_x) flow_map_y_img = Image.fromarray(flow_map_y) # Resize both channels to the target size using bilinear interpolation flow_map_x_resized = flow_map_x_img.resize(target_size, Image.BILINEAR) flow_map_y_resized = flow_map_y_img.resize(target_size, Image.BILINEAR) # Convert resized PIL images back to NumPy arrays flow_map_x_resized = np.array(flow_map_x_resized) flow_map_y_resized = np.array(flow_map_y_resized) # Compute the scaling factor based on the size change scale_factor = target_size[0] / original_size[0] # Scaling factor for both dx and dy # Scale the flow vectors (dx and dy) accordingly flow_map_x_resized *= scale_factor flow_map_y_resized *= scale_factor # Recombine the two channels into a resized flow map flow_map_resized = np.stack([flow_map_x_resized, flow_map_y_resized], axis=-1) flow_map_resized = torch.from_numpy(flow_map_resized)[None].permute(0, 3, 1, 2) return flow_map_resized def get_vmae_optical_flow_crop_batched_smoothed(generator, mask_generator, img1, img2, neg_back_flow=True, num_scales=1, min_scale=400, N_mask_samples=100, mask_ratio=0.8, smoothing_factor=1): ##### DEPRECATED print('Deprecated. Please use scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed') return scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator, mask_generator, img1, img2, neg_back_flow=neg_back_flow, num_scales=num_scales, min_scale=min_scale, N_mask_samples=N_mask_samples, mask_ratio=mask_ratio, smoothing_factor=smoothing_factor) def average_crops(tensor, D): C, H, W = tensor.shape # Create zero-filled tensors for the shifted crops down_shifted = torch.zeros_like(tensor) up_shifted = torch.zeros_like(tensor) right_shifted = torch.zeros_like(tensor) left_shifted = torch.zeros_like(tensor) # Shift the tensor and store the results in the zero-filled tensors down_shifted[:, :H-D, :] = tensor[:, D:, :] up_shifted[:, D:, :] = tensor[:, :H-D, :] right_shifted[:, :, :W-D] = tensor[:, :, D:] left_shifted[:, :, D:] = tensor[:, :, :W-D] # Average the tensor with its four crops result = (tensor + down_shifted + up_shifted + right_shifted + left_shifted) / 5.0 return result def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(predictor, mask_generator, img1, img2, conditioning_img=None, num_scales=1, min_scale=400, N_mask_samples=100, smoothing_factor=1): B = img1.shape[0] assert len(img1.shape) == 4 assert num_scales >= 1 # For scaling h1 = img2.shape[-2] w1 = img2.shape[-1] alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1 frame_size = 224 // predictor.patch_size[-1] patch_size = predictor.patch_size[-1] num_frames = predictor.num_frames all_fwd_flows_e2d = [] s_hs = [] s_ws = [] for aidx in range(num_scales): # print(aidx) # print('aidx: ', aidx) img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], mode='bicubic', align_corners=True) img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], mode='bicubic', align_corners=True) if conditioning_img is not None: conditioning_img_scaled = F.interpolate(conditioning_img.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], mode='bilinear', align_corners=False) # print("img1_scaled", img1_scaled.shape, alpha, min_scale, num_scales) h2 = img2_scaled.shape[-2] w2 = img2_scaled.shape[-1] s_h = h1 / h2 s_w = w1 / w2 s_hs.append(s_h) s_ws.append(s_w) if conditioning_img is not None: video = torch.cat([conditioning_img_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1) else: video = torch.cat([img2_scaled.unsqueeze(1)]*(num_frames-1) + [img1_scaled.unsqueeze(1)], 1) # Should work, even if the incoming video is already 224x224 crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1) num_crops = len(crops1) crop_flows_enc = [] crop_flows_enc2dec = [] N_samples = N_mask_samples crop = torch.cat(crops1, 0).cuda() optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda() mask_counts = torch.zeros(frame_size, frame_size).cuda() i = 0 while i < N_samples or (mask_counts == 0).any().item(): if i % 100 == 0: pass # print(i) # This would be that every sample has the same mask. For now that's okay I think mask = mask_generator().bool().cuda() mask_2f = ~mask[0, (frame_size * frame_size)*(num_frames-1):] mask_counts += mask_2f.reshape(frame_size, frame_size) with torch.cuda.amp.autocast(enabled=True): processed_x = crop.transpose(1, 2) encoder_out = predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1)) encoder_to_decoder = predictor.encoder_to_decoder(encoder_out) encoder_to_decoder = encoder_to_decoder[:, (frame_size * frame_size)*(num_frames-2):, :] flow_mask = mask[:, (frame_size * frame_size)*(num_frames-2):] optical_flow_e2d = [] # one per batch element for now for b in range(B * num_crops): batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size) # optical_flow_e2d.append(batch_flow.unsqueeze(0)) optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0)) optical_flow_e2d = torch.cat(optical_flow_e2d, 0) optical_flows_enc2dec += optical_flow_e2d i += 1 optical_flows_enc2dec = optical_flows_enc2dec / mask_counts #other fucntion # scale_factor_y = video.shape[-2] / 224 # scale_factor_x = video.shape[-1] / 224 # # scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec) # scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w # scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h # # # split the crops back up # crop_flows_enc2dec = scaled_optical_flow.split(B, 0) ### #Kevin's fn crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0) ### #Changed by Kevin T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in crop_flows_enc2dec] optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, ( B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1) #other function # optical_flows_enc2dec_joined = reconstruct_video_new_2_batched( # [_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in # crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1) # all_fwd_flows_e2d.append(optical_flows_enc2dec_joined) #other function # all_fwd_flows_e2d_new = [] # # for r in all_fwd_flows_e2d: # new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1]) # all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1)) # return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1) # # # return_flow = -return_flow # all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new] # # return return_flow, all_fwd_flows_e2d_new #Kevin's method all_fwd_flows_e2d_new = [] for ridx, r in enumerate(all_fwd_flows_e2d): # print('ridx', ridx) # print('sh', s_hs[ridx]) # print('sw', s_ws[ridx]) # print('scale_fac y', scale_ys[ridx]) # print('scale_fac x', scale_xs[ridx]) _sh = s_hs[ridx] _sw = s_ws[ridx] _sfy = predictor.patch_size[-1] _sfx = predictor.patch_size[-1] # plt.figure(figsize=(20, 20)) # plt.subplot(1,3,1) # plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0)) # plt.subplot(1,3,2) new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])], mode='bicubic', align_corners=True) # plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0)) scaled_new_r = torch.zeros_like(new_r) scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh # plt.subplot(1,3,3) # plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0)) # plt.show() all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1)) return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1) return_flow = -return_flow all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new] return return_flow , all_fwd_flows_e2d_new def extract_jacobians_and_flows(img1, img2, flow_generator, mask, target_mask=None): IMAGE_SIZE = img1.shape[-2:] y = torch.cat([img2.unsqueeze(1), img1.unsqueeze(1)], 1) jacobians, flows, _ = flow_generator(y, mask, target_mask) # swap x,y flow dims flows = torch.cat([flows[0, 1].unsqueeze(0), flows[0, 0].unsqueeze(0)]) # upsample to 224 flows = flows.unsqueeze(0).repeat_interleave(IMAGE_SIZE[0] // flows.shape[-1], -1).repeat_interleave( IMAGE_SIZE[0] // flows.shape[-1], -2) return jacobians, flows import matplotlib.pyplot as plt class FlowToRgb(object): def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False): self.max_speed = max_speed self.from_image_coordinates = from_image_coordinates self.from_sampling_grid = from_sampling_grid def __call__(self, flow): assert flow.size(-3) == 2, flow.shape if self.from_sampling_grid: flow_x, flow_y = torch.split(flow, [1, 1], dim=-3) flow_y = -flow_y elif not self.from_image_coordinates: flow_x, flow_y = torch.split(flow, [1, 1], dim=-3) else: flow_h, flow_w = torch.split(flow, [1,1], dim=-3) flow_x, flow_y = [flow_w, -flow_h] # print("flow_x", flow_x[0, :, 0, 0], flow_y[0, :, 0, 0]) angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed # print("angle", angle[0, :, 0, 0] * 180 / np.pi) hue = torch.fmod(angle, torch.tensor(2 * np.pi)) sat = torch.ones_like(hue) val = speed hsv = torch.cat([hue, sat, val], -3) rgb = kornia.color.hsv_to_rgb(hsv) return rgb def make_colorwheel(self): """ Generates a color wheel for optical flow visualization as presented in: Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) """ RY = 15 YG = 6 GC = 4 CB = 11 BM = 13 MR = 6 ncols = RY + YG + GC + CB + BM + MR colorwheel = np.zeros((ncols, 3)) col = 0 # RY colorwheel[0:RY, 0] = 255 colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) col += RY # YG colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) colorwheel[col:col + YG, 1] = 255 col += YG # GC colorwheel[col:col + GC, 1] = 255 colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) col += GC # CB colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(0, CB) / CB) colorwheel[col:col + CB, 2] = 255 col += CB # BM colorwheel[col:col + BM, 2] = 255 colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) col += BM # MR colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(0, MR) / MR) colorwheel[col:col + MR, 0] = 255 return colorwheel