File size: 4,075 Bytes
6b1e9f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# borrow from optimization https://github.com/wangsen1312/joints2smpl
import os
import argparse
import pickle

import h5py
import natsort
import smplx

import torch

from mld.transforms.joints2rots import config
from mld.transforms.joints2rots.smplify import SMPLify3D

parser = argparse.ArgumentParser()
parser.add_argument("--pkl", type=str, default=None, help="pkl motion file")
parser.add_argument("--dir", type=str, default=None, help="pkl motion folder")
parser.add_argument("--num_smplify_iters", type=int, default=150, help="num of smplify iters")
parser.add_argument("--cuda", type=bool, default=True, help="enables cuda")
parser.add_argument("--gpu_ids", type=int, default=0, help="choose gpu ids")
parser.add_argument("--num_joints", type=int, default=22, help="joint number")
parser.add_argument("--joint_category", type=str, default="AMASS", help="use correspondence")
parser.add_argument("--fix_foot", type=str, default="False", help="fix foot or not")
opt = parser.parse_args()
print(opt)

if opt.pkl:
    paths = [opt.pkl]
elif opt.dir:
    paths = []
    file_list = natsort.natsorted(os.listdir(opt.dir))
    for item in file_list:
        if item.endswith('.pkl') and not item.endswith("_mesh.pkl"):
            paths.append(os.path.join(opt.dir, item))
else:
    raise ValueError(f'{opt.pkl} and {opt.dir} are both None!')

for path in paths:
    # load joints
    if os.path.exists(path.replace('.pkl', '_mesh.pkl')):
        print(f"{path} is rendered! skip!")
        continue

    with open(path, 'rb') as f:
        data = pickle.load(f)

    joints = data['joints']
    # load predefined something
    device = torch.device("cuda:" + str(opt.gpu_ids) if opt.cuda else "cpu")
    print(config.SMPL_MODEL_DIR)
    smplxmodel = smplx.create(
        config.SMPL_MODEL_DIR,
        model_type="smpl",
        gender="neutral",
        ext="pkl",
        batch_size=joints.shape[0],
    ).to(device)

    # load the mean pose as original
    smpl_mean_file = config.SMPL_MEAN_FILE

    file = h5py.File(smpl_mean_file, "r")
    init_mean_pose = (
        torch.from_numpy(file["pose"][:])
        .unsqueeze(0).repeat(joints.shape[0], 1)
        .float()
        .to(device)
    )
    init_mean_shape = (
        torch.from_numpy(file["shape"][:])
        .unsqueeze(0).repeat(joints.shape[0], 1)
        .float()
        .to(device)
    )
    cam_trans_zero = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(device)

    # initialize SMPLify
    smplify = SMPLify3D(
        smplxmodel=smplxmodel,
        batch_size=joints.shape[0],
        joints_category=opt.joint_category,
        num_iters=opt.num_smplify_iters,
        device=device,
    )
    print("initialize SMPLify3D done!")

    print("Start SMPLify!")
    keypoints_3d = torch.Tensor(joints).to(device).float()

    if opt.joint_category == "AMASS":
        confidence_input = torch.ones(opt.num_joints)
        # make sure the foot and ankle
        if opt.fix_foot:
            confidence_input[7] = 1.5
            confidence_input[8] = 1.5
            confidence_input[10] = 1.5
            confidence_input[11] = 1.5
    else:
        print("Such category not settle down!")

    # ----- from initial to fitting -------
    (
        new_opt_vertices,
        new_opt_joints,
        new_opt_pose,
        new_opt_betas,
        new_opt_cam_t,
        new_opt_joint_loss,
    ) = smplify(
        init_mean_pose.detach(),
        init_mean_shape.detach(),
        cam_trans_zero.detach(),
        keypoints_3d,
        conf_3d=confidence_input.to(device)
    )

    # fix shape
    betas = torch.zeros_like(new_opt_betas)
    root = keypoints_3d[:, 0, :]

    output = smplxmodel(
        betas=betas,
        global_orient=new_opt_pose[:, :3],
        body_pose=new_opt_pose[:, 3:],
        transl=root,
        return_verts=True,
    )
    vertices = output.vertices.detach().cpu().numpy()
    data['vertices'] = vertices

    save_file = path.replace('.pkl', '_mesh.pkl')
    with open(save_file, 'wb') as f:
        pickle.dump(data, f)
    print(f'vertices saved in {save_file}')