File size: 14,227 Bytes
98e2c81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import os
import os.path as osp
import torch
import torch.nn.functional as F
import numpy as np
import itertools
from tensorboardX import SummaryWriter

from NN.losses import make_criteria
from utils.base import logger

class GPS:
    def __init__(self,
                 init_mode: str = 'random_synthesis',
                 noise_sigma: float = 1.0,
                 coarse_ratio: float = 0.2,
                 coarse_ratio_factor: float = 6,
                 pyr_factor: float = 0.75,
                 num_stages_limit: int = -1,
                 device: str = 'cuda:0',
                 silent: bool = False
                 ):
        '''
        Args:
            init_mode:
                - 'random_synthesis': init with random seed
                - 'random': init with random seed
            noise_sigma: float = 1.0, random noise.
            coarse_ratio: float = 0.2, ratio at the coarse level.
            pyr_factor: float = 0.75, pyramid factor.
            num_stages_limit: int = -1, no limit.
            device: str = 'cuda:0', default device.
            silent: bool = False, mute the output.
        '''
        self.init_mode = init_mode
        self.noise_sigma = noise_sigma
        self.coarse_ratio = coarse_ratio
        self.coarse_ratio_factor = coarse_ratio_factor
        self.pyr_factor = pyr_factor
        self.num_stages_limit = num_stages_limit
        self.device = torch.device(device)
        self.silent = silent

    def _get_pyramid_lengths(self, dest, ext=None):
        """Get a list of pyramid lengths"""
        if self.coarse_ratio == -1:
            self.coarse_ratio = np.around(ext['criteria']['patch_size'] * self.coarse_ratio_factor / dest, 2)

        lengths = [int(np.round(dest * self.coarse_ratio))]
        while lengths[-1] < dest:
            lengths.append(int(np.round(lengths[-1] / self.pyr_factor)))
            if lengths[-1] == lengths[-2]:
                lengths[-1] += 1
        lengths[-1] = dest

        return lengths

    def _get_target_pyramid(self, target, ext=None):
        """Reads a target motion(s) and create a pyraimd out of it. Ordered in increatorch.sing size"""
        self._num_target = len(target)
        lengths = []
        min_len = 10000
        for i in range(len(target)):
            new_length = self._get_pyramid_lengths(len(target[i]), ext)
            min_len = min(min_len, len(new_length))
            if self.num_stages_limit != -1:
                new_length = new_length[:self.num_stages_limit]
            lengths.append(new_length)
        for i in range(len(target)):
            lengths[i] = lengths[i][-min_len:]
        self.pyraimd_lengths = lengths

        target_pyramid = [[] for _ in range(len(lengths[0]))]
        for step in range(len(lengths[0])):
            for i in range(len(target)):
                length = lengths[i][step]
                motion = target[i]
                target_pyramid[step].append(motion.sample(size=length).to(self.device))
                # target_pyramid[step].append(motion.pos2velo(motion.sample(size=length)))
                # motion.motion_data = motion.pos2velo(motion.motion_data)
                # target_pyramid[step].append(motion.sample(size=length))
                # motion.motion_data = motion.velo2pos(motion.motion_data)    

        if not self.silent:
            print('Levels:', lengths)
            for i in range(len(target_pyramid)):
                print(f'Number of clips in target pyramid {i} is {len(target_pyramid[i])}: {[[tgt.min(), tgt.max()] for tgt in target_pyramid[i]]}')

        return target_pyramid

    def _get_initial_motion(self):
        """Prepare the initial motion for optimization"""
        if 'random_synthesis' in str(self.init_mode):
            m = self.init_mode.split('/')[-1]
            if m =='random_synthesis':
                final_length = sum([i[-1] for i in self.pyraimd_lengths])
            elif 'x' in m:
                final_length = int(m.replace('x', '')) * sum([i[-1] for i in self.pyraimd_lengths])
            elif (self.init_mode.split('/')[-1]).isdigit():
                final_length = int(self.init_mode.split('/')[-1])
            else:
                raise ValueError(f'incorrect init_mode: {self.init_mode}')

            self.synthesized_lengths = self._get_pyramid_lengths(final_length)

        else:
            raise ValueError(f'Unsupported init_mode {self.init_mode}')
            
        initial_motion = F.interpolate(torch.cat([self.target_pyramid[0][i] for i in range(self._num_target)], dim=-1),
                                       size=self.synthesized_lengths[0], mode='linear', align_corners=True)
        if self.noise_sigma > 0:
            initial_motion_w_noise = initial_motion + torch.randn_like(initial_motion) * self.noise_sigma
            initial_motion_w_noise = torch.fmod(initial_motion_w_noise, 1.0)
        else:
            initial_motion_w_noise = initial_motion

        if not self.silent:
            print('Synthesized lengths:', self.synthesized_lengths)
            print('Initial motion:', initial_motion.min(), initial_motion.max())
            print('Initial motion with noise:', initial_motion_w_noise.min(), initial_motion_w_noise.max())

        return initial_motion_w_noise

    def run(self, target, mode="backpropagate", ext=None, debug_dir=None):
        '''
        Run the patch-based motion synthesis.

        Args:
            target (torch.Tensor): Target data.
            mode (str): Optimization mode. Support ['backpropagate', 'match_and_blend']
            ext (dict): extra data or constrain.
            debug_dir (str): Debug directory.
        '''
        # preprare data
        self.target_pyramid = self._get_target_pyramid(target, ext)
        self.synthesized = self._get_initial_motion()
        if debug_dir is not None:
            writer = SummaryWriter(log_dir=debug_dir)

        # prepare configuration
        if mode == "backpropagate":
            self.synthesized.requires_grad_(True)
            assert 'criteria' in ext.keys(), 'Please specify a criteria for synthsis.'
            criteria = make_criteria(ext['criteria']).to(self.device)
        elif mode == "match_and_blend":
            self.synthesized.requires_grad_(False)
            assert 'criteria' in ext.keys(), 'Please specify a criteria for synthsis.'
            criteria = make_criteria(ext['criteria']).to(self.device)
        else:
            raise ValueError(f'Unsupported mode: {mode}')

        # perform synthsis
        self.pbar = logger(ext['num_itrs'], len(self.target_pyramid))
        ext['pbar'] = self.pbar
        for lvl, lvl_target in enumerate(self.target_pyramid):
            self.pbar.new_lvl()
            if lvl > 0:
                with torch.no_grad():
                    self.synthesized = F.interpolate(self.synthesized.detach(), size=self.synthesized_lengths[lvl], mode='linear')
                if mode == "backpropagate":
                    self.synthesized.requires_grad_(True)

            if mode == "backpropagate": # direct optimize the synthesized motion
                self.synthesized, losses = GPS.backpropagate(self.synthesized, lvl_target, criteria, ext=ext)
            elif mode == "match_and_blend":
                self.synthesized, losses = GPS.match_and_blend(self.synthesized, lvl_target, criteria, ext=ext)

            criteria.clean_cache()
            if debug_dir:
                for itr in range(len(losses)):
                    writer.add_scalar(f'optimize/losses_lvl{lvl}', losses[itr], itr)
        self.pbar.pbar.close()


        return self.synthesized.detach()

    @staticmethod
    def backpropagate(synthesized, targets, criteria=None, ext=None):
        """
        Minimizes criteria(synthesized, target) for num_steps SGD steps
        Args:
            targets (torch.Tensor): Target data.
            ext (dict): extra configurations.
        """
        if criteria is None:
            assert 'criteria' in ext.keys(), 'Criteria is not set'
            criteria = make_criteria(ext['criteria']).to(synthesized.device)

        optim = None
        if 'optimizer' in ext.keys():
            if ext['optimizer'] == 'Adam':
                optim = torch.optim.Adam([synthesized], lr=ext['lr'])
            elif ext['optimizer'] == 'SGD':
                optim = torch.optim.SGD([synthesized], lr=ext['lr'])
            elif ext['optimizer'] == 'RMSprop':
                optim = torch.optim.RMSprop([synthesized], lr=ext['lr'])
            else:
                print(f'use default RMSprop optimizer')
        optim = torch.optim.RMSprop([synthesized], lr=ext['lr']) if optim is None else optim
        # optim = torch.optim.Adam([synthesized], lr=ext['lr']) if optim is None else optim
        lr_decay = np.exp(np.log(0.333) / ext['num_itrs'])

        # other constraints
        trajectory = ext['trajectory'] if 'trajectory' in ext.keys() else None

        losses = []
        for _i in range(ext['num_itrs']):
            optim.zero_grad()
            
            loss = criteria(synthesized, targets)

            if trajectory is not None: ## velo constrain
                target_traj = F.interpolate(trajectory, size=synthesized.shape[-1], mode='linear')
                # target_traj = F.interpolate(trajectory, size=synthesized.shape[-1], mode='linear', align_corners=False)
                target_velo = ext['pos2velo'](target_traj)
                
                velo_mask = [-3, -1]
                loss += 1 * F.l1_loss(synthesized[:, velo_mask, :], target_velo[:, velo_mask, :])

            loss.backward()
            optim.step()

            # Update staus
            losses.append(loss.item())
            if 'pbar' in ext.keys():
                ext['pbar'].step()
                ext['pbar'].print()

        return synthesized, losses

    @staticmethod
    @torch.no_grad()
    def match_and_blend(synthesized, targets, criteria, ext):
        """
        Minimizes criteria(synthesized, target)
        Args:
            targets (torch.Tensor): Target data.
            ext (dict): extra configurations.
        """
        losses = []
        for _i in range(ext['num_itrs']):
            if 'parts_list' in ext.keys():
                def extract_part_motions(motion, parts_list):
                    part_motions = []
                    n_frames = motion.shape[-1]
                    rot, pos = motion[:, :-3, :].reshape(-1, 6, n_frames), motion[:, -3:, :]

                    for part in parts_list:
                        # part -= 1
                        part = [i -1 for i in part]

                        # print(part)
                        if 0 in part:
                            part_motions += [torch.cat([rot[part].view(1, -1, n_frames), pos.view(1, -1, n_frames)], dim=1)]
                        else:
                            part_motions += [rot[part].view(1, -1, n_frames)]

                    return part_motions
                def combine_part_motions(part_motions, parts_list):
                    assert len(part_motions) == len(parts_list)
                    n_frames = part_motions[0].shape[-1]
                    l = max(list(itertools.chain(*parts_list)))
                    # print(l, n_frames)
                    # motion = torch.zeros((1, (l+1)*6 + 3, n_frames), device=part_motions[0].device)
                    rot = torch.zeros(((l+1), 6, n_frames), device=part_motions[0].device)
                    pos = torch.zeros((1, 3, n_frames), device=part_motions[0].device)
                    div_rot = torch.zeros((l+1), device=part_motions[0].device)
                    div_pos = torch.zeros(1, device=part_motions[0].device)

                    for part_motion, part in zip(part_motions, parts_list):
                        part = [i -1 for i in part]

                        if 0 in part:
                            # print(part_motion.shape)
                            pos += part_motion[:, -3:, :]
                            div_pos += 1
                            rot[part] += part_motion[:, :-3, :].view(-1, 6, n_frames)
                            div_rot[part] += 1
                        else:
                            rot[part] += part_motion.view(-1, 6, n_frames)
                            div_rot[part] += 1
                            
                    # print(div_rot, div_pos)
                    # print(rot.shape)
                    rot = (rot.permute(1, 2, 0) / div_rot).permute(2, 0, 1)
                    pos = pos / div_pos

                    return torch.cat([rot.view(1, -1, n_frames), pos.view(1, 3, n_frames)], dim=1)

                # raw_synthesized = synthesized
                # print(synthesized, synthesized.shape)
                synthesized_part_motions = extract_part_motions(synthesized, ext['parts_list'])
                targets_part_motions = [extract_part_motions(target, ext['parts_list']) for target in targets]

                synthesized = []
                for _j in range(len(synthesized_part_motions)):
                    synthesized_part_motion = synthesized_part_motions[_j]
                    # synthesized += [synthesized_part_motion]
                    targets_part_motion = [target[_j] for target in targets_part_motions]
                    # # print(synthesized_part_motion.shape, targets_part_motion[0].shape)
                    synthesized += [criteria(synthesized_part_motion, targets_part_motion, ext=ext, return_blended_results=True)[0]]

                # print(len(synthesized))
                
                synthesized = combine_part_motions(synthesized, ext['parts_list'])
                # print(synthesized, synthesized.shape)
                # print((raw_synthesized-synthesized > 0.00001).sum())
                # exit()
                # print(synthesized.shape)
                losses = 0

                # exit()
       
            else:
                synthesized, loss = criteria(synthesized, targets, ext=ext, return_blended_results=True)

                # Update staus
                losses.append(loss.item())
                if 'pbar' in ext.keys():
                    ext['pbar'].step()
                    ext['pbar'].print()

        return synthesized, losses