Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
def sampling_grid(height, width): | |
H, W = height, width | |
grid = torch.stack([ | |
torch.arange(W).view(1, -1).repeat(H, 1), | |
torch.arange(H).view(-1, 1).repeat(1, W) | |
], -1) | |
grid = grid.view(1, H, W, 2) | |
return grid | |
def normalize_sampling_grid(coords): | |
assert len(coords.shape) == 4, coords.shape | |
assert coords.size(-1) == 2, coords.shape | |
H, W = coords.shape[-3:-1] | |
xs, ys = coords.split([1, 1], -1) | |
xs = 2 * xs / (W - 1) - 1 | |
ys = 2 * ys / (H - 1) - 1 | |
return torch.cat([xs, ys], -1) | |
def backward_warp(img2, flow, do_mask=False): | |
""" | |
Grid sample from img2 using the flow from img1->img2 to get a prediction of img1. | |
flow: [B,2,H',W'] in units of pixels at its current resolution. The two channels | |
should be (x,y) where larger y values correspond to lower parts of the image. | |
""" | |
## resize the flow to the image size. | |
## since flow has units of pixels, its values need to be rescaled accordingly. | |
if list(img2.shape[-2:]) != list(flow.shape[-2:]): | |
scale = [img2.size(-1) / flow.size(-1), # x | |
img2.size(-2) / flow.size(-2)] # y | |
scale = torch.tensor(scale).view(1, 2, 1, 1).to(flow.device) | |
flow = scale * transforms.Resize(img2.shape[-2:])(flow) # defaults to bilinear | |
B, C, H, W = img2.shape | |
## use flow to warp sampling grid | |
grid = sampling_grid(H, W).to(flow.device) + flow.permute(0, 2, 3, 1) | |
## put grid in normalized image coordinates | |
grid = normalize_sampling_grid(grid) | |
## backward warp, i.e. sample pixel (x,y) from (x+flow_x, y+flow_y) | |
img1_pred = F.grid_sample(img2, grid, align_corners=True) | |
if do_mask: | |
mask = (grid[..., 0] > -1) & (grid[..., 0] < 1) & \ | |
(grid[..., 1] > -1) & (grid[..., 1] < 1) | |
mask = mask[:, None].to(img2.dtype) | |
return (img1_pred, mask) | |
else: | |
return (img1_pred, torch.ones_like(grid[..., 0][:, None]).float()) | |