Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
import torch | |
from torch import nn | |
from torch.autograd import Variable | |
from torch.nn.parameter import Parameter | |
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) | |
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) | |
def conversion_helper(val, conversion): | |
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" | |
if not isinstance(val, (tuple, list)): | |
return conversion(val) | |
rtn = [conversion_helper(v, conversion) for v in val] | |
if isinstance(val, tuple): | |
rtn = tuple(rtn) | |
return rtn | |
def fp32_to_fp16(val): | |
"""Convert fp32 `val` to fp16""" | |
def half_conversion(val): | |
val_typecheck = val | |
if isinstance(val_typecheck, (Parameter, Variable)): | |
val_typecheck = val.data | |
if isinstance(val_typecheck, FLOAT_TYPES): | |
val = val.half() | |
return val | |
return conversion_helper(val, half_conversion) | |
def fp16_to_fp32(val): | |
"""Convert fp16 `val` to fp32""" | |
def float_conversion(val): | |
val_typecheck = val | |
if isinstance(val_typecheck, (Parameter, Variable)): | |
val_typecheck = val.data | |
if isinstance(val_typecheck, HALF_TYPES): | |
val = val.float() | |
return val | |
return conversion_helper(val, float_conversion) | |
class FP16Module(nn.Module): | |
def __init__(self, module): | |
super(FP16Module, self).__init__() | |
self.add_module('module', module.half()) | |
def forward(self, *inputs, **kwargs): | |
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) | |
def state_dict(self, destination=None, prefix='', keep_vars=False): | |
return self.module.state_dict(destination, prefix, keep_vars) | |
def load_state_dict(self, state_dict, strict=True): | |
self.module.load_state_dict(state_dict, strict=strict) | |
def get_param(self, item): | |
return self.module.get_param(item) | |
def to(self, device, *args, **kwargs): | |
self.module.to(device) | |
return super().to(device, *args, **kwargs) |