import torch from utils.transforms import quat2mat, repr6d2mat, euler2mat class ForwardKinematics: def __init__(self, parents, offsets=None): self.parents = parents if offsets is not None and len(offsets.shape) == 2: offsets = offsets.unsqueeze(0) self.offsets = offsets def forward(self, rots, offsets=None, global_pos=None): """ Forward Kinematics: returns a per-bone transformation @param rots: local joint rotations (batch_size, bone_num, 3, 3) @param offsets: (batch_size, bone_num, 3) or None @param global_pos: global_position: (batch_size, 3) or keep it as in offsets (default) @return: (batch_szie, bone_num, 3, 4) """ rots = rots.clone() if offsets is None: offsets = self.offsets.to(rots.device) if global_pos is None: global_pos = offsets[:, 0] pos = torch.zeros((rots.shape[0], rots.shape[1], 3), device=rots.device) rest_pos = torch.zeros_like(pos) res = torch.zeros((rots.shape[0], rots.shape[1], 3, 4), device=rots.device) pos[:, 0] = global_pos rest_pos[:, 0] = offsets[:, 0] for i, p in enumerate(self.parents): if i != 0: rots[:, i] = torch.matmul(rots[:, p], rots[:, i]) pos[:, i] = torch.matmul(rots[:, p], offsets[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, p] rest_pos[:, i] = rest_pos[:, p] + offsets[:, i] res[:, i, :3, :3] = rots[:, i] res[:, i, :, 3] = torch.matmul(rots[:, i], -rest_pos[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, i] return res def accumulate(self, local_rots): """ Get global joint rotation from local rotations @param local_rots: (batch_size, n_bone, 3, 3) @return: global_rotations """ res = torch.empty_like(local_rots) for i, p in enumerate(self.parents): if i == 0: res[:, i] = local_rots[:, i] else: res[:, i] = torch.matmul(res[:, p], local_rots[:, i]) return res def unaccumulate(self, global_rots): """ Get local joint rotation from global rotations @param global_rots: (batch_size, n_bone, 3, 3) @return: local_rotations """ res = torch.empty_like(global_rots) inv = torch.empty_like(global_rots) for i, p in enumerate(self.parents): if i == 0: inv[:, i] = global_rots[:, i].transpose(-2, -1) res[:, i] = global_rots[:, i] continue res[:, i] = torch.matmul(inv[:, p], global_rots[:, i]) inv[:, i] = torch.matmul(res[:, i].transpose(-2, -1), inv[:, p]) return res class ForwardKinematicsJoint: def __init__(self, parents, offset): self.parents = parents self.offset = offset ''' rotation should have shape batch_size * Joint_num * (3/4) * Time position should have shape batch_size * 3 * Time offset should have shape batch_size * Joint_num * 3 output have shape batch_size * Time * Joint_num * 3 ''' def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset=None, world=True): ''' if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation') if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation') rotation = rotation.permute(0, 3, 1, 2) position = position.permute(0, 2, 1) ''' if rotation.shape[-1] == 6: transform = repr6d2mat(rotation) elif rotation.shape[-1] == 4: norm = torch.norm(rotation, dim=-1, keepdim=True) rotation = rotation / norm transform = quat2mat(rotation) elif rotation.shape[-1] == 3: transform = euler2mat(rotation) else: raise Exception('Only accept quaternion rotation input') result = torch.empty(transform.shape[:-2] + (3,), device=position.device) if offset is None: offset = self.offset offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1)) result[..., 0, :] = position for i, pi in enumerate(self.parents): if pi == -1: assert i == 0 continue result[..., i, :] = torch.matmul(transform[..., pi, :, :], offset[..., i, :, :]).squeeze() transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone()) if world: result[..., i, :] += result[..., pi, :] return result class InverseKinematicsJoint: def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains): self.rotations = rotations.detach().clone() self.rotations.requires_grad_(True) self.position = positions.detach().clone() self.position.requires_grad_(True) self.parents = parents self.offset = offset self.constrains = constrains self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999)) self.criteria = torch.nn.MSELoss() self.fk = ForwardKinematicsJoint(parents, offset) self.glb = None def step(self): self.optimizer.zero_grad() glb = self.fk.forward(self.rotations, self.position) loss = self.criteria(glb, self.constrains) loss.backward() self.optimizer.step() self.glb = glb return loss.item() class InverseKinematicsJoint2: def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains, cid, lambda_rec_rot=1., lambda_rec_pos=1., use_velo=False): self.use_velo = use_velo self.rotations_ori = rotations.detach().clone() self.rotations = rotations.detach().clone() self.rotations.requires_grad_(True) self.position_ori = positions.detach().clone() self.position = positions.detach().clone() if self.use_velo: self.position[1:] = self.position[1:] - self.position[:-1] self.position.requires_grad_(True) self.parents = parents self.offset = offset self.constrains = constrains.detach().clone() self.cid = cid self.lambda_rec_rot = lambda_rec_rot self.lambda_rec_pos = lambda_rec_pos self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999)) self.criteria = torch.nn.MSELoss() self.fk = ForwardKinematicsJoint(parents, offset) self.glb = None def step(self): self.optimizer.zero_grad() if self.use_velo: position = torch.cumsum(self.position, dim=0) else: position = self.position glb = self.fk.forward(self.rotations, position) self.constrain_loss = self.criteria(glb[:, self.cid], self.constrains) self.rec_loss_rot = self.criteria(self.rotations, self.rotations_ori) self.rec_loss_pos = self.criteria(self.position, self.position_ori) loss = self.constrain_loss + self.rec_loss_rot * self.lambda_rec_rot + self.rec_loss_pos * self.lambda_rec_pos loss.backward() self.optimizer.step() self.glb = glb return loss.item() def get_position(self): if self.use_velo: position = torch.cumsum(self.position.detach(), dim=0) else: position = self.position.detach() return position