MotionLCM / mld /utils /temos_utils.py
wxDai's picture
init
6b1e9f7
raw
history blame contribute delete
No virus
565 Bytes
import torch
def lengths_to_mask(lengths: list[int],
device: torch.device,
max_len: int = None) -> torch.Tensor:
lengths = torch.tensor(lengths, device=device)
max_len = max_len if max_len else max(lengths)
mask = torch.arange(max_len, device=device).expand(
len(lengths), max_len) < lengths.unsqueeze(1)
return mask
def remove_padding(tensors: torch.Tensor, lengths: list[int]) -> list:
return [
tensor[:tensor_length]
for tensor, tensor_length in zip(tensors, lengths)
]