File size: 689 Bytes
d380b77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from enum import Enum
import yaml
from easydict import EasyDict as edict
import torch.nn as nn
import torch
def load_yaml(path):
with open(path, 'r') as f:
return edict(yaml.safe_load(f))
def move_to_device(obj, device):
if isinstance(obj, nn.Module):
return obj.to(device)
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, (tuple, list)):
return [move_to_device(el, device) for el in obj]
if isinstance(obj, dict):
return {name: move_to_device(val, device) for name, val in obj.items()}
raise ValueError(f'Unexpected type {type(obj)}')
class SmallMode(Enum):
DROP = "drop"
UPSCALE = "upscale"
|