AiOS / detrsmpl /utils /misc.py
ttxskk
update
d7e58f0
raw
history blame
869 Bytes
from functools import partial
import torch
def multi_apply(func, *args, **kwargs):
"""Apply function to a list of arguments.
Note:
This function applies the ``func`` to multiple inputs and
map the multiple outputs of the ``func`` into different
list. Each list contains the same type of outputs corresponding
to different inputs.
Args:
func (Function): A function that will be applied to a list of
arguments
Returns:
tuple(list): A tuple containing multiple list, each list contains \
a kind of returned results by the function
"""
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
def torch_to_numpy(x):
assert isinstance(x, torch.Tensor)
return x.detach().cpu().numpy()