|
import pdb |
|
|
|
import torch |
|
|
|
def lengths_to_mask(lengths, max_len): |
|
|
|
mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1) |
|
return mask |
|
|
|
|
|
def collate_tensors(batch): |
|
dims = batch[0].dim() |
|
max_size = [max([b.size(i) for b in batch]) for i in range(dims)] |
|
size = (len(batch),) + tuple(max_size) |
|
canvas = batch[0].new_zeros(size=size) |
|
for i, b in enumerate(batch): |
|
sub_tensor = canvas[i] |
|
for d in range(dims): |
|
sub_tensor = sub_tensor.narrow(d, 0, b.size(d)) |
|
sub_tensor.add_(b) |
|
return canvas |
|
|
|
|
|
def collate(batch): |
|
notnone_batches = [b for b in batch if b is not None] |
|
databatch = [b['inp'] for b in notnone_batches] |
|
if 'lengths' in notnone_batches[0]: |
|
lenbatch = [b['lengths'] for b in notnone_batches] |
|
else: |
|
lenbatch = [len(b['inp'][0][0]) for b in notnone_batches] |
|
|
|
|
|
databatchTensor = collate_tensors(databatch) |
|
lenbatchTensor = torch.as_tensor(lenbatch) |
|
maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(1) |
|
|
|
motion = databatchTensor |
|
cond = {'y': {'mask': maskbatchTensor, 'lengths': lenbatchTensor}} |
|
|
|
if 'text' in notnone_batches[0]: |
|
textbatch = [b['text'] for b in notnone_batches] |
|
cond['y'].update({'text': textbatch}) |
|
|
|
if 'tokens' in notnone_batches[0]: |
|
textbatch = [b['tokens'] for b in notnone_batches] |
|
cond['y'].update({'tokens': textbatch}) |
|
|
|
if 'action' in notnone_batches[0]: |
|
actionbatch = [b['action'] for b in notnone_batches] |
|
cond['y'].update({'action': torch.as_tensor(actionbatch).unsqueeze(1)}) |
|
|
|
|
|
if 'action_text' in notnone_batches[0]: |
|
action_text = [b['action_text']for b in notnone_batches] |
|
cond['y'].update({'action_text': action_text}) |
|
|
|
return motion, cond |
|
|
|
|
|
def t2m_collate(batch): |
|
|
|
adapted_batch = [{ |
|
'inp': torch.tensor(b[4].T).float().unsqueeze(1), |
|
'text': b[2], |
|
'tokens': b[6], |
|
'lengths': b[5], |
|
} for b in batch] |
|
return collate(adapted_batch) |
|
|
|
|
|
|