wyysf commited on
Commit
98e2c81
0 Parent(s):

Duplicate from wyysf/GenMM-test

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
GPS.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import itertools
7
+ from tensorboardX import SummaryWriter
8
+
9
+ from NN.losses import make_criteria
10
+ from utils.base import logger
11
+
12
+ class GPS:
13
+ def __init__(self,
14
+ init_mode: str = 'random_synthesis',
15
+ noise_sigma: float = 1.0,
16
+ coarse_ratio: float = 0.2,
17
+ coarse_ratio_factor: float = 6,
18
+ pyr_factor: float = 0.75,
19
+ num_stages_limit: int = -1,
20
+ device: str = 'cuda:0',
21
+ silent: bool = False
22
+ ):
23
+ '''
24
+ Args:
25
+ init_mode:
26
+ - 'random_synthesis': init with random seed
27
+ - 'random': init with random seed
28
+ noise_sigma: float = 1.0, random noise.
29
+ coarse_ratio: float = 0.2, ratio at the coarse level.
30
+ pyr_factor: float = 0.75, pyramid factor.
31
+ num_stages_limit: int = -1, no limit.
32
+ device: str = 'cuda:0', default device.
33
+ silent: bool = False, mute the output.
34
+ '''
35
+ self.init_mode = init_mode
36
+ self.noise_sigma = noise_sigma
37
+ self.coarse_ratio = coarse_ratio
38
+ self.coarse_ratio_factor = coarse_ratio_factor
39
+ self.pyr_factor = pyr_factor
40
+ self.num_stages_limit = num_stages_limit
41
+ self.device = torch.device(device)
42
+ self.silent = silent
43
+
44
+ def _get_pyramid_lengths(self, dest, ext=None):
45
+ """Get a list of pyramid lengths"""
46
+ if self.coarse_ratio == -1:
47
+ self.coarse_ratio = np.around(ext['criteria']['patch_size'] * self.coarse_ratio_factor / dest, 2)
48
+
49
+ lengths = [int(np.round(dest * self.coarse_ratio))]
50
+ while lengths[-1] < dest:
51
+ lengths.append(int(np.round(lengths[-1] / self.pyr_factor)))
52
+ if lengths[-1] == lengths[-2]:
53
+ lengths[-1] += 1
54
+ lengths[-1] = dest
55
+
56
+ return lengths
57
+
58
+ def _get_target_pyramid(self, target, ext=None):
59
+ """Reads a target motion(s) and create a pyraimd out of it. Ordered in increatorch.sing size"""
60
+ self._num_target = len(target)
61
+ lengths = []
62
+ min_len = 10000
63
+ for i in range(len(target)):
64
+ new_length = self._get_pyramid_lengths(len(target[i]), ext)
65
+ min_len = min(min_len, len(new_length))
66
+ if self.num_stages_limit != -1:
67
+ new_length = new_length[:self.num_stages_limit]
68
+ lengths.append(new_length)
69
+ for i in range(len(target)):
70
+ lengths[i] = lengths[i][-min_len:]
71
+ self.pyraimd_lengths = lengths
72
+
73
+ target_pyramid = [[] for _ in range(len(lengths[0]))]
74
+ for step in range(len(lengths[0])):
75
+ for i in range(len(target)):
76
+ length = lengths[i][step]
77
+ motion = target[i]
78
+ target_pyramid[step].append(motion.sample(size=length).to(self.device))
79
+ # target_pyramid[step].append(motion.pos2velo(motion.sample(size=length)))
80
+ # motion.motion_data = motion.pos2velo(motion.motion_data)
81
+ # target_pyramid[step].append(motion.sample(size=length))
82
+ # motion.motion_data = motion.velo2pos(motion.motion_data)
83
+
84
+ if not self.silent:
85
+ print('Levels:', lengths)
86
+ for i in range(len(target_pyramid)):
87
+ print(f'Number of clips in target pyramid {i} is {len(target_pyramid[i])}: {[[tgt.min(), tgt.max()] for tgt in target_pyramid[i]]}')
88
+
89
+ return target_pyramid
90
+
91
+ def _get_initial_motion(self):
92
+ """Prepare the initial motion for optimization"""
93
+ if 'random_synthesis' in str(self.init_mode):
94
+ m = self.init_mode.split('/')[-1]
95
+ if m =='random_synthesis':
96
+ final_length = sum([i[-1] for i in self.pyraimd_lengths])
97
+ elif 'x' in m:
98
+ final_length = int(m.replace('x', '')) * sum([i[-1] for i in self.pyraimd_lengths])
99
+ elif (self.init_mode.split('/')[-1]).isdigit():
100
+ final_length = int(self.init_mode.split('/')[-1])
101
+ else:
102
+ raise ValueError(f'incorrect init_mode: {self.init_mode}')
103
+
104
+ self.synthesized_lengths = self._get_pyramid_lengths(final_length)
105
+
106
+ else:
107
+ raise ValueError(f'Unsupported init_mode {self.init_mode}')
108
+
109
+ initial_motion = F.interpolate(torch.cat([self.target_pyramid[0][i] for i in range(self._num_target)], dim=-1),
110
+ size=self.synthesized_lengths[0], mode='linear', align_corners=True)
111
+ if self.noise_sigma > 0:
112
+ initial_motion_w_noise = initial_motion + torch.randn_like(initial_motion) * self.noise_sigma
113
+ initial_motion_w_noise = torch.fmod(initial_motion_w_noise, 1.0)
114
+ else:
115
+ initial_motion_w_noise = initial_motion
116
+
117
+ if not self.silent:
118
+ print('Synthesized lengths:', self.synthesized_lengths)
119
+ print('Initial motion:', initial_motion.min(), initial_motion.max())
120
+ print('Initial motion with noise:', initial_motion_w_noise.min(), initial_motion_w_noise.max())
121
+
122
+ return initial_motion_w_noise
123
+
124
+ def run(self, target, mode="backpropagate", ext=None, debug_dir=None):
125
+ '''
126
+ Run the patch-based motion synthesis.
127
+
128
+ Args:
129
+ target (torch.Tensor): Target data.
130
+ mode (str): Optimization mode. Support ['backpropagate', 'match_and_blend']
131
+ ext (dict): extra data or constrain.
132
+ debug_dir (str): Debug directory.
133
+ '''
134
+ # preprare data
135
+ self.target_pyramid = self._get_target_pyramid(target, ext)
136
+ self.synthesized = self._get_initial_motion()
137
+ if debug_dir is not None:
138
+ writer = SummaryWriter(log_dir=debug_dir)
139
+
140
+ # prepare configuration
141
+ if mode == "backpropagate":
142
+ self.synthesized.requires_grad_(True)
143
+ assert 'criteria' in ext.keys(), 'Please specify a criteria for synthsis.'
144
+ criteria = make_criteria(ext['criteria']).to(self.device)
145
+ elif mode == "match_and_blend":
146
+ self.synthesized.requires_grad_(False)
147
+ assert 'criteria' in ext.keys(), 'Please specify a criteria for synthsis.'
148
+ criteria = make_criteria(ext['criteria']).to(self.device)
149
+ else:
150
+ raise ValueError(f'Unsupported mode: {mode}')
151
+
152
+ # perform synthsis
153
+ self.pbar = logger(ext['num_itrs'], len(self.target_pyramid))
154
+ ext['pbar'] = self.pbar
155
+ for lvl, lvl_target in enumerate(self.target_pyramid):
156
+ self.pbar.new_lvl()
157
+ if lvl > 0:
158
+ with torch.no_grad():
159
+ self.synthesized = F.interpolate(self.synthesized.detach(), size=self.synthesized_lengths[lvl], mode='linear')
160
+ if mode == "backpropagate":
161
+ self.synthesized.requires_grad_(True)
162
+
163
+ if mode == "backpropagate": # direct optimize the synthesized motion
164
+ self.synthesized, losses = GPS.backpropagate(self.synthesized, lvl_target, criteria, ext=ext)
165
+ elif mode == "match_and_blend":
166
+ self.synthesized, losses = GPS.match_and_blend(self.synthesized, lvl_target, criteria, ext=ext)
167
+
168
+ criteria.clean_cache()
169
+ if debug_dir:
170
+ for itr in range(len(losses)):
171
+ writer.add_scalar(f'optimize/losses_lvl{lvl}', losses[itr], itr)
172
+ self.pbar.pbar.close()
173
+
174
+
175
+ return self.synthesized.detach()
176
+
177
+ @staticmethod
178
+ def backpropagate(synthesized, targets, criteria=None, ext=None):
179
+ """
180
+ Minimizes criteria(synthesized, target) for num_steps SGD steps
181
+ Args:
182
+ targets (torch.Tensor): Target data.
183
+ ext (dict): extra configurations.
184
+ """
185
+ if criteria is None:
186
+ assert 'criteria' in ext.keys(), 'Criteria is not set'
187
+ criteria = make_criteria(ext['criteria']).to(synthesized.device)
188
+
189
+ optim = None
190
+ if 'optimizer' in ext.keys():
191
+ if ext['optimizer'] == 'Adam':
192
+ optim = torch.optim.Adam([synthesized], lr=ext['lr'])
193
+ elif ext['optimizer'] == 'SGD':
194
+ optim = torch.optim.SGD([synthesized], lr=ext['lr'])
195
+ elif ext['optimizer'] == 'RMSprop':
196
+ optim = torch.optim.RMSprop([synthesized], lr=ext['lr'])
197
+ else:
198
+ print(f'use default RMSprop optimizer')
199
+ optim = torch.optim.RMSprop([synthesized], lr=ext['lr']) if optim is None else optim
200
+ # optim = torch.optim.Adam([synthesized], lr=ext['lr']) if optim is None else optim
201
+ lr_decay = np.exp(np.log(0.333) / ext['num_itrs'])
202
+
203
+ # other constraints
204
+ trajectory = ext['trajectory'] if 'trajectory' in ext.keys() else None
205
+
206
+ losses = []
207
+ for _i in range(ext['num_itrs']):
208
+ optim.zero_grad()
209
+
210
+ loss = criteria(synthesized, targets)
211
+
212
+ if trajectory is not None: ## velo constrain
213
+ target_traj = F.interpolate(trajectory, size=synthesized.shape[-1], mode='linear')
214
+ # target_traj = F.interpolate(trajectory, size=synthesized.shape[-1], mode='linear', align_corners=False)
215
+ target_velo = ext['pos2velo'](target_traj)
216
+
217
+ velo_mask = [-3, -1]
218
+ loss += 1 * F.l1_loss(synthesized[:, velo_mask, :], target_velo[:, velo_mask, :])
219
+
220
+ loss.backward()
221
+ optim.step()
222
+
223
+ # Update staus
224
+ losses.append(loss.item())
225
+ if 'pbar' in ext.keys():
226
+ ext['pbar'].step()
227
+ ext['pbar'].print()
228
+
229
+ return synthesized, losses
230
+
231
+ @staticmethod
232
+ @torch.no_grad()
233
+ def match_and_blend(synthesized, targets, criteria, ext):
234
+ """
235
+ Minimizes criteria(synthesized, target)
236
+ Args:
237
+ targets (torch.Tensor): Target data.
238
+ ext (dict): extra configurations.
239
+ """
240
+ losses = []
241
+ for _i in range(ext['num_itrs']):
242
+ if 'parts_list' in ext.keys():
243
+ def extract_part_motions(motion, parts_list):
244
+ part_motions = []
245
+ n_frames = motion.shape[-1]
246
+ rot, pos = motion[:, :-3, :].reshape(-1, 6, n_frames), motion[:, -3:, :]
247
+
248
+ for part in parts_list:
249
+ # part -= 1
250
+ part = [i -1 for i in part]
251
+
252
+ # print(part)
253
+ if 0 in part:
254
+ part_motions += [torch.cat([rot[part].view(1, -1, n_frames), pos.view(1, -1, n_frames)], dim=1)]
255
+ else:
256
+ part_motions += [rot[part].view(1, -1, n_frames)]
257
+
258
+ return part_motions
259
+ def combine_part_motions(part_motions, parts_list):
260
+ assert len(part_motions) == len(parts_list)
261
+ n_frames = part_motions[0].shape[-1]
262
+ l = max(list(itertools.chain(*parts_list)))
263
+ # print(l, n_frames)
264
+ # motion = torch.zeros((1, (l+1)*6 + 3, n_frames), device=part_motions[0].device)
265
+ rot = torch.zeros(((l+1), 6, n_frames), device=part_motions[0].device)
266
+ pos = torch.zeros((1, 3, n_frames), device=part_motions[0].device)
267
+ div_rot = torch.zeros((l+1), device=part_motions[0].device)
268
+ div_pos = torch.zeros(1, device=part_motions[0].device)
269
+
270
+ for part_motion, part in zip(part_motions, parts_list):
271
+ part = [i -1 for i in part]
272
+
273
+ if 0 in part:
274
+ # print(part_motion.shape)
275
+ pos += part_motion[:, -3:, :]
276
+ div_pos += 1
277
+ rot[part] += part_motion[:, :-3, :].view(-1, 6, n_frames)
278
+ div_rot[part] += 1
279
+ else:
280
+ rot[part] += part_motion.view(-1, 6, n_frames)
281
+ div_rot[part] += 1
282
+
283
+ # print(div_rot, div_pos)
284
+ # print(rot.shape)
285
+ rot = (rot.permute(1, 2, 0) / div_rot).permute(2, 0, 1)
286
+ pos = pos / div_pos
287
+
288
+ return torch.cat([rot.view(1, -1, n_frames), pos.view(1, 3, n_frames)], dim=1)
289
+
290
+ # raw_synthesized = synthesized
291
+ # print(synthesized, synthesized.shape)
292
+ synthesized_part_motions = extract_part_motions(synthesized, ext['parts_list'])
293
+ targets_part_motions = [extract_part_motions(target, ext['parts_list']) for target in targets]
294
+
295
+ synthesized = []
296
+ for _j in range(len(synthesized_part_motions)):
297
+ synthesized_part_motion = synthesized_part_motions[_j]
298
+ # synthesized += [synthesized_part_motion]
299
+ targets_part_motion = [target[_j] for target in targets_part_motions]
300
+ # # print(synthesized_part_motion.shape, targets_part_motion[0].shape)
301
+ synthesized += [criteria(synthesized_part_motion, targets_part_motion, ext=ext, return_blended_results=True)[0]]
302
+
303
+ # print(len(synthesized))
304
+
305
+ synthesized = combine_part_motions(synthesized, ext['parts_list'])
306
+ # print(synthesized, synthesized.shape)
307
+ # print((raw_synthesized-synthesized > 0.00001).sum())
308
+ # exit()
309
+ # print(synthesized.shape)
310
+ losses = 0
311
+
312
+ # exit()
313
+
314
+ else:
315
+ synthesized, loss = criteria(synthesized, targets, ext=ext, return_blended_results=True)
316
+
317
+ # Update staus
318
+ losses.append(loss.item())
319
+ if 'pbar' in ext.keys():
320
+ ext['pbar'].step()
321
+ ext['pbar'].print()
322
+
323
+ return synthesized, losses
324
+
NN/losses.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .utils import extract_patches, combine_patches, efficient_cdist, get_NNs_Dists
5
+
6
+ def make_criteria(conf):
7
+ if conf['type'] == 'PatchCoherentLoss':
8
+ return PatchCoherentLoss(conf['patch_size'], stride=conf['stride'], loop=conf['loop'], coherent_alpha=conf['coherent_alpha'])
9
+ elif conf['type'] == 'SWDLoss':
10
+ raise NotImplementedError('SWDLoss is not implemented')
11
+ else:
12
+ raise ValueError('Invalid criteria: {}'.format(conf['criteria']))
13
+
14
+ class PatchCoherentLoss(torch.nn.Module):
15
+ def __init__(self, patch_size=7, stride=1, loop=False, coherent_alpha=None, cache=False):
16
+ super(PatchCoherentLoss, self).__init__()
17
+ self.patch_size = patch_size
18
+ self.stride = stride
19
+ self.loop = loop
20
+ self.coherent_alpha = coherent_alpha
21
+ assert self.stride == 1, "Only support stride of 1"
22
+ # assert self.patch_size % 2 == 1, "Only support odd patch size"
23
+ self.cache = cache
24
+ if cache:
25
+ self.cached_data = None
26
+
27
+ def forward(self, X, Ys, dist_wrapper=None, ext=None, return_blended_results=False):
28
+ """For each patch in input X find its NN in target Y and sum the their distances"""
29
+ assert X.shape[0] == 1, "Only support batch size of 1"
30
+ dist_fn = lambda X, Y: dist_wrapper(efficient_cdist, X, Y) if dist_wrapper is not None else efficient_cdist(X, Y)
31
+
32
+ x_patches = extract_patches(X, self.patch_size, self.stride, loop=self.loop)
33
+
34
+ if not self.cache or self.cached_data is None:
35
+ y_patches = []
36
+ for y in Ys:
37
+ y_patches += [extract_patches(y, self.patch_size, self.stride, loop=False)]
38
+ y_patches = torch.cat(y_patches, dim=1)
39
+ self.cached_data = y_patches
40
+ else:
41
+ y_patches = self.cached_data
42
+
43
+ nnf, dist = get_NNs_Dists(dist_fn, x_patches.squeeze(0), y_patches.squeeze(0), self.coherent_alpha)
44
+
45
+ if return_blended_results:
46
+ return combine_patches(X.shape, y_patches[:, nnf, :], self.patch_size, self.stride, loop=self.loop), dist.mean()
47
+ else:
48
+ return dist.mean()
49
+
50
+ def clean_cache(self):
51
+ self.cached_data = None
NN/utils.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import unfoldNd
4
+
5
+ def extract_patches(x, patch_size, stride, loop=False):
6
+ """Extract patches from a motion sequence"""
7
+ b, c, _t = x.shape
8
+
9
+ # manually padding to loop
10
+ if loop:
11
+ half = patch_size // 2
12
+ front, tail = x[:,:,:half], x[:,:,-half:]
13
+ x = torch.concat([tail, x, front], dim=-1)
14
+
15
+ x_patches = unfoldNd.unfoldNd(x, kernel_size=patch_size, stride=stride).transpose(1, 2).reshape(b, -1, c, patch_size)
16
+
17
+ return x_patches.view(b, -1, c * patch_size)
18
+
19
+ def combine_patches(x_shape, ys, patch_size, stride, loop=False):
20
+ """Combine motion patches"""
21
+ # manually handle to loop
22
+ out_shape = [*x_shape]
23
+ if loop:
24
+ padding = patch_size // 2
25
+ out_shape[-1] = out_shape[-1] + padding * 2
26
+
27
+ combined = unfoldNd.foldNd(ys.permute(0, 2, 1), output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride)
28
+
29
+ # normal fold matrix
30
+ input_ones = torch.ones(tuple(out_shape), dtype=ys.dtype, device=ys.device)
31
+ divisor = unfoldNd.unfoldNd(input_ones, kernel_size=patch_size, stride=stride)
32
+ divisor = unfoldNd.foldNd(divisor, output_size=tuple(out_shape[-1:]), kernel_size=patch_size, stride=stride)
33
+ combined = (combined / divisor).squeeze(dim=0).unsqueeze(0)
34
+
35
+ if loop:
36
+ half = patch_size // 2
37
+ front, tail = combined[:,:,:half], combined[:,:,-half:]
38
+ combined[:, :, half:2 * half] = (combined[:, :, half:2 * half] + tail) / 2
39
+ combined[:, :, - 2 * half:-half] = (front + combined[:, :, - 2 * half:-half]) / 2
40
+ combined = combined[:, :, half:-half]
41
+
42
+ return combined
43
+
44
+
45
+ def efficient_cdist(X, Y):
46
+ """
47
+ Pytorch efficient way of computing distances between all vectors in X and Y, i.e (X[:, None] - Y[None, :])**2
48
+ Get the nearest neighbor index from Y for each X
49
+ :param X: (n1, d) tensor
50
+ :param Y: (n2, d) tensor
51
+ Returns a n2 n1 of indices
52
+ """
53
+ dist = (X * X).sum(1)[:, None] + (Y * Y).sum(1)[None, :] - 2.0 * torch.mm(X, torch.transpose(Y, 0, 1))
54
+ d = X.shape[1]
55
+ dist /= d # normalize by size of vector to make dists independent of the size of d ( use same alpha for all patche-sizes)
56
+ return dist # DO NOT use torch.sqrt
57
+
58
+
59
+ def get_col_mins_efficient(dist_fn, X, Y, b=1024):
60
+ """
61
+ Computes the l2 distance to the closest x or each y.
62
+ :param X: (n1, d) tensor
63
+ :param Y: (n2, d) tensor
64
+ Returns n1 long array of L2 distances
65
+ """
66
+ n_batches = len(Y) // b
67
+ mins = torch.zeros(Y.shape[0], dtype=X.dtype, device=X.device)
68
+ for i in range(n_batches):
69
+ mins[i * b:(i + 1) * b] = dist_fn(X, Y[i * b:(i + 1) * b]).min(0)[0]
70
+ if len(Y) % b != 0:
71
+ mins[n_batches * b:] = dist_fn(X, Y[n_batches * b:]).min(0)[0]
72
+
73
+ return mins
74
+
75
+
76
+ def get_NNs_Dists(dist_fn, X, Y, alpha=None, b=1024):
77
+ """
78
+ Get the nearest neighbor index from Y for each X.
79
+ Avoids holding a (n1 * n2) amtrix in order to reducing memory footprint to (b * max(n1,n2)).
80
+ :param X: (n1, d) tensor
81
+ :param Y: (n2, d) tensor
82
+ Returns a n2 n1 of indices amd distances
83
+ """
84
+ if alpha is not None:
85
+ normalizing_row = get_col_mins_efficient(dist_fn, X, Y, b=b)
86
+ normalizing_row = alpha + normalizing_row[None, :]
87
+ else:
88
+ normalizing_row = 1
89
+
90
+ NNs = torch.zeros(X.shape[0], dtype=torch.long, device=X.device)
91
+ Dists = torch.zeros(X.shape[0], dtype=torch.float, device=X.device)
92
+
93
+ n_batches = len(X) // b
94
+ for i in range(n_batches):
95
+ dists = dist_fn(X[i * b:(i + 1) * b], Y) / normalizing_row
96
+ NNs[i * b:(i + 1) * b] = dists.min(1)[1]
97
+ Dists[i * b:(i + 1) * b] = dists.min(1)[0]
98
+ if len(X) % b != 0:
99
+ dists = dist_fn(X[n_batches * b:], Y) / normalizing_row
100
+ NNs[n_batches * b:] = dists.min(1)[1]
101
+ Dists[n_batches * b: ] = dists.min(1)[0]
102
+
103
+ return NNs, Dists
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: GenMM
3
+ emoji: 🌍
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.33.1
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: wyysf/GenMM-test
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+
4
+ from dataset.tracks_motion import TracksMotion
5
+ from GPS import GPS
6
+ import gradio as gr
7
+
8
+ def _synthesis(synthesis_setting, motion_data):
9
+ model = GPS(
10
+ init_mode = f"random_synthesis/{synthesis_setting['frames']}",
11
+ noise_sigma = synthesis_setting['noise_sigma'],
12
+ coarse_ratio = 0.2,
13
+ pyr_factor = synthesis_setting['pyr_factor'],
14
+ num_stages_limit = -1,
15
+ silent=True,
16
+ device='cpu'
17
+ )
18
+
19
+ synthesized_motion = model.run(
20
+ motion_data,
21
+ mode="match_and_blend",
22
+ ext={
23
+ 'criteria': {
24
+ 'type': 'PatchCoherentLoss',
25
+ 'patch_size': synthesis_setting['patch_size'],
26
+ 'stride': synthesis_setting['stride'] if 'stride' in synthesis_setting.keys() else 1,
27
+ 'loop': synthesis_setting['loop'],
28
+ 'coherent_alpha': synthesis_setting['alpha'] if synthesis_setting['completeness'] else None,
29
+ },
30
+ 'optimizer': "match_and_blend",
31
+ 'num_itrs': synthesis_setting['num_steps'],
32
+ }
33
+ )
34
+
35
+ return synthesized_motion
36
+
37
+ def synthesis(data):
38
+ data = json.loads(data)
39
+ # create track object
40
+ data['setting']['coarse_ratio'] = -1
41
+ motion_data = TracksMotion(data['tracks'], scale=data['scale'])
42
+ start = time.time()
43
+ synthesized_motion = _synthesis(
44
+ data['setting'],
45
+ [motion_data]
46
+ )
47
+ end = time.time()
48
+ data['time'] = end - start
49
+ data['tracks'] = motion_data.parse(synthesized_motion)
50
+
51
+ return data
52
+
53
+ demo = gr.Interface(fn=synthesis, inputs="json", outputs="json")
54
+ demo.launch()
configs/random_synthesis.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ outout_dir: './output/random_synthesis'
2
+
3
+ # for GANimator BVH data
4
+ skeleton_aware: true
5
+ use_velo: true
6
+ repr: 'repr6d'
7
+ contact: true
8
+ keep_y_pos: true
9
+ joint_reduction: true
10
+
11
+
12
+ # for synthesis
13
+ coarse_ratio: -1
14
+ coarse_ratio_factor: 10
15
+ pyr_factor: 0.75
16
+ num_stages_limit: -1
17
+ noise_sigma: 10.0
18
+ patch_size: 11
19
+ loop: false
20
+ loss_type: 'PatchCoherent'
21
+ coherent_alpha: 0.01
22
+ optimizer: 'RMSprop'
23
+ lr: 0.01
24
+ num_steps: 3
25
+ decay_rate: 0.9
26
+ decay_steps: 0.9
27
+
28
+ # for visualization (only for blender render)
29
+ visualization: true
30
+ fbx_path: null
31
+ reso: '[1920, 1080]'
32
+ samples: 64
33
+ fps: 30
34
+ frame_end: -1
35
+ camera_pos: '[0, -8, 2.5]'
36
+ target_pos: '[0, 2, 0.5]'
dataset/.DS_Store ADDED
Binary file (6.15 kB). View file
 
