Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator, | |
mask_generator, | |
img1, | |
img2, | |
neg_back_flow=True, | |
num_scales=1, | |
min_scale=400, | |
N_mask_samples=100, | |
mask_ratio=0.8, | |
smoothing_factor=1): | |
B = img1.shape[0] | |
assert len(img1.shape) == 4 | |
assert num_scales >= 1 | |
# For scaling | |
h1 = img2.shape[-2] | |
w1 = img2.shape[-1] | |
assert min_scale < h1 and min_scale >= 360 # Below 360p, the flows look terrible | |
if neg_back_flow is False: | |
print('WARNING: Not calculating negative backward flow') | |
alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1 | |
frame_size = 224 // generator.patch_size[-1] | |
all_fwd_flows_e2d = [] | |
s_hs = [] | |
s_ws = [] | |
for aidx in range(num_scales): | |
print(aidx) | |
# print('aidx: ', aidx) | |
img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], | |
mode='bicubic', align_corners=True) | |
img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)], | |
mode='bicubic', align_corners=True) | |
h2 = img2_scaled.shape[-2] | |
w2 = img2_scaled.shape[-1] | |
s_h = h1 / h2 | |
s_w = w1 / w2 | |
s_hs.append(s_h) | |
s_ws.append(s_w) | |
# Because technically the compute_optical_flow function returns neg back flow | |
if neg_back_flow is True: | |
video = torch.cat([img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1) | |
else: | |
video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1)], 1) | |
# Should work, even if the incoming video is already 224x224 | |
crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1) | |
num_crops = len(crops1) | |
crop_flows_enc = [] | |
crop_flows_enc2dec = [] | |
N_samples = N_mask_samples | |
crop = torch.cat(crops1, 0).cuda() | |
optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda() | |
mask_counts = torch.zeros(frame_size, frame_size).cuda() | |
i = 0 | |
while i < N_samples or (mask_counts == 0).any().item(): | |
if i % 100 == 0: | |
pass # print(i) | |
mask_generator.mask_ratio = mask_ratio | |
# This would be that every sample has the same mask. For now that's okay I think | |
mask = mask_generator()[None] | |
mask_2f = ~mask[0, frame_size * frame_size:] | |
mask_counts += mask_2f.reshape(frame_size, frame_size) | |
with torch.cuda.amp.autocast(enabled=True): | |
processed_x = generator._preprocess(crop) | |
encoder_out = generator.predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1)) | |
encoder_to_decoder = generator.predictor.encoder_to_decoder(encoder_out) | |
optical_flow_e2d = [] | |
# one per batch element for now | |
for b in range(B * num_crops): | |
batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), mask, frame_size) | |
optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0)) | |
optical_flow_e2d = torch.cat(optical_flow_e2d, 0) | |
optical_flows_enc2dec += optical_flow_e2d | |
i += 1 | |
optical_flows_enc2dec = optical_flows_enc2dec / mask_counts | |
# split the crops back up | |
crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0) | |
T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in | |
crop_flows_enc2dec] | |
optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, ( | |
B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1) | |
all_fwd_flows_e2d.append(optical_flows_enc2dec_joined) | |
all_fwd_flows_e2d_new = [] | |
for ridx, r in enumerate(all_fwd_flows_e2d): | |
# print('ridx', ridx) | |
# print('sh', s_hs[ridx]) | |
# print('sw', s_ws[ridx]) | |
# print('scale_fac y', scale_ys[ridx]) | |
# print('scale_fac x', scale_xs[ridx]) | |
_sh = s_hs[ridx] | |
_sw = s_ws[ridx] | |
_sfy = generator.patch_size[-1] | |
_sfx = generator.patch_size[-1] | |
# plt.figure(figsize=(20, 20)) | |
# plt.subplot(1,3,1) | |
# plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0)) | |
# plt.subplot(1,3,2) | |
new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])], | |
mode='bicubic', align_corners=True) | |
# plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0)) | |
scaled_new_r = torch.zeros_like(new_r) | |
scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw | |
scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh | |
# plt.subplot(1,3,3) | |
# plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0)) | |
# plt.show() | |
all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1)) | |
return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1) | |
if neg_back_flow is True: | |
return_flow = -return_flow | |
all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new] | |
return return_flow, all_fwd_flows_e2d_new |