File size: 2,053 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn.functional as F
from torchvision import transforms


def sampling_grid(height, width):
    H, W = height, width
    grid = torch.stack([
        torch.arange(W).view(1, -1).repeat(H, 1),
        torch.arange(H).view(-1, 1).repeat(1, W)
    ], -1)
    grid = grid.view(1, H, W, 2)
    return grid


def normalize_sampling_grid(coords):
    assert len(coords.shape) == 4, coords.shape
    assert coords.size(-1) == 2, coords.shape
    H, W = coords.shape[-3:-1]
    xs, ys = coords.split([1, 1], -1)
    xs = 2 * xs / (W - 1) - 1
    ys = 2 * ys / (H - 1) - 1
    return torch.cat([xs, ys], -1)


def backward_warp(img2, flow, do_mask=False):
    """
    Grid sample from img2 using the flow from img1->img2 to get a prediction of img1.

    flow: [B,2,H',W'] in units of pixels at its current resolution. The two channels
          should be (x,y) where larger y values correspond to lower parts of the image.
    """

    ## resize the flow to the image size.
    ## since flow has units of pixels, its values need to be rescaled accordingly.
    if list(img2.shape[-2:]) != list(flow.shape[-2:]):
        scale = [img2.size(-1) / flow.size(-1),  # x
                 img2.size(-2) / flow.size(-2)]  # y
        scale = torch.tensor(scale).view(1, 2, 1, 1).to(flow.device)
        flow = scale * transforms.Resize(img2.shape[-2:])(flow)  # defaults to bilinear

    B, C, H, W = img2.shape

    ## use flow to warp sampling grid
    grid = sampling_grid(H, W).to(flow.device) + flow.permute(0, 2, 3, 1)

    ## put grid in normalized image coordinates
    grid = normalize_sampling_grid(grid)

    ## backward warp, i.e. sample pixel (x,y) from (x+flow_x, y+flow_y)
    img1_pred = F.grid_sample(img2, grid, align_corners=True)

    if do_mask:
        mask = (grid[..., 0] > -1) & (grid[..., 0] < 1) & \
               (grid[..., 1] > -1) & (grid[..., 1] < 1)
        mask = mask[:, None].to(img2.dtype)
        return (img1_pred, mask)

    else:
        return (img1_pred, torch.ones_like(grid[..., 0][:, None]).float())