dataset/tracks_motion.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import join as pjoin
3
+ import numpy as np
4
+ import copy
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from utils.transforms import quat2repr6d, quat2euler, repr6d2quat
8
+
9
+ class TracksParser():
10
+ def __init__(self, tracks_json, scale=1.0, requires_contact=False, joint_reduction=False):
11
+ assert requires_contact==False, 'contact is not implemented for tracks data yet!!!'
12
+
13
+ self.tracks_json = tracks_json
14
+ self.scale = scale
15
+ self.requires_contact = requires_contact
16
+ self.joint_reduction = joint_reduction
17
+
18
+ self.skeleton_names = []
19
+ self.rotations = []
20
+ for i, track in enumerate(self.tracks_json):
21
+ # print(i, track['name'])
22
+ self.skeleton_names.append(track['name'])
23
+ if i == 0:
24
+ assert track['type'] == 'vector'
25
+ self.position = np.array(track['values']).reshape(-1, 3) * self.scale
26
+ self.num_frames = self.position.shape[0]
27
+ else:
28
+ assert track['type'] == 'quaternion' # DEAFULT: quaternion
29
+ rotation = np.array(track['values']).reshape(-1, 4)
30
+ if rotation.shape[0] == 0:
31
+ rotation = np.zeros((self.num_frames, 4))
32
+ elif rotation.shape[0] < self.num_frames:
33
+ rotation = np.repeat(rotation, self.num_frames // rotation.shape[0], axis=0)
34
+ elif rotation.shape[0] > self.num_frames:
35
+ rotation = rotation[:self.num_frames]
36
+ self.rotations += [rotation]
37
+ self.rotations = np.array(self.rotations, dtype=np.float32)
38
+
39
+ def to_tensor(self, repr='euler', rot_only=False):
40
+ if repr not in ['euler', 'quat', 'quaternion', 'repr6d']:
41
+ raise Exception('Unknown rotation representation')
42
+ rotations = self.get_rotation(repr=repr)
43
+ positions = self.get_position()
44
+
45
+ if rot_only:
46
+ return rotations.reshape(rotations.shape[0], -1)
47
+
48
+ if self.requires_contact:
49
+ virtual_contact = torch.zeros_like(rotations[:, :len(self.skeleton.contact_id)])
50
+ virtual_contact[..., 0] = self.contact_label
51
+ rotations = torch.cat([rotations, virtual_contact], dim=1)
52
+
53
+ rotations = rotations.reshape(rotations.shape[0], -1)
54
+ return torch.cat((rotations, positions), dim=-1)
55
+
56
+ def get_rotation(self, repr='quat'):
57
+ if repr == 'quaternion' or repr == 'quat' or repr == 'repr6d':
58
+ rotations = torch.tensor(self.rotations, dtype=torch.float).transpose(0, 1)
59
+ if repr == 'repr6d':
60
+ rotations = quat2repr6d(rotations)
61
+ if repr == 'euler':
62
+ rotations = quat2euler(rotations)
63
+ return rotations
64
+
65
+ def get_position(self):
66
+ return torch.tensor(self.position, dtype=torch.float32)
67
+
68
+ class TracksMotion:
69
+ def __init__(self, tracks_json, scale=1.0, repr='repr6d', padding=False,
70
+ use_velo=True, contact=False, keep_y_pos=True, joint_reduction=False):
71
+ self.scale = scale
72
+ self.tracks = TracksParser(tracks_json, scale, requires_contact=contact, joint_reduction=joint_reduction)
73
+ self.raw_motion = self.tracks.to_tensor(repr=repr)
74
+ self.extra = {
75
+
76
+ }
77
+
78
+ self.repr = repr
79
+ if repr == 'quat':
80
+ self.n_rot = 4
81
+ elif repr == 'repr6d':
82
+ self.n_rot = 6
83
+ elif repr == 'euler':
84
+ self.n_rot = 3
85
+ self.padding = padding
86
+ self.use_velo = use_velo
87
+ self.contact = contact
88
+ self.keep_y_pos = keep_y_pos
89
+ self.joint_reduction = joint_reduction
90
+
91
+ self.raw_motion = self.raw_motion.permute(1, 0).unsqueeze_(0) # Shape = (1, n_channel, n_frames)
92
+ self.extra['global_pos'] = self.raw_motion[:, -3:, :]
93
+
94
+ if padding:
95
+ self.n_pad = self.n_rot - 3 # pad position channels
96
+ paddings = torch.zeros_like(self.raw_motion[:, :self.n_pad])
97
+ self.raw_motion = torch.cat((self.raw_motion, paddings), dim=1)
98
+ else:
99
+ self.n_pad = 0
100
+ self.raw_motion = torch.cat((self.raw_motion[:, :-3-self.n_pad], self.raw_motion[:, -3-self.n_pad:]), dim=1)
101
+
102
+ if self.use_velo:
103
+ self.msk = [-3, -2, -1] if not keep_y_pos else [-3, -1]
104
+ self.raw_motion = self.pos2velo(self.raw_motion)
105
+
106
+ self.n_contact = len(self.tracks.skeleton.contact_id) if contact else 0
107
+
108
+ @property
109
+ def n_channels(self):
110
+ return self.raw_motion.shape[1]
111
+
112
+ def __len__(self):
113
+ return self.raw_motion.shape[-1]
114
+
115
+ def pos2velo(self, pos):
116
+ msk = [i - self.n_pad for i in self.msk]
117
+ velo = pos.detach().clone().to(pos.device)
118
+ velo[:, msk, 1:] = pos[:, msk, 1:] - pos[:, msk, :-1]
119
+ self.begin_pos = pos[:, msk, 0].clone()
120
+ velo[:, msk, 0] = pos[:, msk, 1]
121
+ return velo
122
+
123
+ def velo2pos(self, velo):
124
+ msk = [i - self.n_pad for i in self.msk]
125
+ pos = velo.detach().clone().to(velo.device)
126
+ pos[:, msk, 0] = self.begin_pos.to(velo.device)
127
+ pos[:, msk] = torch.cumsum(velo[:, msk], dim=-1)
128
+ return pos
129
+
130
+ def motion2pos(self, motion):
131
+ if not self.use_velo:
132
+ return motion
133
+ else:
134
+ self.velo2pos(motion.clone())
135
+
136
+ def sample(self, size=None, slerp=False, align_corners=False):
137
+ if size is None:
138
+ return {'motion': self.raw_motion, 'extra': self.extra}
139
+ else:
140
+ if slerp:
141
+ raise NotImplementedError('slerp is not not implemented yet!!!')
142
+ else:
143
+ motion = F.interpolate(self.raw_motion, size=size, mode='linear', align_corners=align_corners)
144
+ extra = {}
145
+ if 'global_pos' in self.extra.keys():
146
+ extra['global_pos'] = F.interpolate(self.extra['global_pos'], size=size, mode='linear', align_corners=align_corners)
147
+
148
+ return motion
149
+ # return {'motion': motion, 'extra': extra}
150
+
151
+ def parse(self, motion, keep_velo=False,):
152
+ """
153
+ No batch support here!!!
154
+ :returns tracks_json
155
+ """
156
+ motion = motion.clone()
157
+
158
+ if self.use_velo and not keep_velo:
159
+ motion = self.velo2pos(motion)
160
+ if self.n_pad:
161
+ motion = motion[:, :-self.n_pad]
162
+ if self.contact:
163
+ raise NotImplementedError('contact is not implemented yet!!!')
164
+
165
+ motion = motion.squeeze().permute(1, 0)
166
+ pos = motion[..., -3:] / self.scale
167
+ rot = motion[..., :-3].reshape(motion.shape[0], -1, self.n_rot)
168
+ if self.repr == 'repr6d':
169
+ rot = repr6d2quat(rot)
170
+ elif self.repr == 'euler':
171
+ raise NotImplementedError('parse "euler is not implemented yet!!!')
172
+
173
+ times = []
174
+ out_tracks_json = copy.deepcopy(self.tracks.tracks_json)
175
+ for i, _track in enumerate(out_tracks_json):
176
+ if i == 0:
177
+ times = [ j * out_tracks_json[i]['times'][1] for j in range(motion.shape[0])]
178
+ out_tracks_json[i]['values'] = pos.flatten().detach().cpu().numpy().tolist()
179
+ else:
180
+ out_tracks_json[i]['values'] = rot[:, i-1, :].flatten().detach().cpu().numpy().tolist()
181
+ out_tracks_json[i]['times'] = times
182
+
183
+ return out_tracks_json
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ filterpy==1.4.5
2
+ torchvision==0.12.0
3
+ tensorboardX==2.5
4
+ protobuf==3.20.1
5
+ scipy==1.7.3
6
+ tqdm==4.62.3
7
+ unfoldNd
8
+ flask==2.1.3
9
+ flask-cors==3.0.10
10
+ pyyaml>=5.3.1
11
+ requests
12
+ tensorboard
13
+ transforms3d
14
+ imageio
15
+ imageio-ffmpeg
utils/.DS_Store ADDED
Binary file (6.15 kB). View file
 
utils/base.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import sys
4
+ import time
5
+ import yaml
6
+ import imageio
7
+ import random
8
+ import shutil
9
+ import random
10
+ import numpy as np
11
+ import torch
12
+ from tqdm import tqdm
13
+ import matplotlib.pyplot as plt
14
+
15
+ class ConfigParser():
16
+ def __init__(self, args):
17
+ """
18
+ class to parse configuration.
19
+ """
20
+ args = args.parse_args()
21
+ self.cfg = self.merge_config_file(args)
22
+
23
+ # set random seed
24
+ self.set_seed()
25
+
26
+ def __str__(self):
27
+ return str(self.cfg.__dict__)
28
+
29
+ def __getattr__(self, name):
30
+ """
31
+ Access items use dot.notation.
32
+ """
33
+ return self.cfg.__dict__[name]
34
+
35
+ def __getitem__(self, name):
36
+ """
37
+ Access items like ordinary dict.
38
+ """
39
+ return self.cfg.__dict__[name]
40
+
41
+ def merge_config_file(self, args, allow_invalid=True):
42
+ """
43
+ Load json config file and merge the arguments
44
+ """
45
+ assert args.config is not None
46
+ with open(args.config, 'r') as f:
47
+ cfg = yaml.safe_load(f)
48
+ if 'config' in cfg.keys():
49
+ del cfg['config']
50
+ f.close()
51
+ invalid_args = list(set(cfg.keys()) - set(dir(args)))
52
+ if invalid_args and not allow_invalid:
53
+ raise ValueError(f"Invalid args {invalid_args} in {args.config}.")
54
+
55
+ for k in list(cfg.keys()):
56
+ if k in args.__dict__.keys() and args.__dict__[k] is not None:
57
+ print('=========> overwrite config: {} = {}'.format(k, args.__dict__[k]))
58
+ del cfg[k]
59
+
60
+ args.__dict__.update(cfg)
61
+
62
+ return args
63
+
64
+ def set_seed(self):
65
+ ''' set random seed for random, numpy and torch. '''
66
+ if 'seed' not in self.cfg.__dict__.keys():
67
+ return
68
+ if self.cfg.seed is None:
69
+ self.cfg.seed = int(time.time()) % 1000000
70
+ print('=========> set random seed: {}'.format(self.cfg.seed))
71
+ # fix random seeds for reproducibility
72
+ random.seed(self.cfg.seed)
73
+ np.random.seed(self.cfg.seed)
74
+ torch.manual_seed(self.cfg.seed)
75
+ torch.cuda.manual_seed(self.cfg.seed)
76
+
77
+ def save_codes_and_config(self, save_path):
78
+ """
79
+ save codes and config to $save_path.
80
+ """
81
+ cur_codes_path = osp.dirname(osp.dirname(os.path.abspath(__file__)))
82
+ if os.path.exists(save_path):
83
+ shutil.rmtree(save_path)
84
+ shutil.copytree(cur_codes_path, osp.join(save_path, 'codes'), \
85
+ ignore=shutil.ignore_patterns('*debug*', '*data*', '*output*', '*exps*', '*.txt', '*.json', '*.mp4', '*.png', '*.jpg', '*.bvh', '*.csv', '*.pth', '*.tar', '*.npz'))
86
+
87
+ with open(osp.join(save_path, 'config.yaml'), 'w') as f:
88
+ f.write(yaml.dump(self.cfg.__dict__))
89
+ f.close()
90
+
91
+
92
+ # other utils
93
+ class logger:
94
+ """Keeps track of the levels and steps of optimization. Logs it via TQDM"""
95
+ def __init__(self, n_steps, n_lvls):
96
+ self.n_steps = n_steps
97
+ self.n_lvls = n_lvls
98
+ self.lvl = -1
99
+ self.lvl_step = 0
100
+ self.steps = 0
101
+ self.pbar = tqdm(total=self.n_lvls * self.n_steps, desc='Starting')
102
+
103
+ def step(self):
104
+ self.pbar.update(1)
105
+ self.steps += 1
106
+ self.lvl_step += 1
107
+
108
+ def new_lvl(self):
109
+ self.lvl += 1
110
+ self.lvl_step = 0
111
+
112
+ def print(self):
113
+ self.pbar.set_description(f'Lvl {self.lvl}/{self.n_lvls-1}, step {self.lvl_step}/{self.n_steps}')
114
+
115
+
116
+ def set_seed(seed):
117
+ if seed is not None:
118
+ random.seed(seed)
119
+ np.random.seed(seed)
120
+ torch.manual_seed(seed)
121
+ torch.cuda.manual_seed(seed)
122
+
123
+
124
+ # debug utils
125
+ def draw_trajectory(trajectory, save_path=None, anim=True):
126
+ r = max(abs(trajectory.min()), trajectory.max())
127
+ if anim:
128
+ imgs = []
129
+ for i in tqdm(range(1, trajectory.shape[0])):
130
+ plt.plot(trajectory[:i, 0], trajectory[:i, 2], color='red')
131
+ plt.xlim(-r-1, r+1)
132
+ plt.ylim(-r-1, r+1)
133
+ plt.savefig(save_path + '.png')
134
+ imgs += [imageio.imread(save_path + '.png')]
135
+ imageio.mimwrite(save_path + '.mp4', imgs)
136
+ plt.close()
137
+ else:
138
+ # plt.scatter(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2])
139
+ plt.plot(trajectory[:, 0], trajectory[:, 2], color='red')
140
+ plt.xlim(-r*1.5, r*1.5)
141
+ plt.ylim(-r*1.5, r*1.5)
142
+ if save_path is not None:
143
+ plt.savefig(save_path + '.png')
144
+ plt.close()
145
+
146
+ # velo = self.raw_motion[0, self.mask, :].numpy()
147
+ # print(velo.shape)
148
+ # imgs = []
utils/contact.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def foot_contact_by_height(pos):
5
+ eps = 0.25
6
+ return (-eps < pos[..., 1]) * (pos[..., 1] < eps)
7
+
8
+
9
+ def velocity(pos, padding=False):
10
+ velo = pos[1:, ...] - pos[:-1, ...]
11
+ velo_norm = torch.norm(velo, dim=-1)
12
+ if padding:
13
+ pad = torch.zeros_like(velo_norm[:1, :])
14
+ velo_norm = torch.cat([pad, velo_norm], dim=0)
15
+ return velo_norm
16
+
17
+
18
+ def foot_contact(pos, ref_height=1., threshold=0.018):
19
+ velo_norm = velocity(pos)
20
+ contact = velo_norm < threshold
21
+ contact = contact.int()
22
+ padding = torch.zeros_like(contact)
23
+ contact = torch.cat([padding[:1, :], contact], dim=0)
24
+ return contact
25
+
26
+
27
+ def alpha(t):
28
+ return 2.0 * t * t * t - 3.0 * t * t + 1
29
+
30
+
31
+ def lerp(a, l, r):
32
+ return (1 - a) * l + a * r
33
+
34
+
35
+ def constrain_from_contact(contact, glb, fid='TBD', L=5):
36
+ """
37
+ :param contact: contact label
38
+ :param glb: original global position
39
+ :param fid: joint id to fix, corresponding to the order in contact
40
+ :param L: frame to look forward/backward
41
+ :return:
42
+ """
43
+ T = glb.shape[0]
44
+
45
+ for i, fidx in enumerate(fid): # fidx: index of the foot joint
46
+ fixed = contact[:, i] # [T]
47
+ s = 0
48
+ while s < T:
49
+ while s < T and fixed[s] == 0:
50
+ s += 1
51
+ if s >= T:
52
+ break
53
+ t = s
54
+ avg = glb[t, fidx].clone()
55
+ while t + 1 < T and fixed[t + 1] == 1:
56
+ t += 1
57
+ avg += glb[t, fidx].clone()
58
+ avg /= (t - s + 1)
59
+
60
+ for j in range(s, t + 1):
61
+ glb[j, fidx] = avg.clone()
62
+ s = t + 1
63
+
64
+ for s in range(T):
65
+ if fixed[s] == 1:
66
+ continue
67
+ l, r = None, None
68
+ consl, consr = False, False
69
+ for k in range(L):
70
+ if s - k - 1 < 0:
71
+ break
72
+ if fixed[s - k - 1]:
73
+ l = s - k - 1
74
+ consl = True
75
+ break
76
+ for k in range(L):
77
+ if s + k + 1 >= T:
78
+ break
79
+ if fixed[s + k + 1]:
80
+ r = s + k + 1
81
+ consr = True
82
+ break
83
+ if not consl and not consr:
84
+ continue
85
+ if consl and consr:
86
+ litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),
87
+ glb[s, fidx], glb[l, fidx])
88
+ ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),
89
+ glb[s, fidx], glb[r, fidx])
90
+ itp = lerp(alpha(1.0 * (s - l + 1) / (r - l + 1)),
91
+ ritp, litp)
92
+ glb[s, fidx] = itp.clone()
93
+ continue
94
+ if consl:
95
+ litp = lerp(alpha(1.0 * (s - l + 1) / (L + 1)),
96
+ glb[s, fidx], glb[l, fidx])
97
+ glb[s, fidx] = litp.clone()
98
+ continue
99
+ if consr:
100
+ ritp = lerp(alpha(1.0 * (r - s + 1) / (L + 1)),
101
+ glb[s, fidx], glb[r, fidx])
102
+ glb[s, fidx] = ritp.clone()
103
+ return glb
utils/kinematics.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils.transforms import quat2mat, repr6d2mat, euler2mat
3
+
4
+
5
+ class ForwardKinematics:
6
+ def __init__(self, parents, offsets=None):
7
+ self.parents = parents
8
+ if offsets is not None and len(offsets.shape) == 2:
9
+ offsets = offsets.unsqueeze(0)
10
+ self.offsets = offsets
11
+
12
+ def forward(self, rots, offsets=None, global_pos=None):
13
+ """
14
+ Forward Kinematics: returns a per-bone transformation
15
+ @param rots: local joint rotations (batch_size, bone_num, 3, 3)
16
+ @param offsets: (batch_size, bone_num, 3) or None
17
+ @param global_pos: global_position: (batch_size, 3) or keep it as in offsets (default)
18
+ @return: (batch_szie, bone_num, 3, 4)
19
+ """
20
+ rots = rots.clone()
21
+ if offsets is None:
22
+ offsets = self.offsets.to(rots.device)
23
+ if global_pos is None:
24
+ global_pos = offsets[:, 0]
25
+
26
+ pos = torch.zeros((rots.shape[0], rots.shape[1], 3), device=rots.device)
27
+ rest_pos = torch.zeros_like(pos)
28
+ res = torch.zeros((rots.shape[0], rots.shape[1], 3, 4), device=rots.device)
29
+
30
+ pos[:, 0] = global_pos
31
+ rest_pos[:, 0] = offsets[:, 0]
32
+
33
+ for i, p in enumerate(self.parents):
34
+ if i != 0:
35
+ rots[:, i] = torch.matmul(rots[:, p], rots[:, i])
36
+ pos[:, i] = torch.matmul(rots[:, p], offsets[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, p]
37
+ rest_pos[:, i] = rest_pos[:, p] + offsets[:, i]
38
+
39
+ res[:, i, :3, :3] = rots[:, i]
40
+ res[:, i, :, 3] = torch.matmul(rots[:, i], -rest_pos[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, i]
41
+
42
+ return res
43
+
44
+ def accumulate(self, local_rots):
45
+ """
46
+ Get global joint rotation from local rotations
47
+ @param local_rots: (batch_size, n_bone, 3, 3)
48
+ @return: global_rotations
49
+ """
50
+ res = torch.empty_like(local_rots)
51
+ for i, p in enumerate(self.parents):
52
+ if i == 0:
53
+ res[:, i] = local_rots[:, i]
54
+ else:
55
+ res[:, i] = torch.matmul(res[:, p], local_rots[:, i])
56
+ return res
57
+
58
+ def unaccumulate(self, global_rots):
59
+ """
60
+ Get local joint rotation from global rotations
61
+ @param global_rots: (batch_size, n_bone, 3, 3)
62
+ @return: local_rotations
63
+ """
64
+ res = torch.empty_like(global_rots)
65
+ inv = torch.empty_like(global_rots)
66
+
67
+ for i, p in enumerate(self.parents):
68
+ if i == 0:
69
+ inv[:, i] = global_rots[:, i].transpose(-2, -1)
70
+ res[:, i] = global_rots[:, i]
71
+ continue
72
+ res[:, i] = torch.matmul(inv[:, p], global_rots[:, i])
73
+ inv[:, i] = torch.matmul(res[:, i].transpose(-2, -1), inv[:, p])
74
+
75
+ return res
76
+
77
+
78
+ class ForwardKinematicsJoint:
79
+ def __init__(self, parents, offset):
80
+ self.parents = parents
81
+ self.offset = offset
82
+
83
+ '''
84
+ rotation should have shape batch_size * Joint_num * (3/4) * Time
85
+ position should have shape batch_size * 3 * Time
86
+ offset should have shape batch_size * Joint_num * 3
87
+ output have shape batch_size * Time * Joint_num * 3
88
+ '''
89
+
90
+ def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset=None,
91
+ world=True):
92
+ '''
93
+ if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation')
94
+ if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation')
95
+ rotation = rotation.permute(0, 3, 1, 2)
96
+ position = position.permute(0, 2, 1)
97
+ '''
98
+ if rotation.shape[-1] == 6:
99
+ transform = repr6d2mat(rotation)
100
+ elif rotation.shape[-1] == 4:
101
+ norm = torch.norm(rotation, dim=-1, keepdim=True)
102
+ rotation = rotation / norm
103
+ transform = quat2mat(rotation)
104
+ elif rotation.shape[-1] == 3:
105
+ transform = euler2mat(rotation)
106
+ else:
107
+ raise Exception('Only accept quaternion rotation input')
108
+ result = torch.empty(transform.shape[:-2] + (3,), device=position.device)
109
+
110
+ if offset is None:
111
+ offset = self.offset
112
+ offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1))
113
+
114
+ result[..., 0, :] = position
115
+ for i, pi in enumerate(self.parents):
116
+ if pi == -1:
117
+ assert i == 0
118
+ continue
119
+
120
+ result[..., i, :] = torch.matmul(transform[..., pi, :, :], offset[..., i, :, :]).squeeze()
121
+ transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone())
122
+ if world: result[..., i, :] += result[..., pi, :]
123
+ return result
124
+
125
+
126
+ class InverseKinematicsJoint:
127
+ def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains):
128
+ self.rotations = rotations.detach().clone()
129
+ self.rotations.requires_grad_(True)
130
+ self.position = positions.detach().clone()
131
+ self.position.requires_grad_(True)
132
+
133
+ self.parents = parents
134
+ self.offset = offset
135
+ self.constrains = constrains
136
+
137
+ self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999))
138
+ self.criteria = torch.nn.MSELoss()
139
+
140
+ self.fk = ForwardKinematicsJoint(parents, offset)
141
+
142
+ self.glb = None
143
+
144
+ def step(self):
145
+ self.optimizer.zero_grad()
146
+ glb = self.fk.forward(self.rotations, self.position)
147
+ loss = self.criteria(glb, self.constrains)
148
+ loss.backward()
149
+ self.optimizer.step()
150
+ self.glb = glb
151
+ return loss.item()
152
+
153
+
154
+ class InverseKinematicsJoint2:
155
+ def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains, cid,
156
+ lambda_rec_rot=1., lambda_rec_pos=1., use_velo=False):
157
+ self.use_velo = use_velo
158
+ self.rotations_ori = rotations.detach().clone()
159
+ self.rotations = rotations.detach().clone()
160
+ self.rotations.requires_grad_(True)
161
+ self.position_ori = positions.detach().clone()
162
+ self.position = positions.detach().clone()
163
+ if self.use_velo:
164
+ self.position[1:] = self.position[1:] - self.position[:-1]
165
+ self.position.requires_grad_(True)
166
+
167
+ self.parents = parents
168
+ self.offset = offset
169
+ self.constrains = constrains.detach().clone()
170
+ self.cid = cid
171
+
172
+ self.lambda_rec_rot = lambda_rec_rot
173
+ self.lambda_rec_pos = lambda_rec_pos
174
+
175
+ self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999))
176
+ self.criteria = torch.nn.MSELoss()
177
+
178
+ self.fk = ForwardKinematicsJoint(parents, offset)
179
+
180
+ self.glb = None
181
+
182
+ def step(self):
183
+ self.optimizer.zero_grad()
184
+ if self.use_velo:
185
+ position = torch.cumsum(self.position, dim=0)
186
+ else:
187
+ position = self.position
188
+ glb = self.fk.forward(self.rotations, position)
189
+ self.constrain_loss = self.criteria(glb[:, self.cid], self.constrains)
190
+ self.rec_loss_rot = self.criteria(self.rotations, self.rotations_ori)
191
+ self.rec_loss_pos = self.criteria(self.position, self.position_ori)
192
+ loss = self.constrain_loss + self.rec_loss_rot * self.lambda_rec_rot + self.rec_loss_pos * self.lambda_rec_pos
193
+ loss.backward()
194
+ self.optimizer.step()
195
+ self.glb = glb
196
+ return loss.item()
197
+
198
+ def get_position(self):
199
+ if self.use_velo:
200
+ position = torch.cumsum(self.position.detach(), dim=0)
201
+ else:
202
+ position = self.position.detach()
203
+ return position
utils/skeleton.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import numpy as np
6
+
7
+
8
+ class SkeletonConv(nn.Module):
9
+ def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num, stride=1, padding=0,
10
+ bias=True, padding_mode='zeros', add_offset=False, in_offset_channel=0):
11
+ super(SkeletonConv, self).__init__()
12
+
13
+ if in_channels % joint_num != 0 or out_channels % joint_num != 0:
14
+ raise Exception('in/out channels should be divided by joint_num')
15
+ self.in_channels_per_joint = in_channels // joint_num
16
+ self.out_channels_per_joint = out_channels // joint_num
17
+
18
+ if padding_mode == 'zeros': padding_mode = 'constant'
19
+
20
+ self.expanded_neighbour_list = []
21
+ self.expanded_neighbour_list_offset = []
22
+ self.neighbour_list = neighbour_list
23
+ self.add_offset = add_offset
24
+ self.joint_num = joint_num
25
+
26
+ self.stride = stride
27
+ self.dilation = 1
28
+ self.groups = 1
29
+ self.padding = padding
30
+ self.padding_mode = padding_mode
31
+ self._padding_repeated_twice = (padding, padding)
32
+
33
+ for neighbour in neighbour_list:
34
+ expanded = []
35
+ for k in neighbour:
36
+ for i in range(self.in_channels_per_joint):
37
+ expanded.append(k * self.in_channels_per_joint + i)
38
+ self.expanded_neighbour_list.append(expanded)
39
+
40
+ if self.add_offset:
41
+ self.offset_enc = SkeletonLinear(neighbour_list, in_offset_channel * len(neighbour_list), out_channels)
42
+
43
+ for neighbour in neighbour_list:
44
+ expanded = []
45
+ for k in neighbour:
46
+ for i in range(add_offset):
47
+ expanded.append(k * in_offset_channel + i)
48
+ self.expanded_neighbour_list_offset.append(expanded)
49
+
50
+ self.weight = torch.zeros(out_channels, in_channels, kernel_size)
51
+ if bias:
52
+ self.bias = torch.zeros(out_channels)
53
+ else:
54
+ self.register_parameter('bias', None)
55
+
56
+ self.mask = torch.zeros_like(self.weight)
57
+ for i, neighbour in enumerate(self.expanded_neighbour_list):
58
+ self.mask[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...] = 1
59
+ self.mask = nn.Parameter(self.mask, requires_grad=False)
60
+
61
+ self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \
62
+ 'joint_num={}, stride={}, padding={}, bias={})'.format(
63
+ in_channels // joint_num, out_channels // joint_num, kernel_size, joint_num, stride, padding, bias
64
+ )
65
+
66
+ self.reset_parameters()
67
+
68
+ def reset_parameters(self):
69
+ for i, neighbour in enumerate(self.expanded_neighbour_list):
70
+ """ Use temporary variable to avoid assign to copy of slice, which might lead to un expected result """
71
+ tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
72
+ neighbour, ...])
73
+ nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
74
+ self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1),
75
+ neighbour, ...] = tmp
76
+ if self.bias is not None:
77
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
78
+ self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...])
79
+ bound = 1 / math.sqrt(fan_in)
80
+ tmp = torch.zeros_like(
81
+ self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)])
82
+ nn.init.uniform_(tmp, -bound, bound)
83
+ self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp
84
+
85
+ self.weight = nn.Parameter(self.weight)
86
+ if self.bias is not None:
87
+ self.bias = nn.Parameter(self.bias)
88
+
89
+ def set_offset(self, offset):
90
+ if not self.add_offset: raise Exception('Wrong Combination of Parameters')
91
+ self.offset = offset.reshape(offset.shape[0], -1)
92
+
93
+ def forward(self, input):
94
+ weight_masked = self.weight * self.mask
95
+ res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
96
+ weight_masked, self.bias, self.stride,
97
+ 0, self.dilation, self.groups)
98
+
99
+ if self.add_offset:
100
+ offset_res = self.offset_enc(self.offset)
101
+ offset_res = offset_res.reshape(offset_res.shape + (1, ))
102
+ res += offset_res / 100
103
+ return res
104
+
105
+ def __repr__(self):
106
+ return self.description
107
+
108
+
109
+ class SkeletonLinear(nn.Module):
110
+ def __init__(self, neighbour_list, in_channels, out_channels, extra_dim1=False):
111
+ super(SkeletonLinear, self).__init__()
112
+ self.neighbour_list = neighbour_list
113
+ self.in_channels = in_channels
114
+ self.out_channels = out_channels
115
+ self.in_channels_per_joint = in_channels // len(neighbour_list)
116
+ self.out_channels_per_joint = out_channels // len(neighbour_list)
117
+ self.extra_dim1 = extra_dim1
118
+ self.expanded_neighbour_list = []
119
+
120
+ for neighbour in neighbour_list:
121
+ expanded = []
122
+ for k in neighbour:
123
+ for i in range(self.in_channels_per_joint):
124
+ expanded.append(k * self.in_channels_per_joint + i)
125
+ self.expanded_neighbour_list.append(expanded)
126
+
127
+ self.weight = torch.zeros(out_channels, in_channels)
128
+ self.mask = torch.zeros(out_channels, in_channels)
129
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
130
+
131
+ self.reset_parameters()
132
+
133
+ def reset_parameters(self):
134
+ for i, neighbour in enumerate(self.expanded_neighbour_list):
135
+ tmp = torch.zeros_like(
136
+ self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour]
137
+ )
138
+ self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1
139
+ nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
140
+ self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp
141
+
142
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
143
+ bound = 1 / math.sqrt(fan_in)
144
+ nn.init.uniform_(self.bias, -bound, bound)
145
+
146
+ self.weight = nn.Parameter(self.weight)
147
+ self.mask = nn.Parameter(self.mask, requires_grad=False)
148
+
149
+ def forward(self, input):
150
+ input = input.reshape(input.shape[0], -1)
151
+ weight_masked = self.weight * self.mask
152
+ res = F.linear(input, weight_masked, self.bias)
153
+ if self.extra_dim1: res = res.reshape(res.shape + (1,))
154
+ return res
155
+
156
+
157
+ class SkeletonPoolJoint(nn.Module):
158
+ def __init__(self, topology, pooling_mode, channels_per_joint, last_pool=False):
159
+ super(SkeletonPoolJoint, self).__init__()
160
+
161
+ if pooling_mode != 'mean':
162
+ raise Exception('Unimplemented pooling mode in matrix_implementation')
163
+
164
+ self.joint_num = len(topology)
165
+ self.parent = topology
166
+ self.pooling_list = []
167
+ self.pooling_mode = pooling_mode
168
+
169
+ self.pooling_map = [-1 for _ in range(len(self.parent))]
170
+ self.child = [-1 for _ in range(len(self.parent))]
171
+ children_cnt = [0 for _ in range(len(self.parent))]
172
+ for x, pa in enumerate(self.parent):
173
+ if pa < 0: continue
174
+ children_cnt[pa] += 1
175
+ self.child[pa] = x
176
+ self.pooling_map[0] = 0
177
+ for x in range(len(self.parent)):
178
+ if children_cnt[x] == 0 or (children_cnt[x] == 1 and children_cnt[self.child[x]] > 1):
179
+ while children_cnt[x] <= 1:
180
+ pa = self.parent[x]
181
+ if last_pool:
182
+ seq = [x]
183
+ while pa != -1 and children_cnt[pa] == 1:
184
+ seq = [pa] + seq
185
+ x = pa
186
+ pa = self.parent[x]
187
+ self.pooling_list.append(seq)
188
+ break
189
+ else:
190
+ if pa != -1 and children_cnt[pa] == 1:
191
+ self.pooling_list.append([pa, x])
192
+ x = self.parent[pa]
193
+ else:
194
+ self.pooling_list.append([x, ])
195
+ break
196
+ elif children_cnt[x] > 1:
197
+ self.pooling_list.append([x, ])
198
+
199
+ self.description = 'SkeletonPool(in_joint_num={}, out_joint_num={})'.format(
200
+ len(topology), len(self.pooling_list),
201
+ )
202
+
203
+ self.pooling_list.sort(key=lambda x:x[0])
204
+ for i, a in enumerate(self.pooling_list):
205
+ for j in a:
206
+ self.pooling_map[j] = i
207
+
208
+ self.output_joint_num = len(self.pooling_list)
209
+ self.new_topology = [-1 for _ in range(len(self.pooling_list))]
210
+ for i, x in enumerate(self.pooling_list):
211
+ if i < 1: continue
212
+ self.new_topology[i] = self.pooling_map[self.parent[x[0]]]
213
+
214
+ self.weight = torch.zeros(len(self.pooling_list) * channels_per_joint, self.joint_num * channels_per_joint)
215
+
216
+ for i, pair in enumerate(self.pooling_list):
217
+ for j in pair:
218
+ for c in range(channels_per_joint):
219
+ self.weight[i * channels_per_joint + c, j * channels_per_joint + c] = 1.0 / len(pair)
220
+
221
+ self.weight = nn.Parameter(self.weight, requires_grad=False)
222
+
223
+ def forward(self, input: torch.Tensor):
224
+ return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)
225
+
226
+
227
+ class SkeletonPool(nn.Module):
228
+ def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
229
+ super(SkeletonPool, self).__init__()
230
+
231
+ if pooling_mode != 'mean':
232
+ raise Exception('Unimplemented pooling mode in matrix_implementation')
233
+
234
+ self.channels_per_edge = channels_per_edge
235
+ self.pooling_mode = pooling_mode
236
+ self.edge_num = len(edges) + 1
237
+ self.seq_list = []
238
+ self.pooling_list = []
239
+ self.new_edges = []
240
+ degree = [0] * 100
241
+
242
+ for edge in edges:
243
+ degree[edge[0]] += 1
244
+ degree[edge[1]] += 1
245
+
246
+ def find_seq(j, seq):
247
+ nonlocal self, degree, edges
248
+
249
+ if degree[j] > 2 and j != 0:
250
+ self.seq_list.append(seq)
251
+ seq = []
252
+
253
+ if degree[j] == 1:
254
+ self.seq_list.append(seq)
255
+ return
256
+
257
+ for idx, edge in enumerate(edges):
258
+ if edge[0] == j:
259
+ find_seq(edge[1], seq + [idx])
260
+
261
+ find_seq(0, [])
262
+ for seq in self.seq_list:
263
+ if last_pool:
264
+ self.pooling_list.append(seq)
265
+ continue
266
+ if len(seq) % 2 == 1:
267
+ self.pooling_list.append([seq[0]])
268
+ self.new_edges.append(edges[seq[0]])
269
+ seq = seq[1:]
270
+ for i in range(0, len(seq), 2):
271
+ self.pooling_list.append([seq[i], seq[i + 1]])
272
+ self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])
273
+
274
+ # add global position
275
+ self.pooling_list.append([self.edge_num - 1])
276
+
277
+ self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format(
278
+ len(edges), len(self.pooling_list)
279
+ )
280
+
281
+ self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)
282
+
283
+ for i, pair in enumerate(self.pooling_list):
284
+ for j in pair:
285
+ for c in range(channels_per_edge):
286
+ self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)
287
+
288
+ self.weight = nn.Parameter(self.weight, requires_grad=False)
289
+
290
+ def forward(self, input: torch.Tensor):
291
+ return torch.matmul(self.weight, input)
292
+
293
+
294
+ class SkeletonUnpool(nn.Module):
295
+ def __init__(self, pooling_list, channels_per_edge):
296
+ super(SkeletonUnpool, self).__init__()
297
+ self.pooling_list = pooling_list
298
+ self.input_joint_num = len(pooling_list)
299
+ self.output_joint_num = 0
300
+ self.channels_per_edge = channels_per_edge
301
+ for t in self.pooling_list:
302
+ self.output_joint_num += len(t)
303
+
304
+ self.description = 'SkeletonUnpool(in_joint_num={}, out_joint_num={})'.format(
305
+ self.input_joint_num, self.output_joint_num,
306
+ )
307
+
308
+ self.weight = torch.zeros(self.output_joint_num * channels_per_edge, self.input_joint_num * channels_per_edge)
309
+
310
+ for i, pair in enumerate(self.pooling_list):
311
+ for j in pair:
312
+ for c in range(channels_per_edge):
313
+ self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1
314
+
315
+ self.weight = nn.Parameter(self.weight)
316
+ self.weight.requires_grad_(False)
317
+
318
+ def forward(self, input: torch.Tensor):
319
+ return torch.matmul(self.weight, input.unsqueeze(-1)).squeeze(-1)
320
+
321
+
322
+ def find_neighbor_joint(parents, threshold):
323
+ n_joint = len(parents)
324
+ dist_mat = np.empty((n_joint, n_joint), dtype=np.int)
325
+ dist_mat[:, :] = 100000
326
+ for i, p in enumerate(parents):
327
+ dist_mat[i, i] = 0
328
+ if i != 0:
329
+ dist_mat[i, p] = dist_mat[p, i] = 1
330
+
331
+ """
332
+ Floyd's algorithm
333
+ """
334
+ for k in range(n_joint):
335
+ for i in range(n_joint):
336
+ for j in range(n_joint):
337
+ dist_mat[i, j] = min(dist_mat[i, j], dist_mat[i, k] + dist_mat[k, j])
338
+
339
+ neighbor_list = []
340
+ for i in range(n_joint):
341
+ neighbor = []
342
+ for j in range(n_joint):
343
+ if dist_mat[i, j] <= threshold:
344
+ neighbor.append(j)
345
+ neighbor_list.append(neighbor)
346
+
347
+ return neighbor_list
utils/transforms.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def batch_mm(matrix, matrix_batch):
6
+ """
7
+ https://github.com/pytorch/pytorch/issues/14489#issuecomment-607730242
8
+ :param matrix: Sparse or dense matrix, size (m, n).
9
+ :param matrix_batch: Batched dense matrices, size (b, n, k).
10
+ :return: The batched matrix-matrix product, size (m, n) x (b, n, k) = (b, m, k).
11
+ """
12
+ batch_size = matrix_batch.shape[0]
13
+ # Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k)
14
+ vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1)
15
+
16
+ # A matrix-matrix product is a batched matrix-vector product of the columns.
17
+ # And then reverse the reshaping. (m, n) x (n, b*k) = (m, b*k) -> (m, b, k) -> (b, m, k)
18
+ return matrix.mm(vectors).reshape(matrix.shape[0], batch_size, -1).transpose(1, 0)
19
+
20
+
21
+ def aa2quat(rots, form='wxyz', unified_orient=True):
22
+ """
23
+ Convert angle-axis representation to wxyz quaternion and to the half plan (w >= 0)
24
+ @param rots: angle-axis rotations, (*, 3)
25
+ @param form: quaternion format, either 'wxyz' or 'xyzw'
26
+ @param unified_orient: Use unified orientation for quaternion (quaternion is dual cover of SO3)
27
+ :return:
28
+ """
29
+ angles = rots.norm(dim=-1, keepdim=True)
30
+ norm = angles.clone()
31
+ norm[norm < 1e-8] = 1
32
+ axis = rots / norm
33
+ quats = torch.empty(rots.shape[:-1] + (4,), device=rots.device, dtype=rots.dtype)
34
+ angles = angles * 0.5
35
+ if form == 'wxyz':
36
+ quats[..., 0] = torch.cos(angles.squeeze(-1))
37
+ quats[..., 1:] = torch.sin(angles) * axis
38
+ elif form == 'xyzw':
39
+ quats[..., :3] = torch.sin(angles) * axis
40
+ quats[..., 3] = torch.cos(angles.squeeze(-1))
41
+
42
+ if unified_orient:
43
+ idx = quats[..., 0] < 0
44
+ quats[idx, :] *= -1
45
+
46
+ return quats
47
+
48
+
49
+ def quat2aa(quats):
50
+ """
51
+ Convert wxyz quaternions to angle-axis representation
52
+ :param quats:
53
+ :return:
54
+ """
55
+ _cos = quats[..., 0]
56
+ xyz = quats[..., 1:]
57
+ _sin = xyz.norm(dim=-1)
58
+ norm = _sin.clone()
59
+ norm[norm < 1e-7] = 1
60
+ axis = xyz / norm.unsqueeze(-1)
61
+ angle = torch.atan2(_sin, _cos) * 2
62
+ return axis * angle.unsqueeze(-1)
63
+
64
+
65
+ def quat2mat(quats: torch.Tensor):
66
+ """
67
+ Convert (w, x, y, z) quaternions to 3x3 rotation matrix
68
+ :param quats: quaternions of shape (..., 4)
69
+ :return: rotation matrices of shape (..., 3, 3)
70
+ """
71
+ qw = quats[..., 0]
72
+ qx = quats[..., 1]
73
+ qy = quats[..., 2]
74
+ qz = quats[..., 3]
75
+
76
+ x2 = qx + qx
77
+ y2 = qy + qy
78
+ z2 = qz + qz
79
+ xx = qx * x2
80
+ yy = qy * y2
81
+ wx = qw * x2
82
+ xy = qx * y2
83
+ yz = qy * z2
84
+ wy = qw * y2
85
+ xz = qx * z2
86
+ zz = qz * z2
87
+ wz = qw * z2
88
+
89
+ m = torch.empty(quats.shape[:-1] + (3, 3), device=quats.device, dtype=quats.dtype)
90
+ m[..., 0, 0] = 1.0 - (yy + zz)
91
+ m[..., 0, 1] = xy - wz
92
+ m[..., 0, 2] = xz + wy
93
+ m[..., 1, 0] = xy + wz
94
+ m[..., 1, 1] = 1.0 - (xx + zz)
95
+ m[..., 1, 2] = yz - wx
96
+ m[..., 2, 0] = xz - wy
97
+ m[..., 2, 1] = yz + wx
98
+ m[..., 2, 2] = 1.0 - (xx + yy)
99
+
100
+ return m
101
+
102
+
103
+ def quat2euler(q, order='xyz', degrees=True):
104
+ """
105
+ Convert (w, x, y, z) quaternions to xyz euler angles. This is used for bvh output.
106
+ """
107
+ q0 = q[..., 0]
108
+ q1 = q[..., 1]
109
+ q2 = q[..., 2]
110
+ q3 = q[..., 3]
111
+ es = torch.empty(q0.shape + (3,), device=q.device, dtype=q.dtype)
112
+
113
+ if order == 'xyz':
114
+ es[..., 2] = torch.atan2(2 * (q0 * q3 - q1 * q2), q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
115
+ es[..., 1] = torch.asin((2 * (q1 * q3 + q0 * q2)).clip(-1, 1))
116
+ es[..., 0] = torch.atan2(2 * (q0 * q1 - q2 * q3), q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
117
+ else:
118
+ raise NotImplementedError('Cannot convert to ordering %s' % order)
119
+
120
+ if degrees:
121
+ es = es * 180 / np.pi
122
+
123
+ return es
124
+
125
+
126
+ def euler2mat(rots, order='xyz'):
127
+ axis = {'x': torch.tensor((1, 0, 0), device=rots.device),
128
+ 'y': torch.tensor((0, 1, 0), device=rots.device),
129
+ 'z': torch.tensor((0, 0, 1), device=rots.device)}
130
+
131
+ rots = rots / 180 * np.pi
132
+ mats = []
133
+ for i in range(3):
134
+ aa = axis[order[i]] * rots[..., i].unsqueeze(-1)
135
+ mats.append(aa2mat(aa))
136
+ return mats[0] @ (mats[1] @ mats[2])
137
+
138
+
139
+ def aa2mat(rots):
140
+ """
141
+ Convert angle-axis representation to rotation matrix
142
+ :param rots: angle-axis representation
143
+ :return:
144
+ """
145
+ quat = aa2quat(rots)
146
+ mat = quat2mat(quat)
147
+ return mat
148
+
149
+
150
+ def mat2quat(R) -> torch.Tensor:
151
+ '''
152
+ https://github.com/duolu/pyrotation/blob/master/pyrotation/pyrotation.py
153
+ Convert a rotation matrix to a unit quaternion.
154
+
155
+ This uses the Shepperd’s method for numerical stability.
156
+ '''
157
+
158
+ # The rotation matrix must be orthonormal
159
+
160
+ w2 = (1 + R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2])
161
+ x2 = (1 + R[..., 0, 0] - R[..., 1, 1] - R[..., 2, 2])
162
+ y2 = (1 - R[..., 0, 0] + R[..., 1, 1] - R[..., 2, 2])
163
+ z2 = (1 - R[..., 0, 0] - R[..., 1, 1] + R[..., 2, 2])
164
+
165
+ yz = (R[..., 1, 2] + R[..., 2, 1])
166
+ xz = (R[..., 2, 0] + R[..., 0, 2])
167
+ xy = (R[..., 0, 1] + R[..., 1, 0])
168
+
169
+ wx = (R[..., 2, 1] - R[..., 1, 2])
170
+ wy = (R[..., 0, 2] - R[..., 2, 0])
171
+ wz = (R[..., 1, 0] - R[..., 0, 1])
172
+
173
+ w = torch.empty_like(x2)
174
+ x = torch.empty_like(x2)
175
+ y = torch.empty_like(x2)
176
+ z = torch.empty_like(x2)
177
+
178
+ flagA = (R[..., 2, 2] < 0) * (R[..., 0, 0] > R[..., 1, 1])
179
+ flagB = (R[..., 2, 2] < 0) * (R[..., 0, 0] <= R[..., 1, 1])
180
+ flagC = (R[..., 2, 2] >= 0) * (R[..., 0, 0] < -R[..., 1, 1])
181
+ flagD = (R[..., 2, 2] >= 0) * (R[..., 0, 0] >= -R[..., 1, 1])
182
+
183
+ x[flagA] = torch.sqrt(x2[flagA])
184
+ w[flagA] = wx[flagA] / x[flagA]
185
+ y[flagA] = xy[flagA] / x[flagA]
186
+ z[flagA] = xz[flagA] / x[flagA]
187
+
188
+ y[flagB] = torch.sqrt(y2[flagB])
189
+ w[flagB] = wy[flagB] / y[flagB]
190
+ x[flagB] = xy[flagB] / y[flagB]
191
+ z[flagB] = yz[flagB] / y[flagB]
192
+
193
+ z[flagC] = torch.sqrt(z2[flagC])
194
+ w[flagC] = wz[flagC] / z[flagC]
195
+ x[flagC] = xz[flagC] / z[flagC]
196
+ y[flagC] = yz[flagC] / z[flagC]
197
+
198
+ w[flagD] = torch.sqrt(w2[flagD])
199
+ x[flagD] = wx[flagD] / w[flagD]
200
+ y[flagD] = wy[flagD] / w[flagD]
201
+ z[flagD] = wz[flagD] / w[flagD]
202
+
203
+ # if R[..., 2, 2] < 0:
204
+ #
205
+ # if R[..., 0, 0] > R[..., 1, 1]:
206
+ #
207
+ # x = torch.sqrt(x2)
208
+ # w = wx / x
209
+ # y = xy / x
210
+ # z = xz / x
211
+ #
212
+ # else:
213
+ #
214
+ # y = torch.sqrt(y2)
215
+ # w = wy / y
216
+ # x = xy / y
217
+ # z = yz / y
218
+ #
219
+ # else:
220
+ #
221
+ # if R[..., 0, 0] < -R[..., 1, 1]:
222
+ #
223
+ # z = torch.sqrt(z2)
224
+ # w = wz / z
225
+ # x = xz / z
226
+ # y = yz / z
227
+ #
228
+ # else:
229
+ #
230
+ # w = torch.sqrt(w2)
231
+ # x = wx / w
232
+ # y = wy / w
233
+ # z = wz / w
234
+
235
+ res = [w, x, y, z]
236
+ res = [z.unsqueeze(-1) for z in res]
237
+
238
+ return torch.cat(res, dim=-1) / 2
239
+
240
+
241
+ def quat2repr6d(quat):
242
+ mat = quat2mat(quat)
243
+ res = mat[..., :2, :]
244
+ res = res.reshape(res.shape[:-2] + (6, ))
245
+ return res
246
+
247
+
248
+ def repr6d2mat(repr):
249
+ x = repr[..., :3]
250
+ y = repr[..., 3:]
251
+ x = x / x.norm(dim=-1, keepdim=True)
252
+ z = torch.cross(x, y)
253
+ z = z / z.norm(dim=-1, keepdim=True)
254
+ y = torch.cross(z, x)
255
+ res = [x, y, z]
256
+ res = [v.unsqueeze(-2) for v in res]
257
+ mat = torch.cat(res, dim=-2)
258
+ return mat
259
+
260
+
261
+ def repr6d2quat(repr) -> torch.Tensor:
262
+ x = repr[..., :3]
263
+ y = repr[..., 3:]
264
+ x = x / x.norm(dim=-1, keepdim=True)
265
+ z = torch.cross(x, y)
266
+ z = z / z.norm(dim=-1, keepdim=True)
267
+ y = torch.cross(z, x)
268
+ res = [x, y, z]
269
+ res = [v.unsqueeze(-2) for v in res]
270
+ mat = torch.cat(res, dim=-2)
271
+ return mat2quat(mat)
272
+
273
+
274
+ def inv_affine(mat):
275
+ """
276
+ Calculate the inverse of any affine transformation
277
+ """
278
+ affine = torch.zeros((mat.shape[:2] + (1, 4)))
279
+ affine[..., 3] = 1
280
+ vert_mat = torch.cat((mat, affine), dim=2)
281
+ vert_mat_inv = torch.inverse(vert_mat)
282
+ return vert_mat_inv[..., :3, :]
283
+
284
+
285
+ def inv_rigid_affine(mat):
286
+ """
287
+ Calculate the inverse of a rigid affine transformation
288
+ """
289
+ res = mat.clone()
290
+ res[..., :3] = mat[..., :3].transpose(-2, -1)
291
+ res[..., 3] = -torch.matmul(res[..., :3], mat[..., 3].unsqueeze(-1)).squeeze(-1)
292
+ return res
293
+
294
+
295
+ def generate_pose(batch_size, device, uniform=False, factor=1, root_rot=False, n_bone=None, ee=None):
296
+ if n_bone is None: n_bone = 24
297
+ if ee is not None:
298
+ if root_rot:
299
+ ee.append(0)
300
+ n_bone_ = n_bone
301
+ n_bone = len(ee)
302
+ axis = torch.randn((batch_size, n_bone, 3), device=device)
303
+ axis /= axis.norm(dim=-1, keepdim=True)
304
+ if uniform:
305
+ angle = torch.rand((batch_size, n_bone, 1), device=device) * np.pi
306
+ else:
307
+ angle = torch.randn((batch_size, n_bone, 1), device=device) * np.pi / 6 * factor
308
+ angle.clamp(-np.pi, np.pi)
309
+ poses = axis * angle
310
+ if ee is not None:
311
+ res = torch.zeros((batch_size, n_bone_, 3), device=device)
312
+ for i, id in enumerate(ee):
313
+ res[:, id] = poses[:, i]
314
+ poses = res
315
+ poses = poses.reshape(batch_size, -1)
316
+ if not root_rot:
317
+ poses[..., :3] = 0
318
+ return poses
319
+
320
+
321
+ def slerp(l, r, t, unit=True):
322
+ """
323
+ :param l: shape = (*, n)
324
+ :param r: shape = (*, n)
325
+ :param t: shape = (*)
326
+ :param unit: If l and h are unit vectors
327
+ :return:
328
+ """
329
+ eps = 1e-8
330
+ if not unit:
331
+ l_n = l / torch.norm(l, dim=-1, keepdim=True)
332
+ r_n = r / torch.norm(r, dim=-1, keepdim=True)
333
+ else:
334
+ l_n = l
335
+ r_n = r
336
+ omega = torch.acos((l_n * r_n).sum(dim=-1).clamp(-1, 1))
337
+ dom = torch.sin(omega)
338
+
339
+ flag = dom < eps
340
+
341
+ res = torch.empty_like(l_n)
342
+ t_t = t[flag].unsqueeze(-1)
343
+ res[flag] = (1 - t_t) * l_n[flag] + t_t * r_n[flag]
344
+
345
+ flag = ~ flag
346
+
347
+ t_t = t[flag]
348
+ d_t = dom[flag]
349
+ va = torch.sin((1 - t_t) * omega[flag]) / d_t
350
+ vb = torch.sin(t_t * omega[flag]) / d_t
351
+ res[flag] = (va.unsqueeze(-1) * l_n[flag] + vb.unsqueeze(-1) * r_n[flag])
352
+ return res
353
+
354
+
355
+ def slerp_quat(l, r, t):
356
+ """
357
+ slerp for unit quaternions
358
+ :param l: (*, 4) unit quaternion
359
+ :param r: (*, 4) unit quaternion
360
+ :param t: (*) scalar between 0 and 1
361
+ """
362
+ t = t.expand(l.shape[:-1])
363
+ flag = (l * r).sum(dim=-1) >= 0
364
+ res = torch.empty_like(l)
365
+ res[flag] = slerp(l[flag], r[flag], t[flag])
366
+ flag = ~ flag
367
+ res[flag] = slerp(-l[flag], r[flag], t[flag])
368
+ return res
369
+
370
+
371
+ # def slerp_6d(l, r, t):
372
+ # l_q = repr6d2quat(l)
373
+ # r_q = repr6d2quat(r)
374
+ # res_q = slerp_quat(l_q, r_q, t)
375
+ # return quat2repr6d(res_q)
376
+
377
+
378
+ def interpolate_6d(input, size):
379
+ """
380
+ :param input: (batch_size, n_channels, length)
381
+ :param size: required output size for temporal axis
382
+ :return:
383
+ """
384
+ batch = input.shape[0]
385
+ length = input.shape[-1]
386
+ input = input.reshape((batch, -1, 6, length))
387
+ input = input.permute(0, 1, 3, 2) # (batch_size, n_joint, length, 6)
388
+ input_q = repr6d2quat(input)
389
+ idx = torch.tensor(list(range(size)), device=input_q.device, dtype=torch.float) / size * (length - 1)
390
+ idx_l = torch.floor(idx)
391
+ t = idx - idx_l
392
+ idx_l = idx_l.long()
393
+ idx_r = idx_l + 1
394
+ t = t.reshape((1, 1, -1))
395
+ res_q = slerp_quat(input_q[..., idx_l, :], input_q[..., idx_r, :], t)
396
+ res = quat2repr6d(res_q) # shape = (batch_size, n_joint, t, 6)
397
+ res = res.permute(0, 1, 3, 2)
398
+ res = res.reshape((batch, -1, size))
399
+ return res