ZeroShape / model /compute_graph /graph_shape.py
zxhuang1698's picture
initial commit
414b431
raw
history blame
10.6 kB
import torch
import torch.nn as nn
from utils.util import EasyDict as edict
from utils.loss import Loss
from model.shape.implicit import Implicit
from model.shape.seen_coord_enc import CoordEncAtt, CoordEncRes
from model.shape.rgb_enc import RGBEncAtt, RGBEncRes
from model.depth.dpt_depth import DPTDepthModel
from utils.util import toggle_grad, interpolate_coordmap, get_child_state_dict
from utils.camera import unproj_depth, valid_norm_fac
from utils.layers import Bottleneck_Conv
class Graph(nn.Module):
def __init__(self, opt):
super().__init__()
# define the intrinsics head
self.intr_feat_channels = 768
self.intr_head = nn.Sequential(
Bottleneck_Conv(self.intr_feat_channels, kernel_size=3),
Bottleneck_Conv(self.intr_feat_channels, kernel_size=3),
)
self.intr_pool = nn.AdaptiveAvgPool2d((1, 1))
self.intr_proj = nn.Linear(self.intr_feat_channels, 3)
# init the last linear layer so it outputs zeros
nn.init.zeros_(self.intr_proj.weight)
nn.init.zeros_(self.intr_proj.bias)
# define the depth pred model based on omnidata
self.dpt_depth = DPTDepthModel(backbone='vitb_rn50_384')
# load the pretrained depth model
# when intrinsics need to be predicted we need to load that part as well
self.load_pretrained_depth(opt)
if opt.optim.fix_dpt:
toggle_grad(self.dpt_depth, False)
toggle_grad(self.intr_head, False)
toggle_grad(self.intr_proj, False)
# encoder that encode seen surface to impl conditioning vec
if opt.arch.depth.encoder == 'resnet':
opt.arch.depth.dsp = 1
self.coord_encoder = CoordEncRes(opt)
else:
self.coord_encoder = CoordEncAtt(embed_dim=opt.arch.latent_dim, n_blocks=opt.arch.depth.n_blocks,
num_heads=opt.arch.num_heads, win_size=opt.arch.win_size//opt.arch.depth.dsp)
# rgb branch (not used in final model, keep here for extension)
if opt.arch.rgb.encoder == 'resnet':
self.rgb_encoder = RGBEncRes(opt)
elif opt.arch.rgb.encoder == 'transformer':
self.rgb_encoder = RGBEncAtt(img_size=opt.H, embed_dim=opt.arch.latent_dim, n_blocks=opt.arch.rgb.n_blocks,
num_heads=opt.arch.num_heads, win_size=opt.arch.win_size)
else:
self.rgb_encoder = None
# implicit function
feat_res = opt.H // opt.arch.win_size
self.impl_network = Implicit(feat_res**2, latent_dim=opt.arch.latent_dim*2 if self.rgb_encoder else opt.arch.latent_dim,
semantic=self.rgb_encoder is not None, n_channels=opt.arch.impl.n_channels,
n_blocks_attn=opt.arch.impl.att_blocks, n_layers_mlp=opt.arch.impl.mlp_layers,
num_heads=opt.arch.num_heads, posenc_3D=opt.arch.impl.posenc_3D,
mlp_ratio=opt.arch.impl.mlp_ratio, skip_in=opt.arch.impl.skip_in,
pos_perlayer=opt.arch.impl.posenc_perlayer)
# loss functions
self.loss_fns = Loss(opt)
def load_pretrained_depth(self, opt):
if opt.pretrain.depth:
# loading from our pretrained depth and intr model
if opt.device == 0:
print("loading dpt depth from {}...".format(opt.pretrain.depth))
checkpoint = torch.load(opt.pretrain.depth, map_location="cuda:{}".format(opt.device))
self.dpt_depth.load_state_dict(get_child_state_dict(checkpoint["graph"], "dpt_depth"))
# load the intr head
if opt.device == 0:
print("loading pretrained intr from {}...".format(opt.pretrain.depth))
self.intr_head.load_state_dict(get_child_state_dict(checkpoint["graph"], "intr_head"))
self.intr_proj.load_state_dict(get_child_state_dict(checkpoint["graph"], "intr_proj"))
elif opt.arch.depth.pretrained:
# loading from omnidata weights
if opt.device == 0:
print("loading dpt depth from {}...".format(opt.arch.depth.pretrained))
checkpoint = torch.load(opt.arch.depth.pretrained, map_location="cuda:{}".format(opt.device))
state_dict = checkpoint['model_state_dict']
self.dpt_depth.load_state_dict(state_dict)
def intr_param2mtx(self, opt, intr_params):
'''
Parameters:
opt: config
intr_params: [B, 3], [scale_f, delta_cx, delta_cy]
Return:
intr: [B, 3, 3]
'''
batch_size = len(intr_params)
f = 1.3875
intr = torch.zeros(3, 3).float().to(intr_params.device).unsqueeze(0).repeat(batch_size, 1, 1)
intr[:, 2, 2] += 1
# scale the focal length
# range: [-1, 1], symmetric
scale_f = torch.tanh(intr_params[:, 0])
# range: [1/4, 4], symmetric
scale_f = torch.pow(4. , scale_f)
intr[:, 0, 0] += f * opt.W * scale_f
intr[:, 1, 1] += f * opt.H * scale_f
# shift the optic center, (at most to the image border)
shift_cx = torch.tanh(intr_params[:, 1]) * opt.W / 2
shift_cy = torch.tanh(intr_params[:, 2]) * opt.H / 2
intr[:, 0, 2] += opt.W / 2 + shift_cx
intr[:, 1, 2] += opt.H / 2 + shift_cy
return intr
def forward(self, opt, var, training=False, get_loss=True):
batch_size = len(var.idx)
# encode the rgb, [B, 3, H, W] -> [B, 1+H/(ws)*W/(ws), C], not used in our final model
var.latent_semantic = self.rgb_encoder(var.rgb_input_map) if self.rgb_encoder else None
# predict the depth map and intrinsics
var.depth_pred, intr_feat = self.dpt_depth(var.rgb_input_map, get_feat=True)
depth_map = var.depth_pred
# predict the intrinsics
intr_feat = self.intr_head(intr_feat)
intr_feat = self.intr_pool(intr_feat).squeeze(-1).squeeze(-1)
intr_params = self.intr_proj(intr_feat)
# [B, 3, 3]
var.intr_pred = self.intr_param2mtx(opt, intr_params)
intr_forward = var.intr_pred
# record the validity mask, [B, H*W]
var.validity_mask = (var.mask_input_map>0.5).float().view(batch_size, -1)
# project the depth to 3D points in view-centric frame
# [B, H*W, 3], in camera coordinates
seen_points_3D_pred = unproj_depth(opt, depth_map, intr_forward)
# [B, H*W, 3], [B, 1, H, W] (boolean) -> [B, 3], [B]
seen_points_mean_pred, seen_points_scale_pred = valid_norm_fac(seen_points_3D_pred, var.mask_input_map > 0.5)
# normalize the seen surface, [B, H*W, 3]
var.seen_points = (seen_points_3D_pred - seen_points_mean_pred.unsqueeze(1)) / seen_points_scale_pred.unsqueeze(-1).unsqueeze(-1)
var.seen_points[(var.mask_input_map<=0.5).view(batch_size, -1)] = 0
# [B, 3, H, W]
seen_3D_map = var.seen_points.view(batch_size, opt.H, opt.W, 3).permute(0, 3, 1, 2).contiguous()
seen_3D_dsp, mask_dsp = interpolate_coordmap(seen_3D_map, var.mask_input_map, (opt.H//opt.arch.depth.dsp, opt.W//opt.arch.depth.dsp))
# encode the depth, [B, 1, H/k, W/k] -> [B, 1+H/(ws)*W/(ws), C]
if opt.arch.depth.encoder == 'resnet':
var.latent_depth = self.coord_encoder(seen_3D_dsp, mask_dsp)
else:
var.latent_depth = self.coord_encoder(seen_3D_dsp.permute(0, 2, 3, 1).contiguous(), mask_dsp.squeeze(1)>0.5)
var.pose = var.pose_gt
# forward for loss calculation (only during training)
if 'gt_sample_points' in var and 'gt_sample_sdf' in var:
with torch.no_grad():
# get the normalizing fac based on the GT seen surface
# project the GT depth to 3D points in view-centric frame
# [B, H*W, 3], in camera coordinates
seen_points_3D_gt = unproj_depth(opt, var.depth_input_map, var.intr)
# [B, H*W, 3], [B, 1, H, W] (boolean) -> [B, 3], [B]
seen_points_mean_gt, seen_points_scale_gt = valid_norm_fac(seen_points_3D_gt, var.mask_input_map > 0.5)
var.seen_points_gt = (seen_points_3D_gt - seen_points_mean_gt.unsqueeze(1)) / seen_points_scale_gt.unsqueeze(-1).unsqueeze(-1)
var.seen_points_gt[(var.mask_input_map<=0.5).view(batch_size, -1)] = 0
# transform the GT points accordingly
# [B, 3, 3]
R_gt = var.pose_gt[:, :, :3]
# [B, 3, 1]
T_gt = var.pose_gt[:, :, 3:]
# [B, 3, N]
gt_sample_points_transposed = var.gt_sample_points.permute(0, 2, 1).contiguous()
# camera coordinates, [B, N, 3]
gt_sample_points_cam = (R_gt @ gt_sample_points_transposed + T_gt).permute(0, 2, 1).contiguous()
# normalize with seen std and mean, [B, N, 3]
var.gt_points_cam = (gt_sample_points_cam - seen_points_mean_gt.unsqueeze(1)) / seen_points_scale_gt.unsqueeze(-1).unsqueeze(-1)
# get near-surface points for visualization
# [B, 100, 3]
close_surf_idx = torch.topk(var.gt_sample_sdf.abs(), k=100, dim=1, largest=False)[1].unsqueeze(-1).repeat(1, 1, 3)
# [B, 100, 3]
var.gt_surf_points = torch.gather(var.gt_points_cam, dim=1, index=close_surf_idx)
# [B, N], [B, N, 1+feat_res**2], inference the impl_network for 3D loss
var.pred_sample_occ, attn = self.impl_network(var.latent_depth, var.latent_semantic, var.gt_points_cam)
# calculate the loss if needed
if get_loss:
loss = self.compute_loss(opt, var, training)
return var, loss
return var
def compute_loss(self, opt, var, training=False):
loss = edict()
if opt.loss_weight.depth is not None:
loss.depth = self.loss_fns.depth_loss(var.depth_pred, var.depth_input_map, var.mask_input_map)
if opt.loss_weight.intr is not None and training:
loss.intr = self.loss_fns.intr_loss(var.seen_points, var.seen_points_gt, var.validity_mask)
if opt.loss_weight.shape is not None and training:
loss.shape = self.loss_fns.shape_loss(var.pred_sample_occ, var.gt_sample_sdf)
return loss