import torch import torch.nn.functional as F from torchvision import transforms def sampling_grid(height, width): H, W = height, width grid = torch.stack([ torch.arange(W).view(1, -1).repeat(H, 1), torch.arange(H).view(-1, 1).repeat(1, W) ], -1) grid = grid.view(1, H, W, 2) return grid def normalize_sampling_grid(coords): assert len(coords.shape) == 4, coords.shape assert coords.size(-1) == 2, coords.shape H, W = coords.shape[-3:-1] xs, ys = coords.split([1, 1], -1) xs = 2 * xs / (W - 1) - 1 ys = 2 * ys / (H - 1) - 1 return torch.cat([xs, ys], -1) def backward_warp(img2, flow, do_mask=False): """ Grid sample from img2 using the flow from img1->img2 to get a prediction of img1. flow: [B,2,H',W'] in units of pixels at its current resolution. The two channels should be (x,y) where larger y values correspond to lower parts of the image. """ ## resize the flow to the image size. ## since flow has units of pixels, its values need to be rescaled accordingly. if list(img2.shape[-2:]) != list(flow.shape[-2:]): scale = [img2.size(-1) / flow.size(-1), # x img2.size(-2) / flow.size(-2)] # y scale = torch.tensor(scale).view(1, 2, 1, 1).to(flow.device) flow = scale * transforms.Resize(img2.shape[-2:])(flow) # defaults to bilinear B, C, H, W = img2.shape ## use flow to warp sampling grid grid = sampling_grid(H, W).to(flow.device) + flow.permute(0, 2, 3, 1) ## put grid in normalized image coordinates grid = normalize_sampling_grid(grid) ## backward warp, i.e. sample pixel (x,y) from (x+flow_x, y+flow_y) img1_pred = F.grid_sample(img2, grid, align_corners=True) if do_mask: mask = (grid[..., 0] > -1) & (grid[..., 0] < 1) & \ (grid[..., 1] > -1) & (grid[..., 1] < 1) mask = mask[:, None].to(img2.dtype) return (img1_pred, mask) else: return (img1_pred, torch.ones_like(grid[..., 0][:, None]).float())