rahulvenkk
app.py updated
6dfcb0f
def scaling_fixed_get_vmae_optical_flow_crop_batched_smoothed(generator,
mask_generator,
img1,
img2,
neg_back_flow=True,
num_scales=1,
min_scale=400,
N_mask_samples=100,
mask_ratio=0.8,
smoothing_factor=1):
B = img1.shape[0]
assert len(img1.shape) == 4
assert num_scales >= 1
# For scaling
h1 = img2.shape[-2]
w1 = img2.shape[-1]
assert min_scale < h1 and min_scale >= 360 # Below 360p, the flows look terrible
if neg_back_flow is False:
print('WARNING: Not calculating negative backward flow')
alpha = (min_scale / img1.shape[-2]) ** (1 / (num_scales - 1)) if num_scales > 1 else 1
frame_size = 224 // generator.patch_size[-1]
all_fwd_flows_e2d = []
s_hs = []
s_ws = []
for aidx in range(num_scales):
print(aidx)
# print('aidx: ', aidx)
img1_scaled = F.interpolate(img1.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
mode='bicubic', align_corners=True)
img2_scaled = F.interpolate(img2.clone(), [int((alpha ** aidx) * h1), int((alpha ** aidx) * w1)],
mode='bicubic', align_corners=True)
h2 = img2_scaled.shape[-2]
w2 = img2_scaled.shape[-1]
s_h = h1 / h2
s_w = w1 / w2
s_hs.append(s_h)
s_ws.append(s_w)
# Because technically the compute_optical_flow function returns neg back flow
if neg_back_flow is True:
video = torch.cat([img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
else:
video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1)], 1)
# Should work, even if the incoming video is already 224x224
crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
num_crops = len(crops1)
crop_flows_enc = []
crop_flows_enc2dec = []
N_samples = N_mask_samples
crop = torch.cat(crops1, 0).cuda()
optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
mask_counts = torch.zeros(frame_size, frame_size).cuda()
i = 0
while i < N_samples or (mask_counts == 0).any().item():
if i % 100 == 0:
pass # print(i)
mask_generator.mask_ratio = mask_ratio
# This would be that every sample has the same mask. For now that's okay I think
mask = mask_generator()[None]
mask_2f = ~mask[0, frame_size * frame_size:]
mask_counts += mask_2f.reshape(frame_size, frame_size)
with torch.cuda.amp.autocast(enabled=True):
processed_x = generator._preprocess(crop)
encoder_out = generator.predictor.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
encoder_to_decoder = generator.predictor.encoder_to_decoder(encoder_out)
optical_flow_e2d = []
# one per batch element for now
for b in range(B * num_crops):
batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), mask, frame_size)
optical_flow_e2d.append(average_crops(batch_flow, smoothing_factor).unsqueeze(0))
optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
optical_flows_enc2dec += optical_flow_e2d
i += 1
optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
# split the crops back up
crop_flows_enc2dec = optical_flows_enc2dec.split(B, 0)
T1 = [F.interpolate(_, [int(224), int(224)], mode='bicubic', align_corners=True).unsqueeze(1).cpu() for _ in
crop_flows_enc2dec]
optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(T1, c_pos1, (
B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
all_fwd_flows_e2d_new = []
for ridx, r in enumerate(all_fwd_flows_e2d):
# print('ridx', ridx)
# print('sh', s_hs[ridx])
# print('sw', s_ws[ridx])
# print('scale_fac y', scale_ys[ridx])
# print('scale_fac x', scale_xs[ridx])
_sh = s_hs[ridx]
_sw = s_ws[ridx]
_sfy = generator.patch_size[-1]
_sfx = generator.patch_size[-1]
# plt.figure(figsize=(20, 20))
# plt.subplot(1,3,1)
# plt.imshow(f2rgb(-r).cpu().numpy()[0].transpose(1,2,0))
# plt.subplot(1,3,2)
new_r = F.interpolate(r, [int(all_fwd_flows_e2d[0].shape[-2]), int(all_fwd_flows_e2d[0].shape[-1])],
mode='bicubic', align_corners=True)
# plt.imshow(f2rgb(-new_r).cpu().numpy()[0].transpose(1,2,0))
scaled_new_r = torch.zeros_like(new_r)
scaled_new_r[:, 0, :, :] = new_r[:, 0, :, :] * _sfx * _sw
scaled_new_r[:, 1, :, :] = new_r[:, 1, :, :] * _sfy * _sh
# plt.subplot(1,3,3)
# plt.imshow(f2rgb(-scaled_new_r).cpu().numpy()[0].transpose(1,2,0))
# plt.show()
all_fwd_flows_e2d_new.append(scaled_new_r.unsqueeze(-1))
return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
if neg_back_flow is True:
return_flow = -return_flow
all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
return return_flow, all_fwd_flows_e2d_new