|
from pickle import TRUE |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from lib.net.geometry import orthogonal |
|
|
|
|
|
class SelfAttention(torch.nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, padding_mode='replicate') |
|
self.attention = nn.Conv3d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
padding=1, |
|
padding_mode='replicate', |
|
bias=False |
|
) |
|
with torch.no_grad(): |
|
self.attention.weight.copy_(torch.zeros_like(self.attention.weight)) |
|
|
|
def forward(self, x): |
|
features = self.conv(x) |
|
attention_mask = torch.sigmoid(self.attention(x)) |
|
return features * attention_mask |
|
|
|
|
|
class IFGeoNet(nn.Module): |
|
def __init__(self, cfg, hidden_dim=256): |
|
super(IFGeoNet, self).__init__() |
|
|
|
self.conv_in_partial = nn.Conv3d( |
|
1, 16, 3, padding=1, padding_mode='replicate' |
|
) |
|
|
|
self.SA = SelfAttention(4, 4) |
|
self.conv_0_fusion = nn.Conv3d(16, 32, 3, padding=1, padding_mode='replicate') |
|
self.conv_0_1_fusion = nn.Conv3d( |
|
32, 32, 3, padding=1, padding_mode='replicate' |
|
) |
|
|
|
self.conv_0 = nn.Conv3d(32, 32, 3, padding=1, padding_mode='replicate') |
|
self.conv_0_1 = nn.Conv3d( |
|
32, 32, 3, padding=1, padding_mode='replicate' |
|
) |
|
|
|
self.conv_1 = nn.Conv3d(32, 64, 3, padding=1, padding_mode='replicate') |
|
self.conv_1_1 = nn.Conv3d( |
|
64, 64, 3, padding=1, padding_mode='replicate' |
|
) |
|
|
|
self.conv_2 = nn.Conv3d(64, 128, 3, padding=1, padding_mode='replicate') |
|
self.conv_2_1 = nn.Conv3d( |
|
128, 128, 3, padding=1, padding_mode='replicate' |
|
) |
|
self.conv_3 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') |
|
self.conv_3_1 = nn.Conv3d( |
|
128, 128, 3, padding=1, padding_mode='replicate' |
|
) |
|
self.conv_4 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') |
|
self.conv_4_1 = nn.Conv3d(128, 128, 3, padding=1, padding_mode='replicate') |
|
|
|
feature_size = (1 + 32 + 32 + 64 + 128 + 128 + 128) + 3 |
|
self.fc_0 = nn.Conv1d(feature_size, hidden_dim * 2, 1) |
|
self.fc_1 = nn.Conv1d(hidden_dim * 2, hidden_dim, 1) |
|
self.fc_2 = nn.Conv1d(hidden_dim, hidden_dim, 1) |
|
self.fc_out = nn.Conv1d(hidden_dim, 1, 1) |
|
self.actvn = nn.ReLU(True) |
|
|
|
self.maxpool = nn.MaxPool3d(2) |
|
|
|
self.partial_conv_in_bn = nn.InstanceNorm3d(16) |
|
self.smpl_conv_in_bn = nn.InstanceNorm3d(4) |
|
|
|
self.conv0_1_bn_fusion = nn.InstanceNorm3d(32) |
|
self.conv0_1_bn = nn.InstanceNorm3d(32) |
|
self.conv1_1_bn = nn.InstanceNorm3d(64) |
|
self.conv2_1_bn = nn.InstanceNorm3d(128) |
|
self.conv3_1_bn = nn.InstanceNorm3d(128) |
|
self.conv4_1_bn = nn.InstanceNorm3d(128) |
|
|
|
self.l1_loss = nn.SmoothL1Loss() |
|
|
|
def forward(self, batch): |
|
|
|
p = orthogonal(batch["samples_geo"].permute(0, 2, 1), |
|
batch["calib"]).permute(0, 2, 1) |
|
x = batch["depth_voxels"] |
|
|
|
x = x.unsqueeze(1) |
|
p_features = p.transpose(1, -1) |
|
p = p.unsqueeze(1).unsqueeze(1) |
|
|
|
|
|
feature_0_partial = F.grid_sample(x, p, padding_mode='border', align_corners=True) |
|
net_partial = self.actvn(self.conv_in_partial(x)) |
|
net_partial = self.partial_conv_in_bn(net_partial) |
|
net_partial = self.maxpool(net_partial) |
|
|
|
|
|
net = self.actvn(self.conv_0_fusion(net_partial)) |
|
net = self.actvn(self.conv_0_1_fusion(net)) |
|
net = self.conv0_1_bn_fusion(net) |
|
feature_1_fused = F.grid_sample(net, p, padding_mode='border', align_corners=True) |
|
|
|
|
|
net = self.actvn(self.conv_0(net)) |
|
net = self.actvn(self.conv_0_1(net)) |
|
net = self.conv0_1_bn(net) |
|
feature_2 = F.grid_sample(net, p, padding_mode='border', align_corners=True) |
|
net = self.maxpool(net) |
|
|
|
net = self.actvn(self.conv_1(net)) |
|
net = self.actvn(self.conv_1_1(net)) |
|
net = self.conv1_1_bn(net) |
|
feature_3 = F.grid_sample(net, p, padding_mode='border', align_corners=True) |
|
net = self.maxpool(net) |
|
|
|
net = self.actvn(self.conv_2(net)) |
|
net = self.actvn(self.conv_2_1(net)) |
|
net = self.conv2_1_bn(net) |
|
feature_4 = F.grid_sample(net, p, padding_mode='border', align_corners=True) |
|
net = self.maxpool(net) |
|
|
|
net = self.actvn(self.conv_3(net)) |
|
net = self.actvn(self.conv_3_1(net)) |
|
net = self.conv3_1_bn(net) |
|
feature_5 = F.grid_sample(net, p, padding_mode='border', align_corners=True) |
|
net = self.maxpool(net) |
|
|
|
net = self.actvn(self.conv_4(net)) |
|
net = self.actvn(self.conv_4_1(net)) |
|
net = self.conv4_1_bn(net) |
|
feature_6 = F.grid_sample(net, p, padding_mode='border', align_corners=True) |
|
|
|
|
|
|
|
features = torch.cat(( |
|
feature_0_partial, feature_1_fused, feature_2, feature_3, feature_4, feature_5, |
|
feature_6 |
|
), |
|
dim=1) |
|
shape = features.shape |
|
features = torch.reshape( |
|
features, (shape[0], shape[1] * shape[3], shape[4]) |
|
) |
|
|
|
features = torch.cat((features, p_features), dim=1) |
|
|
|
net = self.actvn(self.fc_0(features)) |
|
net = self.actvn(self.fc_1(net)) |
|
net = self.actvn(self.fc_2(net)) |
|
net = self.fc_out(net).squeeze(1) |
|
|
|
return net |
|
|
|
def compute_loss(self, prds, tgts): |
|
|
|
loss = self.l1_loss(prds, tgts) |
|
|
|
return loss |
|
|