|
import torch |
|
import os, sys |
|
import pickle |
|
import smplx |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
sys.path.append(os.path.dirname(__file__)) |
|
from customloss import (camera_fitting_loss, |
|
body_fitting_loss, |
|
camera_fitting_loss_3d, |
|
body_fitting_loss_3d, |
|
) |
|
from prior import MaxMixturePrior |
|
import config |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def guess_init_3d(model_joints, |
|
j3d, |
|
joints_category="orig"): |
|
"""Initialize the camera translation via triangle similarity, by using the torso joints . |
|
:param model_joints: SMPL model with pre joints |
|
:param j3d: 25x3 array of Kinect Joints |
|
:returns: 3D vector corresponding to the estimated camera translation |
|
""" |
|
|
|
gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder'] |
|
gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints] |
|
|
|
if joints_category=="orig": |
|
joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints] |
|
elif joints_category=="AMASS": |
|
joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints] |
|
elif joints_category=="MMM": |
|
joints_ind_category = [config.MMM_JOINT_MAP[joint] for joint in gt_joints] |
|
else: |
|
print("NO SUCH JOINTS CATEGORY!") |
|
|
|
sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(dim=1) |
|
init_t = sum_init_t / 4.0 |
|
return init_t |
|
|
|
|
|
|
|
class SMPLify3D(): |
|
"""Implementation of SMPLify, use 3D joints.""" |
|
|
|
def __init__(self, |
|
smplxmodel, |
|
step_size=1e-2, |
|
batch_size=1, |
|
num_iters=100, |
|
use_collision=False, |
|
use_lbfgs=True, |
|
joints_category="orig", |
|
device=torch.device('cuda:0'), |
|
): |
|
|
|
|
|
self.batch_size = batch_size |
|
self.device = device |
|
self.step_size = step_size |
|
|
|
self.num_iters = num_iters |
|
|
|
self.use_lbfgs = use_lbfgs |
|
|
|
self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR, |
|
num_gaussians=8, |
|
dtype=torch.float32).to(device) |
|
|
|
self.use_collision = use_collision |
|
if self.use_collision: |
|
self.part_segm_fn = config.Part_Seg_DIR |
|
|
|
|
|
self.smpl = smplxmodel |
|
|
|
self.model_faces = smplxmodel.faces_tensor.view(-1) |
|
|
|
|
|
self.joints_category = joints_category |
|
|
|
if joints_category=="orig": |
|
self.smpl_index = config.full_smpl_idx |
|
self.corr_index = config.full_smpl_idx |
|
elif joints_category=="AMASS": |
|
self.smpl_index = config.amass_smpl_idx |
|
self.corr_index = config.amass_idx |
|
|
|
|
|
|
|
else: |
|
self.smpl_index = None |
|
self.corr_index = None |
|
print("NO SUCH JOINTS CATEGORY!") |
|
|
|
|
|
def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0): |
|
"""Perform body fitting. |
|
Input: |
|
init_pose: SMPL pose estimate |
|
init_betas: SMPL betas estimate |
|
init_cam_t: Camera translation estimate |
|
j3d: joints 3d aka keypoints |
|
conf_3d: confidence for 3d joints |
|
seq_ind: index of the sequence |
|
Returns: |
|
vertices: Vertices of optimized shape |
|
joints: 3D joints of optimized shape |
|
pose: SMPL pose parameters of optimized shape |
|
betas: SMPL beta parameters of optimized shape |
|
camera_translation: Camera translation |
|
""" |
|
|
|
|
|
search_tree = None |
|
pen_distance = None |
|
filter_faces = None |
|
|
|
if self.use_collision: |
|
from mesh_intersection.bvh_search_tree import BVH |
|
import mesh_intersection.loss as collisions_loss |
|
from mesh_intersection.filter_faces import FilterFaces |
|
|
|
search_tree = BVH(max_collisions=8) |
|
|
|
pen_distance = collisions_loss.DistanceFieldPenetrationLoss( |
|
sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True) |
|
|
|
if self.part_segm_fn: |
|
|
|
part_segm_fn = os.path.expandvars(self.part_segm_fn) |
|
with open(part_segm_fn, 'rb') as faces_parents_file: |
|
face_segm_data = pickle.load(faces_parents_file, encoding='latin1') |
|
faces_segm = face_segm_data['segm'] |
|
faces_parents = face_segm_data['parents'] |
|
|
|
filter_faces = FilterFaces( |
|
faces_segm=faces_segm, faces_parents=faces_parents, |
|
ign_part_pairs=None).to(device=self.device) |
|
|
|
|
|
|
|
body_pose = init_pose[:, 3:].detach().clone() |
|
global_orient = init_pose[:, :3].detach().clone() |
|
betas = init_betas.detach().clone() |
|
|
|
|
|
smpl_output = self.smpl(global_orient=global_orient, |
|
body_pose=body_pose, |
|
betas=betas) |
|
model_joints = smpl_output.joints |
|
|
|
init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).detach() |
|
camera_translation = init_cam_t.clone() |
|
|
|
preserve_pose = init_pose[:, 3:].detach().clone() |
|
|
|
|
|
body_pose.requires_grad = False |
|
betas.requires_grad = False |
|
global_orient.requires_grad = True |
|
camera_translation.requires_grad = True |
|
|
|
camera_opt_params = [global_orient, camera_translation] |
|
|
|
if self.use_lbfgs: |
|
camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters, |
|
lr=self.step_size, line_search_fn='strong_wolfe') |
|
for i in range(10): |
|
def closure(): |
|
camera_optimizer.zero_grad() |
|
smpl_output = self.smpl(global_orient=global_orient, |
|
body_pose=body_pose, |
|
betas=betas) |
|
model_joints = smpl_output.joints |
|
|
|
loss = camera_fitting_loss_3d(model_joints, camera_translation, |
|
init_cam_t, j3d, self.joints_category) |
|
loss.backward() |
|
return loss |
|
|
|
camera_optimizer.step(closure) |
|
else: |
|
camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999)) |
|
|
|
for i in range(20): |
|
smpl_output = self.smpl(global_orient=global_orient, |
|
body_pose=body_pose, |
|
betas=betas) |
|
model_joints = smpl_output.joints |
|
|
|
loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation, |
|
init_cam_t, j3d[:, self.corr_index], self.joints_category) |
|
camera_optimizer.zero_grad() |
|
loss.backward() |
|
camera_optimizer.step() |
|
|
|
|
|
|
|
|
|
body_pose.requires_grad = True |
|
global_orient.requires_grad = True |
|
camera_translation.requires_grad = True |
|
|
|
|
|
if seq_ind == 0: |
|
betas.requires_grad = True |
|
body_opt_params = [body_pose, betas, global_orient, camera_translation] |
|
else: |
|
betas.requires_grad = False |
|
body_opt_params = [body_pose, global_orient, camera_translation] |
|
|
|
if self.use_lbfgs: |
|
body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters, |
|
lr=self.step_size, line_search_fn='strong_wolfe') |
|
|
|
for i in tqdm(range(self.num_iters), desc=f"LBFGS iter: "): |
|
|
|
def closure(): |
|
body_optimizer.zero_grad() |
|
smpl_output = self.smpl(global_orient=global_orient, |
|
body_pose=body_pose, |
|
betas=betas) |
|
model_joints = smpl_output.joints |
|
model_vertices = smpl_output.vertices |
|
|
|
loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, |
|
j3d[:, self.corr_index], self.pose_prior, |
|
joints3d_conf=conf_3d, |
|
joint_loss_weight=600.0, |
|
pose_preserve_weight=5.0, |
|
use_collision=self.use_collision, |
|
model_vertices=model_vertices, model_faces=self.model_faces, |
|
search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) |
|
loss.backward() |
|
return loss |
|
|
|
body_optimizer.step(closure) |
|
else: |
|
body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999)) |
|
|
|
for i in range(self.num_iters): |
|
smpl_output = self.smpl(global_orient=global_orient, |
|
body_pose=body_pose, |
|
betas=betas) |
|
model_joints = smpl_output.joints |
|
model_vertices = smpl_output.vertices |
|
|
|
loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, |
|
j3d[:, self.corr_index], self.pose_prior, |
|
joints3d_conf=conf_3d, |
|
joint_loss_weight=600.0, |
|
use_collision=self.use_collision, |
|
model_vertices=model_vertices, model_faces=self.model_faces, |
|
search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) |
|
body_optimizer.zero_grad() |
|
loss.backward() |
|
body_optimizer.step() |
|
|
|
|
|
with torch.no_grad(): |
|
smpl_output = self.smpl(global_orient=global_orient, |
|
body_pose=body_pose, |
|
betas=betas, return_full_pose=True) |
|
model_joints = smpl_output.joints |
|
model_vertices = smpl_output.vertices |
|
|
|
final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation, |
|
j3d[:, self.corr_index], self.pose_prior, |
|
joints3d_conf=conf_3d, |
|
joint_loss_weight=600.0, |
|
use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces, |
|
search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces) |
|
|
|
vertices = smpl_output.vertices.detach() |
|
joints = smpl_output.joints.detach() |
|
pose = torch.cat([global_orient, body_pose], dim=-1).detach() |
|
betas = betas.detach() |
|
|
|
return vertices, joints, pose, betas, camera_translation, final_loss |