|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, fields |
|
|
|
|
|
class Transform: |
|
|
|
def collate(self, lst_datastruct): |
|
from ..tools import collate_tensor_with_padding |
|
example = lst_datastruct[0] |
|
|
|
def collate_or_none(key): |
|
if example[key] is None: |
|
return None |
|
key_lst = [x[key] for x in lst_datastruct] |
|
return collate_tensor_with_padding(key_lst) |
|
|
|
kwargs = {key: collate_or_none(key) for key in example.datakeys} |
|
|
|
return self.Datastruct(**kwargs) |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class Datastruct: |
|
|
|
def __getitem__(self, key): |
|
return getattr(self, key) |
|
|
|
def __setitem__(self, key, value): |
|
self.__dict__[key] = value |
|
|
|
def get(self, key, default=None): |
|
return getattr(self, key, default) |
|
|
|
def __iter__(self): |
|
return self.keys() |
|
|
|
def keys(self): |
|
keys = [t.name for t in fields(self)] |
|
return iter(keys) |
|
|
|
def values(self): |
|
values = [getattr(self, t.name) for t in fields(self)] |
|
return iter(values) |
|
|
|
def items(self): |
|
data = [(t.name, getattr(self, t.name)) for t in fields(self)] |
|
return iter(data) |
|
|
|
def to(self, *args, **kwargs): |
|
for key in self.datakeys: |
|
if self[key] is not None: |
|
self[key] = self[key].to(*args, **kwargs) |
|
return self |
|
|
|
@property |
|
def device(self): |
|
return self[self.datakeys[0]].device |
|
|
|
def detach(self): |
|
|
|
def detach_or_none(tensor): |
|
if tensor is not None: |
|
return tensor.detach() |
|
return None |
|
|
|
kwargs = {key: detach_or_none(self[key]) for key in self.datakeys} |
|
return self.transforms.Datastruct(**kwargs) |
|
|