Spaces:
Running
on
L40S
Running
on
L40S
from typing import Union | |
import numpy | |
import torch | |
from detrsmpl.core.conventions.joints_mapping.standard_joint_angles import ( | |
TRANSFORMATION_AA_TO_SJA, | |
TRANSFORMATION_SJA_TO_AA, | |
) | |
from .logger import get_root_logger | |
try: | |
from pytorch3d.transforms import ( | |
axis_angle_to_matrix, | |
axis_angle_to_quaternion, | |
euler_angles_to_matrix, | |
matrix_to_euler_angles, | |
matrix_to_quaternion, | |
matrix_to_rotation_6d, | |
quaternion_to_axis_angle, | |
quaternion_to_matrix, | |
rotation_6d_to_matrix, | |
) | |
except (ImportError, ModuleNotFoundError): | |
import traceback | |
logger = get_root_logger() | |
stack_str = '' | |
for line in traceback.format_stack(): | |
if 'frozen' not in line: | |
stack_str += line + '\n' | |
import_exception = traceback.format_exc() + '\n' | |
warning_msg = stack_str + import_exception + \ | |
'If pytorch3d is not required,' +\ | |
' this warning could be ignored.' | |
logger.warning(warning_msg) | |
class Compose: | |
def __init__(self, transforms: list): | |
"""Composes several transforms together. This transform does not | |
support torchscript. | |
Args: | |
transforms (list): (list of transform functions) | |
""" | |
self.transforms = transforms | |
def __call__(self, | |
rotation: Union[torch.Tensor, numpy.ndarray], | |
convention: str = 'xyz', | |
**kwargs): | |
convention = convention.lower() | |
if not (set(convention) == set('xyz') and len(convention) == 3): | |
raise ValueError(f'Invalid convention {convention}.') | |
if isinstance(rotation, numpy.ndarray): | |
data_type = 'numpy' | |
rotation = torch.FloatTensor(rotation) | |
elif isinstance(rotation, torch.Tensor): | |
data_type = 'tensor' | |
else: | |
raise TypeError( | |
'Type of rotation should be torch.Tensor or numpy.ndarray') | |
for t in self.transforms: | |
if 'convention' in t.__code__.co_varnames: | |
rotation = t(rotation, convention.upper(), **kwargs) | |
else: | |
rotation = t(rotation, **kwargs) | |
if data_type == 'numpy': | |
rotation = rotation.detach().cpu().numpy() | |
return rotation | |
def aa_to_rotmat( | |
axis_angle: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
""" | |
Convert axis_angle to rotation matrixs. | |
Args: | |
axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3). | |
""" | |
if axis_angle.shape[-1] != 3: | |
raise ValueError( | |
f'Invalid input axis angles shape f{axis_angle.shape}.') | |
t = Compose([axis_angle_to_matrix]) | |
return t(axis_angle) | |
def aa_to_quat( | |
axis_angle: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
""" | |
Convert axis_angle to quaternions. | |
Args: | |
axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4). | |
""" | |
if axis_angle.shape[-1] != 3: | |
raise ValueError(f'Invalid input axis angles f{axis_angle.shape}.') | |
t = Compose([axis_angle_to_quaternion]) | |
return t(axis_angle) | |
def ee_to_rotmat(euler_angle: Union[torch.Tensor, numpy.ndarray], | |
convention='xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert euler angle to rotation matrixs. | |
Args: | |
euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3). ndim of input is unlimited. | |
convention (str, optional): Convention string of three letters | |
from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3). | |
""" | |
if euler_angle.shape[-1] != 3: | |
raise ValueError( | |
f'Invalid input euler angles shape f{euler_angle.shape}.') | |
t = Compose([euler_angles_to_matrix]) | |
return t(euler_angle, convention.upper()) | |
def rotmat_to_ee( | |
matrix: Union[torch.Tensor, numpy.ndarray], | |
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert rotation matrixs to euler angle. | |
Args: | |
matrix (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3, 3). ndim of input is unlimited. | |
convention (str, optional): Convention string of three letters | |
from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
""" | |
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3: | |
raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.') | |
t = Compose([matrix_to_euler_angles]) | |
return t(matrix, convention.upper()) | |
def rotmat_to_quat( | |
matrix: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert rotation matrixs to quaternions. | |
Args: | |
matrix (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3, 3). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4). | |
""" | |
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3: | |
raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.') | |
t = Compose([matrix_to_quaternion]) | |
return t(matrix) | |
def rotmat_to_rot6d( | |
matrix: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert rotation matrixs to rotation 6d representations. | |
Args: | |
matrix (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3, 3). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6). | |
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
On the Continuity of Rotation Representations in Neural Networks. | |
IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
Retrieved from http://arxiv.org/abs/1812.07035 | |
""" | |
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3: | |
raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.') | |
t = Compose([matrix_to_rotation_6d]) | |
return t(matrix) | |
def quat_to_aa( | |
quaternions: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert quaternions to axis angles. | |
Args: | |
quaternions (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
""" | |
if quaternions.shape[-1] != 4: | |
raise ValueError(f'Invalid input quaternions f{quaternions.shape}.') | |
t = Compose([quaternion_to_axis_angle]) | |
return t(quaternions) | |
def quat_to_rotmat( | |
quaternions: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert quaternions to rotation matrixs. | |
Args: | |
quaternions (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3). | |
""" | |
if quaternions.shape[-1] != 4: | |
raise ValueError( | |
f'Invalid input quaternions shape f{quaternions.shape}.') | |
t = Compose([quaternion_to_matrix]) | |
return t(quaternions) | |
def rot6d_to_rotmat( | |
rotation_6d: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert rotation 6d representations to rotation matrixs. | |
Args: | |
rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 6). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3, 3). | |
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
On the Continuity of Rotation Representations in Neural Networks. | |
IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
Retrieved from http://arxiv.org/abs/1812.07035 | |
""" | |
if rotation_6d.shape[-1] != 6: | |
raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.') | |
t = Compose([rotation_6d_to_matrix]) | |
return t(rotation_6d) | |
def aa_to_ee(axis_angle: Union[torch.Tensor, numpy.ndarray], | |
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert axis angles to euler angle. | |
Args: | |
axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3). ndim of input is unlimited. | |
convention (str, optional): Convention string of three letters | |
from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
""" | |
if axis_angle.shape[-1] != 3: | |
raise ValueError( | |
f'Invalid input axis_angle shape f{axis_angle.shape}.') | |
t = Compose([axis_angle_to_matrix, matrix_to_euler_angles]) | |
return t(axis_angle, convention) | |
def aa_to_rot6d( | |
axis_angle: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert axis angles to rotation 6d representations. | |
Args: | |
axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6). | |
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
On the Continuity of Rotation Representations in Neural Networks. | |
IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
Retrieved from http://arxiv.org/abs/1812.07035 | |
""" | |
if axis_angle.shape[-1] != 3: | |
raise ValueError(f'Invalid input axis_angle f{axis_angle.shape}.') | |
t = Compose([axis_angle_to_matrix, matrix_to_rotation_6d]) | |
return t(axis_angle) | |
def ee_to_aa(euler_angle: Union[torch.Tensor, numpy.ndarray], | |
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert euler angles to axis angles. | |
Args: | |
euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3). ndim of input is unlimited. | |
convention (str, optional): Convention string of three letters | |
from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
""" | |
if euler_angle.shape[-1] != 3: | |
raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.') | |
t = Compose([ | |
euler_angles_to_matrix, matrix_to_quaternion, quaternion_to_axis_angle | |
]) | |
return t(euler_angle, convention) | |
def ee_to_quat(euler_angle: Union[torch.Tensor, numpy.ndarray], | |
convention='xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert euler angles to quaternions. | |
Args: | |
euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3). ndim of input is unlimited. | |
convention (str, optional): Convention string of three letters | |
from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4). | |
""" | |
if euler_angle.shape[-1] != 3: | |
raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.') | |
t = Compose([euler_angles_to_matrix, matrix_to_quaternion]) | |
return t(euler_angle, convention) | |
def ee_to_rot6d(euler_angle: Union[torch.Tensor, numpy.ndarray], | |
convention='xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert euler angles to rotation 6d representation. | |
Args: | |
euler_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3). ndim of input is unlimited. | |
convention (str, optional): Convention string of three letters | |
from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6). | |
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
On the Continuity of Rotation Representations in Neural Networks. | |
IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
Retrieved from http://arxiv.org/abs/1812.07035 | |
""" | |
if euler_angle.shape[-1] != 3: | |
raise ValueError(f'Invalid input euler_angle f{euler_angle.shape}.') | |
t = Compose([euler_angles_to_matrix, matrix_to_rotation_6d]) | |
return t(euler_angle, convention) | |
def rotmat_to_aa( | |
matrix: Union[torch.Tensor, numpy.ndarray], | |
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert rotation matrixs to axis angles. | |
Args: | |
matrix (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 3, 3). ndim of input is unlimited. | |
convention (str, optional): Convention string of three letters | |
from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
""" | |
if matrix.shape[-1] != 3 or matrix.shape[-2] != 3: | |
raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.') | |
t = Compose([matrix_to_quaternion, quaternion_to_axis_angle]) | |
return t(matrix) | |
def quat_to_ee(quaternions: Union[torch.Tensor, numpy.ndarray], | |
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert quaternions to euler angles. | |
Args: | |
quaternions (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 4). ndim of input is unlimited. | |
convention (str, optional): Convention string of three letters | |
from {“x”, “y”, and “z”}. Defaults to 'xyz'. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
""" | |
if quaternions.shape[-1] != 4: | |
raise ValueError(f'Invalid input quaternions f{quaternions.shape}.') | |
t = Compose([quaternion_to_matrix, matrix_to_euler_angles]) | |
return t(quaternions, convention) | |
def quat_to_rot6d( | |
quaternions: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert quaternions to rotation 6d representations. | |
Args: | |
quaternions (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 4). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 6). | |
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
On the Continuity of Rotation Representations in Neural Networks. | |
IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
Retrieved from http://arxiv.org/abs/1812.07035 | |
""" | |
if quaternions.shape[-1] != 4: | |
raise ValueError(f'Invalid input quaternions f{quaternions.shape}.') | |
t = Compose([quaternion_to_matrix, matrix_to_rotation_6d]) | |
return t(quaternions) | |
def rot6d_to_aa( | |
rotation_6d: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert rotation 6d representations to axis angles. | |
Args: | |
rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 6). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
On the Continuity of Rotation Representations in Neural Networks. | |
IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
Retrieved from http://arxiv.org/abs/1812.07035 | |
""" | |
if rotation_6d.shape[-1] != 6: | |
raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.') | |
t = Compose([ | |
rotation_6d_to_matrix, matrix_to_quaternion, quaternion_to_axis_angle | |
]) | |
return t(rotation_6d) | |
def rot6d_to_ee(rotation_6d: Union[torch.Tensor, numpy.ndarray], | |
convention: str = 'xyz') -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert rotation 6d representations to euler angles. | |
Args: | |
rotation_6d (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 6). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
On the Continuity of Rotation Representations in Neural Networks. | |
IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
Retrieved from http://arxiv.org/abs/1812.07035 | |
""" | |
if rotation_6d.shape[-1] != 6: | |
raise ValueError(f'Invalid input rotation_6d f{rotation_6d.shape}.') | |
t = Compose([rotation_6d_to_matrix, matrix_to_euler_angles]) | |
return t(rotation_6d, convention) | |
def rot6d_to_quat( | |
rotation_6d: Union[torch.Tensor, numpy.ndarray] | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert rotation 6d representations to quaternions. | |
Args: | |
rotation (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 6). ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 4). | |
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H. | |
On the Continuity of Rotation Representations in Neural Networks. | |
IEEE Conference on Computer Vision and Pattern Recognition, 2019. | |
Retrieved from http://arxiv.org/abs/1812.07035 | |
""" | |
if rotation_6d.shape[-1] != 6: | |
raise ValueError( | |
f'Invalid input rotation_6d shape f{rotation_6d.shape}.') | |
t = Compose([rotation_6d_to_matrix, matrix_to_quaternion]) | |
return t(rotation_6d) | |
def aa_to_sja( | |
axis_angle: Union[torch.Tensor, numpy.ndarray], | |
R_t: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_AA_TO_SJA, | |
R_t_inv: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_SJA_TO_AA | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert axis-angles to standard joint angles. | |
Args: | |
axis_angle (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 21, 3), ndim of input is unlimited. | |
R_t (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 21, 3, 3). Transformation matrices from | |
original axis-angle coordinate system to | |
standard joint angle coordinate system, | |
ndim of input is unlimited. | |
R_t_inv (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 21, 3, 3). Transformation matrices from | |
standard joint angle coordinate system to | |
original axis-angle coordinate system, | |
ndim of input is unlimited. | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
""" | |
def _aa_to_sja(aa, R_t, R_t_inv): | |
R_aa = axis_angle_to_matrix(aa) | |
R_sja = R_t @ R_aa @ R_t_inv | |
sja = matrix_to_euler_angles(R_sja, convention='XYZ') | |
return sja | |
if axis_angle.shape[-2:] != (21, 3): | |
raise ValueError( | |
f'Invalid input axis angles shape f{axis_angle.shape}.') | |
if R_t.shape[-3:] != (21, 3, 3): | |
raise ValueError(f'Invalid input R_t shape f{R_t.shape}.') | |
if R_t_inv.shape[-3:] != (21, 3, 3): | |
raise ValueError(f'Invalid input R_t_inv shape f{R_t.shape}.') | |
t = Compose([_aa_to_sja]) | |
return t(axis_angle, R_t=R_t, R_t_inv=R_t_inv) | |
def sja_to_aa( | |
sja: Union[torch.Tensor, numpy.ndarray], | |
R_t: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_AA_TO_SJA, | |
R_t_inv: Union[torch.Tensor, numpy.ndarray] = TRANSFORMATION_SJA_TO_AA | |
) -> Union[torch.Tensor, numpy.ndarray]: | |
"""Convert standard joint angles to axis angles. | |
Args: | |
sja (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 21, 3). ndim of input is unlimited. | |
R_t (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 21, 3, 3). Transformation matrices from | |
original axis-angle coordinate system to | |
standard joint angle coordinate system | |
R_t_inv (Union[torch.Tensor, numpy.ndarray]): input shape | |
should be (..., 21, 3, 3). Transformation matrices from | |
standard joint angle coordinate system to | |
original axis-angle coordinate system | |
Returns: | |
Union[torch.Tensor, numpy.ndarray]: shape would be (..., 3). | |
""" | |
def _sja_to_aa(sja, R_t, R_t_inv): | |
R_sja = euler_angles_to_matrix(sja, convention='XYZ') | |
R_aa = R_t_inv @ R_sja @ R_t | |
aa = quaternion_to_axis_angle(matrix_to_quaternion(R_aa)) | |
return aa | |
if sja.shape[-2:] != (21, 3): | |
raise ValueError(f'Invalid input axis angles shape f{sja.shape}.') | |
if R_t.shape[-3:] != (21, 3, 3): | |
raise ValueError(f'Invalid input R_t shape f{R_t.shape}.') | |
if R_t_inv.shape[-3:] != (21, 3, 3): | |
raise ValueError(f'Invalid input R_t_inv shape f{R_t.shape}.') | |
t = Compose([_sja_to_aa]) | |
return t(sja, R_t=R_t, R_t_inv=R_t_inv) | |
def make_homegeneous_rotmat_batch(input: torch.Tensor) -> torch.Tensor: | |
"""Appends a row of [0,0,0,1] to a batch size x 3 x 4 Tensor. | |
Parameters | |
---------- | |
:param input: A tensor of dimensions batch size x 3 x 4 | |
:return: A tensor batch size x 4 x 4 (appended with 0,0,0,1) | |
""" | |
batch_size = input.shape[0] | |
row_append = torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float) | |
row_append.requires_grad = False | |
padded_tensor = torch.cat( | |
[input, row_append.view(1, 1, 4).repeat(batch_size, 1, 1)], dim=1) | |
return padded_tensor | |
def make_homegeneous_rotmat(input: torch.Tensor) -> torch.Tensor: | |
"""Appends a row of [0,0,0,1] to a 3 x 4 Tensor. | |
Parameters | |
---------- | |
:param input: A tensor of dimensions 3 x 4 | |
:return: A tensor batch size x 4 x 4 (appended with 0,0,0,1) | |
""" | |
row_append = torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float) | |
row_append.requires_grad = False | |
padded_tensor = torch.cat(input, row_append, dim=1) | |
return padded_tensor | |