AiOS / detrsmpl /utils /transforms.py
ttxskk
update
d7e58f0
raw
history blame
22.3 kB
from typing import Union
import numpy
import torch
from detrsmpl.core.conventions.joints_mapping.standard_joint_angles import (
TRANSFORMATION_AA_TO_SJA,
TRANSFORMATION_SJA_TO_AA,
)
from .logger import get_root_logger
try:
from pytorch3d.transforms import (
axis_angle_to_matrix,
axis_angle_to_quaternion,
euler_angles_to_matrix,
matrix_to_euler_angles,
matrix_to_quaternion,
matrix_to_rotation_6d,
quaternion_to_axis_angle,
quaternion_to_matrix,
rotation_6d_to_matrix,
)
except (ImportError, ModuleNotFoundError):
import traceback
logger = get_root_logger()
stack_str = ''
for line in traceback.format_stack():
if 'frozen' not in line:
stack_str += line + '\n'
import_exception = traceback.format_exc() + '\n'
warning_msg = stack_str + import_exception + \
'If pytorch3d is not required,' +\
' this warning could be ignored.'
logger.warning(warning_msg)
class Compose:
def __init__(self, transforms: list):
"""Composes several transforms together. This transform does not
support torchscript.
Args:
transforms (list): (list of transform functions)
"""
self.transforms = transforms
def __call__(self,
rotation: Union[torch.Tensor, numpy.ndarray],
convention: str = 'xyz',
**kwargs):
convention = convention.lower()
if not (set(convention) == set('xyz') and len(convention) == 3):
raise ValueError(f'Invalid convention {convention}.')
if isinstance(rotation, numpy.ndarray):
data_type = 'numpy'
rotation = torch.FloatTensor(rotation)
elif isinstance(rotation, torch.Tensor):
data_type = 'tensor'
else:
raise TypeError(
'Type of rotation should be torch.Tensor or numpy.ndarray')
for t in self.transforms:
if 'convention' in t.__code__.co_varnames:
rotation = t(rotation, convention.upper(), **kwargs)
else:
rotation = t(rotation, **kwargs)
if data_type == 'numpy':
rotation = rotation.detach().cpu().numpy()
return rotation
def aa_to_rotmat(
axis_angle: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""
Convert axis_angle to rotation matrixs.
Args:
axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3).
"""
if axis_angle.shape[-1] != 3:
raise ValueError(
f'Invalid input axis angles shape f{axis_angle.shape}.')
t = Compose([axis_angle_to_matrix])
return t(axis_angle)
def aa_to_quat(
axis_angle: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""
Convert axis_angle to quaternions.
Args:
axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4).
"""
if axis_angle.shape[-1] != 3:
raise ValueError(f'Invalid input axis angles f{axis_angle.shape}.')
t = Compose([axis_angle_to_quaternion])
return t(axis_angle)
def ee_to_rotmat(euler_angle: Union[torch.Tensor, numpy.ndarray],
convention='xyz') -> Union[torch.Tensor, numpy.ndarray]:
"""Convert euler angle to rotation matrixs.
Args:
euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3). ndim of input is unlimited.
convention (str, optional): Convention string of three letters
from {“x”, “y”, and “z”}. Defaults to 'xyz'.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3).
"""
if euler_angle.shape[-1] != 3:
raise ValueError(
f'Invalid input euler angles shape f{euler_angle.shape}.')
t = Compose([euler_angles_to_matrix])
return t(euler_angle, convention.upper())
def rotmat_to_ee(
matrix: Union[torch.Tensor, numpy.ndarray],
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
"""Convert rotation matrixs to euler angle.
Args:
matrix (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3, 3). ndim of input is unlimited.
convention (str, optional): Convention string of three letters
from {“x”, “y”, and “z”}. Defaults to 'xyz'.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
"""
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
t = Compose([matrix_to_euler_angles])
return t(matrix, convention.upper())
def rotmat_to_quat(
matrix: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert rotation matrixs to quaternions.
Args:
matrix (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3, 3). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4).
"""
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
t = Compose([matrix_to_quaternion])
return t(matrix)
def rotmat_to_rot6d(
matrix: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert rotation matrixs to rotation 6d representations.
Args:
matrix (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3, 3). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6).
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
t = Compose([matrix_to_rotation_6d])
return t(matrix)
def quat_to_aa(
quaternions: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert quaternions to axis angles.
Args:
quaternions (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
"""
if quaternions.shape[-1] != 4:
raise ValueError(f'Invalid input quaternions f{quaternions.shape}.')
t = Compose([quaternion_to_axis_angle])
return t(quaternions)
def quat_to_rotmat(
quaternions: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert quaternions to rotation matrixs.
Args:
quaternions (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3).
"""
if quaternions.shape[-1] != 4:
raise ValueError(
f'Invalid input quaternions shape f{quaternions.shape}.')
t = Compose([quaternion_to_matrix])
return t(quaternions)
def rot6d_to_rotmat(
rotation_6d: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert rotation 6d representations to rotation matrixs.
Args:
rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 6). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3).
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
if rotation_6d.shape[-1] != 6:
raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.')
t = Compose([rotation_6d_to_matrix])
return t(rotation_6d)
def aa_to_ee(axis_angle: Union[torch.Tensor, numpy.ndarray],
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
"""Convert axis angles to euler angle.
Args:
axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3). ndim of input is unlimited.
convention (str, optional): Convention string of three letters
from {“x”, “y”, and “z”}. Defaults to 'xyz'.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
"""
if axis_angle.shape[-1] != 3:
raise ValueError(
f'Invalid input axis_angle shape f{axis_angle.shape}.')
t = Compose([axis_angle_to_matrix, matrix_to_euler_angles])
return t(axis_angle, convention)
def aa_to_rot6d(
axis_angle: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert axis angles to rotation 6d representations.
Args:
axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6).
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
if axis_angle.shape[-1] != 3:
raise ValueError(f'Invalid input axis_angle f{axis_angle.shape}.')
t = Compose([axis_angle_to_matrix, matrix_to_rotation_6d])
return t(axis_angle)
def ee_to_aa(euler_angle: Union[torch.Tensor, numpy.ndarray],
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
"""Convert euler angles to axis angles.
Args:
euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3). ndim of input is unlimited.
convention (str, optional): Convention string of three letters
from {“x”, “y”, and “z”}. Defaults to 'xyz'.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
"""
if euler_angle.shape[-1] != 3:
raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.')
t = Compose([
euler_angles_to_matrix, matrix_to_quaternion, quaternion_to_axis_angle
])
return t(euler_angle, convention)
def ee_to_quat(euler_angle: Union[torch.Tensor, numpy.ndarray],
convention='xyz') -> Union[torch.Tensor, numpy.ndarray]:
"""Convert euler angles to quaternions.
Args:
euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3). ndim of input is unlimited.
convention (str, optional): Convention string of three letters
from {“x”, “y”, and “z”}. Defaults to 'xyz'.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4).
"""
if euler_angle.shape[-1] != 3:
raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.')
t = Compose([euler_angles_to_matrix, matrix_to_quaternion])
return t(euler_angle, convention)
def ee_to_rot6d(euler_angle: Union[torch.Tensor, numpy.ndarray],
convention='xyz') -> Union[torch.Tensor, numpy.ndarray]:
"""Convert euler angles to rotation 6d representation.
Args:
euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3). ndim of input is unlimited.
convention (str, optional): Convention string of three letters
from {“x”, “y”, and “z”}. Defaults to 'xyz'.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6).
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
if euler_angle.shape[-1] != 3:
raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.')
t = Compose([euler_angles_to_matrix, matrix_to_rotation_6d])
return t(euler_angle, convention)
def rotmat_to_aa(
matrix: Union[torch.Tensor, numpy.ndarray],
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
"""Convert rotation matrixs to axis angles.
Args:
matrix (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 3, 3). ndim of input is unlimited.
convention (str, optional): Convention string of three letters
from {“x”, “y”, and “z”}. Defaults to 'xyz'.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
"""
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3:
raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
t = Compose([matrix_to_quaternion, quaternion_to_axis_angle])
return t(matrix)
def quat_to_ee(quaternions: Union[torch.Tensor, numpy.ndarray],
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
"""Convert quaternions to euler angles.
Args:
quaternions (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 4). ndim of input is unlimited.
convention (str, optional): Convention string of three letters
from {“x”, “y”, and “z”}. Defaults to 'xyz'.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
"""
if quaternions.shape[-1] != 4:
raise ValueError(f'Invalid input quaternions f{quaternions.shape}.')
t = Compose([quaternion_to_matrix, matrix_to_euler_angles])
return t(quaternions, convention)
def quat_to_rot6d(
quaternions: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert quaternions to rotation 6d representations.
Args:
quaternions (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 4). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6).
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
if quaternions.shape[-1] != 4:
raise ValueError(f'Invalid input quaternions f{quaternions.shape}.')
t = Compose([quaternion_to_matrix, matrix_to_rotation_6d])
return t(quaternions)
def rot6d_to_aa(
rotation_6d: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert rotation 6d representations to axis angles.
Args:
rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 6). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
if rotation_6d.shape[-1] != 6:
raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.')
t = Compose([
rotation_6d_to_matrix, matrix_to_quaternion, quaternion_to_axis_angle
])
return t(rotation_6d)
def rot6d_to_ee(rotation_6d: Union[torch.Tensor, numpy.ndarray],
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]:
"""Convert rotation 6d representations to euler angles.
Args:
rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 6). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
if rotation_6d.shape[-1] != 6:
raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.')
t = Compose([rotation_6d_to_matrix, matrix_to_euler_angles])
return t(rotation_6d, convention)
def rot6d_to_quat(
rotation_6d: Union[torch.Tensor, numpy.ndarray]
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert rotation 6d representations to quaternions.
Args:
rotation (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 6). ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4).
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
if rotation_6d.shape[-1] != 6:
raise ValueError(
f'Invalid input rotation_6d shape f{rotation_6d.shape}.')
t = Compose([rotation_6d_to_matrix, matrix_to_quaternion])
return t(rotation_6d)
def aa_to_sja(
axis_angle: Union[torch.Tensor, numpy.ndarray],
R_t: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_AA_TO_SJA,
R_t_inv: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_SJA_TO_AA
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert axis-angles to standard joint angles.
Args:
axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 21, 3), ndim of input is unlimited.
R_t (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 21, 3, 3). Transformation matrices from
original axis-angle coordinate system to
standard joint angle coordinate system,
ndim of input is unlimited.
R_t_inv (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 21, 3, 3). Transformation matrices from
standard joint angle coordinate system to
original axis-angle coordinate system,
ndim of input is unlimited.
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
"""
def _aa_to_sja(aa, R_t, R_t_inv):
R_aa = axis_angle_to_matrix(aa)
R_sja = R_t @ R_aa @ R_t_inv
sja = matrix_to_euler_angles(R_sja, convention='XYZ')
return sja
if axis_angle.shape[-2:] != (21, 3):
raise ValueError(
f'Invalid input axis angles shape f{axis_angle.shape}.')
if R_t.shape[-3:] != (21, 3, 3):
raise ValueError(f'Invalid input R_t shape f{R_t.shape}.')
if R_t_inv.shape[-3:] != (21, 3, 3):
raise ValueError(f'Invalid input R_t_inv shape f{R_t.shape}.')
t = Compose([_aa_to_sja])
return t(axis_angle, R_t=R_t, R_t_inv=R_t_inv)
def sja_to_aa(
sja: Union[torch.Tensor, numpy.ndarray],
R_t: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_AA_TO_SJA,
R_t_inv: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_SJA_TO_AA
) -> Union[torch.Tensor, numpy.ndarray]:
"""Convert standard joint angles to axis angles.
Args:
sja (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 21, 3). ndim of input is unlimited.
R_t (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 21, 3, 3). Transformation matrices from
original axis-angle coordinate system to
standard joint angle coordinate system
R_t_inv (Union[torch.Tensor, numpy.ndarray]): input shape
should be (..., 21, 3, 3). Transformation matrices from
standard joint angle coordinate system to
original axis-angle coordinate system
Returns:
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3).
"""
def _sja_to_aa(sja, R_t, R_t_inv):
R_sja = euler_angles_to_matrix(sja, convention='XYZ')
R_aa = R_t_inv @ R_sja @ R_t
aa = quaternion_to_axis_angle(matrix_to_quaternion(R_aa))
return aa
if sja.shape[-2:] != (21, 3):
raise ValueError(f'Invalid input axis angles shape f{sja.shape}.')
if R_t.shape[-3:] != (21, 3, 3):
raise ValueError(f'Invalid input R_t shape f{R_t.shape}.')
if R_t_inv.shape[-3:] != (21, 3, 3):
raise ValueError(f'Invalid input R_t_inv shape f{R_t.shape}.')
t = Compose([_sja_to_aa])
return t(sja, R_t=R_t, R_t_inv=R_t_inv)
def make_homegeneous_rotmat_batch(input: torch.Tensor) -> torch.Tensor:
"""Appends a row of [0,0,0,1] to a batch size x 3 x 4 Tensor.
Parameters
----------
:param input: A tensor of dimensions batch size x 3 x 4
:return: A tensor batch size x 4 x 4 (appended with 0,0,0,1)
"""
batch_size = input.shape[0]
row_append = torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float)
row_append.requires_grad = False
padded_tensor = torch.cat(
[input, row_append.view(1, 1, 4).repeat(batch_size, 1, 1)], dim=1)
return padded_tensor
def make_homegeneous_rotmat(input: torch.Tensor) -> torch.Tensor:
"""Appends a row of [0,0,0,1] to a 3 x 4 Tensor.
Parameters
----------
:param input: A tensor of dimensions 3 x 4
:return: A tensor batch size x 4 x 4 (appended with 0,0,0,1)
"""
row_append = torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float)
row_append.requires_grad = False
padded_tensor = torch.cat(input, row_append, dim=1)
return padded_tensor