import torch import numpy as np import random import math def create_weighted_mask_batched(h, w): y_mask = np.linspace(0, 1, h) y_mask = np.minimum(y_mask, 1 - y_mask) x_mask = np.linspace(0, 1, w) x_mask = np.minimum(x_mask, 1 - x_mask) weighted_mask = np.outer(y_mask, x_mask) return torch.from_numpy(weighted_mask).float() def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape): B, T, C, H, W = original_shape # Initialize an empty tensor to store the reconstructed video reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device) # Create a tensor to store the sum of weighted masks weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device) # Create a weighted mask for the crops weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device) weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor. for idx, crop in enumerate(cropped_tensors): start_h, start_w = crop_positions[idx] # Multiply the crop with the weighted mask weighted_crop = crop * weighted_mask # Add the weighted crop to the corresponding location in the reconstructed_video tensor reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop # Update the weighted_masks_sum tensor weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask # Add a small epsilon value to avoid division by zero epsilon = 1e-8 # Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon reconstructed_video /= (weighted_masks_sum + epsilon) return reconstructed_video import torch.nn.functional as F resize = lambda x,a: F.interpolate(x, [int(a*x.shape[-2]), int(a*x.shape[-1])], mode='bilinear', align_corners=False) upsample = lambda x,H,W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False) # def compute_optical_flow(embedding_tensor, mask_tensor, frame_size): # Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame mask_unrolled = mask_tensor.view(-1) second_frame_unmask_indices = torch.where(mask_unrolled[frame_size**2:] == False)[0] # Divide the embedding tensor into two parts: corresponding to the first and the second frame first_frame_embeddings = embedding_tensor[0, :frame_size**2, :] second_frame_embeddings = embedding_tensor[0, frame_size**2:, :] # Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T) norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :] cos_sim_matrix = dot_product / norms # Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1) # Convert the 1D pixel indices into 2D coordinates second_frame_y = second_frame_unmask_indices // frame_size second_frame_x = second_frame_unmask_indices % frame_size first_frame_y = first_frame_most_similar_indices // frame_size first_frame_x = first_frame_most_similar_indices % frame_size # Compute the x and y displacements and convert them to float displacements_x = (second_frame_x - first_frame_x).float() displacements_y = (second_frame_y - first_frame_y).float() # Initialize optical flow tensor optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device) # Assign the computed displacements to the corresponding pixels in the optical flow tensor optical_flow[0, second_frame_y, second_frame_x] = displacements_x optical_flow[1, second_frame_y, second_frame_x] = displacements_y return optical_flow def get_minimal_224_crops_new_batched(video_tensor, N): B, T, C, H, W = video_tensor.shape # Calculate the number of crops needed in both the height and width dimensions num_crops_h = math.ceil(H / 224) if H > 224 else 1 num_crops_w = math.ceil(W / 224) if W > 224 else 1 # Calculate the step size for the height and width dimensions step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1)) step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1)) # Create a list to store the cropped tensors and their start positions cropped_tensors = [] crop_positions = [] # Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list for i in range(num_crops_h): for j in range(num_crops_w): start_h = i * step_size_h start_w = j * step_size_w end_h = min(start_h + 224, H) end_w = min(start_w + 224, W) crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w] cropped_tensors.append(crop) crop_positions.append((start_h, start_w)) D = len(cropped_tensors) # If N is greater than D, generate additional random crops if N > D and H > 224 and W > 224: # check if H and W are greater than 224 for _ in range(N - D): start_h = random.randint(0, H - 224) start_w = random.randint(0, W - 224) crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] cropped_tensors.append(crop) crop_positions.append((start_h, start_w)) # Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224) cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors] return cropped_tensors, crop_positions def get_honglin_3frame_vmae_optical_flow_crop_batched(generator, mask_generator, img1, img2, img3, neg_back_flow=True, num_scales=1, min_scale=400, N_mask_samples=100, mask_ratio=0.8, flow_frames='23'): B = img1.shape[0] assert len(img1.shape) == 4 assert num_scales >= 1 # For scaling h1 = img2.shape[-2] w1 = img2.shape[-1] assert min_scale < h1 if neg_back_flow is False: print('WARNING: Not calculating negative backward flow') alpha = (min_scale / img1.shape[-2]) ** (1 / 4) frame_size = 224 // generator.patch_size[-1] patch_size = generator.patch_size[-1] all_fwd_flows_e2d = [] for aidx in range(num_scales): # print('aidx: ', aidx) img1_scaled = resize(img1.clone(), alpha ** aidx) img2_scaled = resize(img2.clone(), alpha ** aidx) img3_scaled = resize(img3.clone(), alpha ** aidx) h2 = img2_scaled.shape[-2] w2 = img2_scaled.shape[-1] s_h = h1 / h2 s_w = w1 / w2 # Because technically the compute_optical_flow function returns neg back flow if neg_back_flow is True: video = torch.cat([img3_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1) else: video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img3_scaled.unsqueeze(1)], 1) # Should work, even if the incoming video is already 224x224 crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1) # print(len(crops1), crops1[0].shape) num_crops = len(crops1) crop_flows_enc = [] crop_flows_enc2dec = [] N_samples = N_mask_samples crop = torch.cat(crops1, 0).cuda() # print(crop.shape) optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda() mask_counts = torch.zeros(frame_size, frame_size).cuda() i = 0 while i < N_samples or (mask_counts == 0).any().item(): if i % 100 == 0: pass # print(i) mask_generator.mask_ratio = mask_ratio # breakpoint() # This would be that every sample has the same mask. For now that's okay I think mask = mask_generator(num_frames=3)[None] mask_2f = ~mask[0, frame_size * frame_size * 2:] mask_counts += mask_2f.reshape(frame_size, frame_size) with torch.cuda.amp.autocast(enabled=True): processed_x = crop.transpose(1, 2) # print("crop", processed_x.max()) encoder_out = generator.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1)) encoder_to_decoder = generator.encoder_to_decoder(encoder_out) # print(encoder_to_decoder.shape) if flow_frames == '23': encoder_to_decoder = encoder_to_decoder[:, frame_size * frame_size:, :] flow_mask = mask[:, frame_size * frame_size:] # print(encoder_to_decoder.shape) elif flow_frames == '12': encoder_to_decoder = encoder_to_decoder[:, :frame_size * frame_size * 2, :] # print(encoder_to_decoder.shape) flow_mask = mask[:, :frame_size * frame_size * 2] # print(mask.shape) # print(flow_mask.shape) # print() optical_flow_e2d = [] # one per batch element for now for b in range(B * num_crops): batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size) optical_flow_e2d.append(batch_flow.unsqueeze(0)) optical_flow_e2d = torch.cat(optical_flow_e2d, 0) optical_flows_enc2dec += optical_flow_e2d i += 1 optical_flows_enc2dec = optical_flows_enc2dec / mask_counts scale_factor_y = video.shape[-2] / 224 scale_factor_x = video.shape[-1] / 224 scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec) scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h # split the crops back up crop_flows_enc2dec = scaled_optical_flow.split(B, 0) # print(len(crop_flows_enc2dec)) optical_flows_enc2dec_joined = reconstruct_video_new_2_batched( [_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1) all_fwd_flows_e2d.append(optical_flows_enc2dec_joined) all_fwd_flows_e2d_new = [] for r in all_fwd_flows_e2d: new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1]) all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1)) return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1) if neg_back_flow is True: return_flow = -return_flow all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new] return return_flow, all_fwd_flows_e2d_new