Spaces:
Running
Running
import torch | |
from utils.transforms import quat2mat, repr6d2mat, euler2mat | |
class ForwardKinematics: | |
def __init__(self, parents, offsets=None): | |
self.parents = parents | |
if offsets is not None and len(offsets.shape) == 2: | |
offsets = offsets.unsqueeze(0) | |
self.offsets = offsets | |
def forward(self, rots, offsets=None, global_pos=None): | |
""" | |
Forward Kinematics: returns a per-bone transformation | |
@param rots: local joint rotations (batch_size, bone_num, 3, 3) | |
@param offsets: (batch_size, bone_num, 3) or None | |
@param global_pos: global_position: (batch_size, 3) or keep it as in offsets (default) | |
@return: (batch_szie, bone_num, 3, 4) | |
""" | |
rots = rots.clone() | |
if offsets is None: | |
offsets = self.offsets.to(rots.device) | |
if global_pos is None: | |
global_pos = offsets[:, 0] | |
pos = torch.zeros((rots.shape[0], rots.shape[1], 3), device=rots.device) | |
rest_pos = torch.zeros_like(pos) | |
res = torch.zeros((rots.shape[0], rots.shape[1], 3, 4), device=rots.device) | |
pos[:, 0] = global_pos | |
rest_pos[:, 0] = offsets[:, 0] | |
for i, p in enumerate(self.parents): | |
if i != 0: | |
rots[:, i] = torch.matmul(rots[:, p], rots[:, i]) | |
pos[:, i] = torch.matmul(rots[:, p], offsets[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, p] | |
rest_pos[:, i] = rest_pos[:, p] + offsets[:, i] | |
res[:, i, :3, :3] = rots[:, i] | |
res[:, i, :, 3] = torch.matmul(rots[:, i], -rest_pos[:, i].unsqueeze(-1)).squeeze(-1) + pos[:, i] | |
return res | |
def accumulate(self, local_rots): | |
""" | |
Get global joint rotation from local rotations | |
@param local_rots: (batch_size, n_bone, 3, 3) | |
@return: global_rotations | |
""" | |
res = torch.empty_like(local_rots) | |
for i, p in enumerate(self.parents): | |
if i == 0: | |
res[:, i] = local_rots[:, i] | |
else: | |
res[:, i] = torch.matmul(res[:, p], local_rots[:, i]) | |
return res | |
def unaccumulate(self, global_rots): | |
""" | |
Get local joint rotation from global rotations | |
@param global_rots: (batch_size, n_bone, 3, 3) | |
@return: local_rotations | |
""" | |
res = torch.empty_like(global_rots) | |
inv = torch.empty_like(global_rots) | |
for i, p in enumerate(self.parents): | |
if i == 0: | |
inv[:, i] = global_rots[:, i].transpose(-2, -1) | |
res[:, i] = global_rots[:, i] | |
continue | |
res[:, i] = torch.matmul(inv[:, p], global_rots[:, i]) | |
inv[:, i] = torch.matmul(res[:, i].transpose(-2, -1), inv[:, p]) | |
return res | |
class ForwardKinematicsJoint: | |
def __init__(self, parents, offset): | |
self.parents = parents | |
self.offset = offset | |
''' | |
rotation should have shape batch_size * Joint_num * (3/4) * Time | |
position should have shape batch_size * 3 * Time | |
offset should have shape batch_size * Joint_num * 3 | |
output have shape batch_size * Time * Joint_num * 3 | |
''' | |
def forward(self, rotation: torch.Tensor, position: torch.Tensor, offset=None, | |
world=True): | |
''' | |
if not quater and rotation.shape[-2] != 3: raise Exception('Unexpected shape of rotation') | |
if quater and rotation.shape[-2] != 4: raise Exception('Unexpected shape of rotation') | |
rotation = rotation.permute(0, 3, 1, 2) | |
position = position.permute(0, 2, 1) | |
''' | |
if rotation.shape[-1] == 6: | |
transform = repr6d2mat(rotation) | |
elif rotation.shape[-1] == 4: | |
norm = torch.norm(rotation, dim=-1, keepdim=True) | |
rotation = rotation / norm | |
transform = quat2mat(rotation) | |
elif rotation.shape[-1] == 3: | |
transform = euler2mat(rotation) | |
else: | |
raise Exception('Only accept quaternion rotation input') | |
result = torch.empty(transform.shape[:-2] + (3,), device=position.device) | |
if offset is None: | |
offset = self.offset | |
offset = offset.reshape((-1, 1, offset.shape[-2], offset.shape[-1], 1)) | |
result[..., 0, :] = position | |
for i, pi in enumerate(self.parents): | |
if pi == -1: | |
assert i == 0 | |
continue | |
result[..., i, :] = torch.matmul(transform[..., pi, :, :], offset[..., i, :, :]).squeeze() | |
transform[..., i, :, :] = torch.matmul(transform[..., pi, :, :].clone(), transform[..., i, :, :].clone()) | |
if world: result[..., i, :] += result[..., pi, :] | |
return result | |
class InverseKinematicsJoint: | |
def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains): | |
self.rotations = rotations.detach().clone() | |
self.rotations.requires_grad_(True) | |
self.position = positions.detach().clone() | |
self.position.requires_grad_(True) | |
self.parents = parents | |
self.offset = offset | |
self.constrains = constrains | |
self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999)) | |
self.criteria = torch.nn.MSELoss() | |
self.fk = ForwardKinematicsJoint(parents, offset) | |
self.glb = None | |
def step(self): | |
self.optimizer.zero_grad() | |
glb = self.fk.forward(self.rotations, self.position) | |
loss = self.criteria(glb, self.constrains) | |
loss.backward() | |
self.optimizer.step() | |
self.glb = glb | |
return loss.item() | |
class InverseKinematicsJoint2: | |
def __init__(self, rotations: torch.Tensor, positions: torch.Tensor, offset, parents, constrains, cid, | |
lambda_rec_rot=1., lambda_rec_pos=1., use_velo=False): | |
self.use_velo = use_velo | |
self.rotations_ori = rotations.detach().clone() | |
self.rotations = rotations.detach().clone() | |
self.rotations.requires_grad_(True) | |
self.position_ori = positions.detach().clone() | |
self.position = positions.detach().clone() | |
if self.use_velo: | |
self.position[1:] = self.position[1:] - self.position[:-1] | |
self.position.requires_grad_(True) | |
self.parents = parents | |
self.offset = offset | |
self.constrains = constrains.detach().clone() | |
self.cid = cid | |
self.lambda_rec_rot = lambda_rec_rot | |
self.lambda_rec_pos = lambda_rec_pos | |
self.optimizer = torch.optim.Adam([self.position, self.rotations], lr=1e-3, betas=(0.9, 0.999)) | |
self.criteria = torch.nn.MSELoss() | |
self.fk = ForwardKinematicsJoint(parents, offset) | |
self.glb = None | |
def step(self): | |
self.optimizer.zero_grad() | |
if self.use_velo: | |
position = torch.cumsum(self.position, dim=0) | |
else: | |
position = self.position | |
glb = self.fk.forward(self.rotations, position) | |
self.constrain_loss = self.criteria(glb[:, self.cid], self.constrains) | |
self.rec_loss_rot = self.criteria(self.rotations, self.rotations_ori) | |
self.rec_loss_pos = self.criteria(self.position, self.position_ori) | |
loss = self.constrain_loss + self.rec_loss_rot * self.lambda_rec_rot + self.rec_loss_pos * self.lambda_rec_pos | |
loss.backward() | |
self.optimizer.step() | |
self.glb = glb | |
return loss.item() | |
def get_position(self): | |
if self.use_velo: | |
position = torch.cumsum(self.position.detach(), dim=0) | |
else: | |
position = self.position.detach() | |
return position | |