rahulvenkk
app.py updated
6dfcb0f
import torch
import torch.nn.functional as F
from torchvision import transforms
def sampling_grid(height, width):
H, W = height, width
grid = torch.stack([
torch.arange(W).view(1, -1).repeat(H, 1),
torch.arange(H).view(-1, 1).repeat(1, W)
], -1)
grid = grid.view(1, H, W, 2)
return grid
def normalize_sampling_grid(coords):
assert len(coords.shape) == 4, coords.shape
assert coords.size(-1) == 2, coords.shape
H, W = coords.shape[-3:-1]
xs, ys = coords.split([1, 1], -1)
xs = 2 * xs / (W - 1) - 1
ys = 2 * ys / (H - 1) - 1
return torch.cat([xs, ys], -1)
def backward_warp(img2, flow, do_mask=False):
"""
Grid sample from img2 using the flow from img1->img2 to get a prediction of img1.
flow: [B,2,H',W'] in units of pixels at its current resolution. The two channels
should be (x,y) where larger y values correspond to lower parts of the image.
"""
## resize the flow to the image size.
## since flow has units of pixels, its values need to be rescaled accordingly.
if list(img2.shape[-2:]) != list(flow.shape[-2:]):
scale = [img2.size(-1) / flow.size(-1), # x
img2.size(-2) / flow.size(-2)] # y
scale = torch.tensor(scale).view(1, 2, 1, 1).to(flow.device)
flow = scale * transforms.Resize(img2.shape[-2:])(flow) # defaults to bilinear
B, C, H, W = img2.shape
## use flow to warp sampling grid
grid = sampling_grid(H, W).to(flow.device) + flow.permute(0, 2, 3, 1)
## put grid in normalized image coordinates
grid = normalize_sampling_grid(grid)
## backward warp, i.e. sample pixel (x,y) from (x+flow_x, y+flow_y)
img1_pred = F.grid_sample(img2, grid, align_corners=True)
if do_mask:
mask = (grid[..., 0] > -1) & (grid[..., 0] < 1) & \
(grid[..., 1] > -1) & (grid[..., 1] < 1)
mask = mask[:, None].to(img2.dtype)
return (img1_pred, mask)
else:
return (img1_pred, torch.ones_like(grid[..., 0][:, None]).float())