|
from enum import Enum |
|
|
|
import yaml |
|
from easydict import EasyDict as edict |
|
import torch.nn as nn |
|
import torch |
|
|
|
|
|
def load_yaml(path): |
|
with open(path, 'r') as f: |
|
return edict(yaml.safe_load(f)) |
|
|
|
|
|
def move_to_device(obj, device): |
|
if isinstance(obj, nn.Module): |
|
return obj.to(device) |
|
if torch.is_tensor(obj): |
|
return obj.to(device) |
|
if isinstance(obj, (tuple, list)): |
|
return [move_to_device(el, device) for el in obj] |
|
if isinstance(obj, dict): |
|
return {name: move_to_device(val, device) for name, val in obj.items()} |
|
raise ValueError(f'Unexpected type {type(obj)}') |
|
|
|
|
|
class SmallMode(Enum): |
|
DROP = "drop" |
|
UPSCALE = "upscale" |
|
|