def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator, mask_generator, img1, img2, neg_back_flow=True, num_scales=1, min_scale=400, N_mask_samples=100, mask_ratio=0.8, smoothing_factor=1): B = img1.shape[0] assert len(img1.shape) == 4 assert num_scales >= 1 # For scaling h1 = img2.shape[-2] w1 = img2.shape[-1] assert min_scale < h1 and min_scale >= 360 # Below 360p, the flows look terrible if neg_back_flow is False: print('WARNING: Not calculating negative backward flow') alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1 frame_size = 224 // generator.patch_size[-1] all_fwd_flows_e2d = [] s_hs = [] s_ws = [] for aidx in range(num_scales): print(aidx) # print('aidx: ', aidx) img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], mode='bicubic', align_corners=True) img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], mode='bicubic', align_corners=True) h2 = img2_scaled.shape[-2] w2 = img2_scaled.shape[-1] s_h = h1 / h2 s_w = w1 / w2 s_hs.append(s_h) s_ws.append(s_w) # Because technically the compute_optical_flow function returns neg back flow if neg_back_flow is True: video = torch.cat([img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1) else: video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1)], 1) # Should work, even if the incoming video is already 224x224 crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1) num_crops = len(crops1) crop_flows_enc = [] crop_flows_enc2dec = [] N_samples = N_mask_samples crop = torch.cat(crops1, 0).cuda() optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda() mask_counts = torch.zeros(frame_size, frame_size).cuda() i = 0 while i < N_samples or (mask_counts == 0).any().item(): if i % 100 == 0: pass # print(i) mask_generator.mask_ratio = mask_ratio # This would be that every sample has the same mask. For now that's okay I think mask = mask_generator()[None] mask_2f = ~mask[0, frame_size * frame_size:] mask_counts += mask_2f.reshape(frame_size, frame_size) with torch.cuda.amp.autocast(enabled=True): processed_x = generator._preprocess(crop) encoder_out = generator.predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1)) encoder_to_decoder = generator.predictor.encoder_to_decoder(encoder_out) optical_flow_e2d = [] # one per batch element for now for b in range(B * num_crops): batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), mask, frame_size) optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0)) optical_flow_e2d = torch.cat(optical_flow_e2d, 0) optical_flows_enc2dec += optical_flow_e2d i += 1 optical_flows_enc2dec = optical_flows_enc2dec / mask_counts # split the crops back up crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0) T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in crop_flows_enc2dec] optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, ( B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1) all_fwd_flows_e2d.append(optical_flows_enc2dec_joined) all_fwd_flows_e2d_new = [] for ridx, r in enumerate(all_fwd_flows_e2d): # print('ridx', ridx) # print('sh', s_hs[ridx]) # print('sw', s_ws[ridx]) # print('scale_fac y', scale_ys[ridx]) # print('scale_fac x', scale_xs[ridx]) _sh = s_hs[ridx] _sw = s_ws[ridx] _sfy = generator.patch_size[-1] _sfx = generator.patch_size[-1] # plt.figure(figsize=(20, 20)) # plt.subplot(1,3,1) # plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0)) # plt.subplot(1,3,2) new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])], mode='bicubic', align_corners=True) # plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0)) scaled_new_r = torch.zeros_like(new_r) scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh # plt.subplot(1,3,3) # plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0)) # plt.show() all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1)) return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1) if neg_back_flow is True: return_flow = -return_flow all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new] return return_flow, all_fwd_flows_e2d_new