File size: 7,015 Bytes
7062b81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
import torch.nn as nn
from networks.blocks.raft import (
coords_grid,
SmallUpdateBlock, BidirCorrBlock
)
from networks.blocks.feat_enc import (
SmallEncoder
)
from networks.blocks.ifrnet import (
resize,
Encoder,
InitDecoder,
IntermediateDecoder
)
from networks.blocks.multi_flow import (
multi_flow_combine,
MultiFlowDecoder
)
class Model(nn.Module):
def __init__(self,
corr_radius=3,
corr_lvls=4,
num_flows=3,
channels=[20, 32, 44, 56],
skip_channels=20):
super(Model, self).__init__()
self.radius = corr_radius
self.corr_levels = corr_lvls
self.num_flows = num_flows
self.channels = channels
self.skip_channels = skip_channels
self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.)
self.encoder = Encoder(channels)
self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels)
self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels)
self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels)
self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows)
self.update4 = self._get_updateblock(44)
self.update3 = self._get_updateblock(32, 2)
self.update2 = self._get_updateblock(20, 4)
self.comb_block = nn.Sequential(
nn.Conv2d(3*num_flows, 6*num_flows, 3, 1, 1),
nn.PReLU(6*num_flows),
nn.Conv2d(6*num_flows, 3, 3, 1, 1),
)
def _get_updateblock(self, cdim, scale_factor=None):
return SmallUpdateBlock(cdim=cdim, hidden_dim=76, flow_dim=20, corr_dim=64,
fc_dim=68, scale_factor=scale_factor,
corr_levels=self.corr_levels, radius=self.radius)
def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1):
# convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0
# based on linear assumption
t1_scale = 1. / embt
t0_scale = 1. / (1. - embt)
if downsample != 1:
inv = 1 / downsample
flow0 = inv * resize(flow0, scale_factor=inv)
flow1 = inv * resize(flow1, scale_factor=inv)
corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale)
corr = torch.cat([corr0, corr1], dim=1)
flow = torch.cat([flow0, flow1], dim=1)
return corr, flow
def forward(self, img0, img1, embt, scale_factor=1.0, eval=False, **kwargs):
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
img0 = img0 - mean_
img1 = img1 - mean_
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0
img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1
b, _, h, w = img0_.shape
coord = coords_grid(b, h // 8, w // 8, img0.device)
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8]
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels)
# f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4]
# f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16]
f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_)
f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_)
######################################### the 4th decoder #########################################
up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, embt)
corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord,
up_flow0_4, up_flow1_4,
embt, downsample=1)
# residue update with lookup corr
delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4)
delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1)
up_flow0_4 = up_flow0_4 + delta_flow0_4
up_flow1_4 = up_flow1_4 + delta_flow1_4
ft_3_ = ft_3_ + delta_ft_3_
######################################### the 3rd decoder #########################################
up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4)
corr_3, flow_3 = self._corr_scale_lookup(corr_fn,
coord, up_flow0_3, up_flow1_3,
embt, downsample=2)
# residue update with lookup corr
delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3)
delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1)
up_flow0_3 = up_flow0_3 + delta_flow0_3
up_flow1_3 = up_flow1_3 + delta_flow1_3
ft_2_ = ft_2_ + delta_ft_2_
######################################### the 2nd decoder #########################################
up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3)
corr_2, flow_2 = self._corr_scale_lookup(corr_fn,
coord, up_flow0_2, up_flow1_2,
embt, downsample=4)
# residue update with lookup corr
delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2)
delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1)
up_flow0_2 = up_flow0_2 + delta_flow0_2
up_flow1_2 = up_flow1_2 + delta_flow1_2
ft_1_ = ft_1_ + delta_ft_1_
######################################### the 1st decoder #########################################
up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2)
if scale_factor != 1.0:
up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor)
mask = resize(mask, scale_factor=(1.0/scale_factor))
img_res = resize(img_res, scale_factor=(1.0/scale_factor))
imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1,
mask, img_res, mean_)
imgt_pred = torch.clamp(imgt_pred, 0, 1)
if eval:
return { 'imgt_pred': imgt_pred, }
else:
up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, h, w)
up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, h, w)
return {
'imgt_pred': imgt_pred,
'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4],
'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4],
'ft_pred': [ft_1_, ft_2_, ft_3_],
}
|