EdgeTA / utils /dl /common /model.py
LINC-BIT's picture
Upload 1912 files
b84549f verified
import enum
import time
from typing import List, Tuple, Type
import torch
import warnings
import os
import thop
from ...common.others import get_cur_time_str
class ModelSaveMethod(enum.Enum):
"""
- WEIGHT: save model by `torch.save(model.state_dict(), ...)`
- FULL: save model by `torch.save(model, ...)`
- JIT: convert model to JIT format and save it by `torch.jit.save(jit_model, ...)`
"""
WEIGHT = 0
FULL = 1
JIT = 2
def save_model(model: torch.nn.Module,
model_file_path: str,
save_method: ModelSaveMethod,
model_input_size: Tuple[int]=None):
"""Save a PyTorch model.
Args:
model (torch.nn.Module): A PyTorch model.
model_file_path (str): Target model file path.
save_method (ModelSaveMethod): The method to save model.
model_input_size (Tuple[int], optional): \
This is required if :attr:`save_method` is :attr:`ModelSaveMethod.JIT`. \
Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. \
Defaults to None.
"""
model.eval()
if save_method == ModelSaveMethod.WEIGHT:
torch.save(model.state_dict(), model_file_path)
elif save_method == ModelSaveMethod.FULL:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
torch.save(model, model_file_path)
elif save_method == ModelSaveMethod.JIT:
assert model_input_size is not None
dummy_input = torch.ones(model_input_size, device=get_model_device(model))
new_model = torch.jit.trace(model, dummy_input, check_trace=False)
torch.jit.save(new_model, model_file_path)
def get_model_size(model: torch.nn.Module, return_MB=False):
"""Get size of a PyTorch model (default in Byte).
Args:
model (torch.nn.Module): A PyTorch model.
return_MB (bool, optional): Return result in MB (/= 1024**2). Defaults to False.
Returns:
int: Model size.
"""
pid = os.getpid()
tmp_model_file_path = './tmp-get-model-size-{}-{}.model'.format(pid, get_cur_time_str())
save_model(model, tmp_model_file_path, ModelSaveMethod.WEIGHT)
model_size = os.path.getsize(tmp_model_file_path)
os.remove(tmp_model_file_path)
if return_MB:
model_size /= 1024**2
return model_size
def get_model_device(model: torch.nn.Module):
"""Get device of a PyTorch model.
Args:
model (torch.nn.Module): A PyTorch model.
Returns:
str: The device of :attr:`model` ('cpu' or 'cuda:x').
"""
return list(model.parameters())[0].device
def get_model_latency(model: torch.nn.Module, model_input_size: Tuple[int], sample_num: int,
device: str, warmup_sample_num: int, return_detail=False):
"""Get the latency (inference time) of a PyTorch model.
Reference: https://deci.ai/resources/blog/measure-inference-time-deep-neural-networks/
Args:
model (torch.nn.Module): A PyTorch model.
model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`.
sample_num (int): How many inputs which size is :attr:`model_input_size` will be tested and compute the average latency as result.
device (str): Typically be 'cpu' or 'cuda'.
warmup_sample_num (int): Let model perform some dummy inference to warm up the test environment to avoid measurement loss.
return_detail (bool, optional): Beside the average latency, return all result measured. Defaults to False.
Returns:
Union[float, Tuple[float, List[float]]]: The average latency (and all lantecy data) of :attr:`model`.
"""
if isinstance(model_input_size, tuple):
dummy_input = torch.rand(model_input_size).to(device)
else:
dummy_input = model_input_size
model = model.to(device)
model.eval()
# warm up
with torch.no_grad():
for _ in range(warmup_sample_num):
model(dummy_input)
infer_time_list = []
if device == 'cuda' or 'cuda' in str(device):
with torch.no_grad():
for _ in range(sample_num):
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
s.record()
model(dummy_input)
e.record()
torch.cuda.synchronize()
cur_model_infer_time = s.elapsed_time(e) / 1000.
infer_time_list += [cur_model_infer_time]
else:
with torch.no_grad():
for _ in range(sample_num):
start = time.time()
model(dummy_input)
cur_model_infer_time = time.time() - start
infer_time_list += [cur_model_infer_time]
avg_infer_time = sum(infer_time_list) / sample_num
if return_detail:
return avg_infer_time, infer_time_list
return avg_infer_time
def get_model_flops_and_params(model: torch.nn.Module, model_input_size: Tuple[int], return_M=False):
"""Get FLOPs and number of parameters of a PyTorch model.
Args:
model (torch.nn.Module): A PyTorch model.
model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`.
Returns:
Tuple[float, float]: FLOPs and number of parameters of :attr:`model`.
"""
device = get_model_device(model)
ops, param = thop.profile(model, (torch.ones(model_input_size).to(device), ), verbose=False)
ops, param = ops * 2, param
if return_M:
ops, param = ops / 1e6, param / 1e6
return ops, param
def get_module(model: torch.nn.Module, module_name: str):
"""Get a module from a PyTorch model.
Example:
>>> from torchvision.models import resnet18
>>> model = resnet18()
>>> get_module(model, 'layer1.0')
BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
Args:
model (torch.nn.Module): A PyTorch model.
module_name (str): Module name.
Returns:
torch.nn.Module: Corrsponding module.
"""
for name, module in model.named_modules():
if name == module_name:
return module
return None
def get_parameter(model: torch.nn.Module, param_name: str):
return getattr(
get_module(model, '.'.join(param_name.split('.')[0: -1])),
param_name.split('.')[-1]
)
def get_super_module(model: torch.nn.Module, module_name: str):
"""Get the super module of a module in a PyTorch model.
Example:
>>> from torchvision.models import resnet18
>>> model = resnet18()
>>> get_super_module(model, 'layer1.0.conv1')
BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
Args:
model (torch.nn.Module): A PyTorch model.
module_name (str): Module name.
Returns:
torch.nn.Module: Super module of module :attr:`module_name`.
"""
super_module_name = '.'.join(module_name.split('.')[0:-1])
return get_module(model, super_module_name)
def set_module(model: torch.nn.Module, module_name: str, module: torch.nn.Module):
"""Set module in a PyTorch model.
Example:
>>> from torchvision.models import resnet18
>>> model = resnet18()
>>> set_module(model, 'layer1.0', torch.nn.Conv2d(64, 64, 3))
>>> model
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
--> (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BasicBlock(
...
)
...
)
...
)
Args:
model (torch.nn.Module): A PyTorch model.
module_name (str): Module name.
module (torch.nn.Module): Target module which will be set into :attr:`model`.
"""
super_module = get_super_module(model, module_name)
setattr(super_module, module_name.split('.')[-1], module)
def get_ith_layer(model: torch.nn.Module, i: int):
"""Get i-th layer in a PyTorch model.
Example:
>>> from torchvision.models import vgg16
>>> model = vgg16()
>>> get_ith_layer(model, 5)
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
Args:
model (torch.nn.Module): A PyTorch model.
i (int): Index of target layer.
Returns:
torch.nn.Module: i-th layer in :attr:`model`.
"""
j = 0
for module in model.modules():
if len(list(module.children())) > 0:
continue
if j == i:
return module
j += 1
return None
def get_ith_layer_name(model: torch.nn.Module, i: int):
"""Get the name of i-th layer in a PyTorch model.
Example:
>>> from torchvision.models import vgg16
>>> model = vgg16()
>>> get_ith_layer_name(model, 5)
'features.5'
Args:
model (torch.nn.Module): A PyTorch model.
i (int): Index of target layer.
Returns:
str: The name of i-th layer in :attr:`model`.
"""
j = 0
for name, module in model.named_modules():
if len(list(module.children())) > 0:
continue
if j == i:
return name
j += 1
return None
def set_ith_layer(model: torch.nn.Module, i: int, layer: torch.nn.Module):
"""Set i-th layer in a PyTorch model.
Example:
>>> from torchvision.models import vgg16
>>> model = vgg16()
>>> model
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
...
)
...
)
>>> set_ith_layer(model, 2, torch.nn.Conv2d(64, 128, 3))
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
--> (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
...
)
...
)
Args:
model (torch.nn.Module): A PyTorch model.
i (int): Index of target layer.
layer (torch.nn.Module): The layer which will be set into :attr:`model`.
"""
j = 0
for name, module in model.named_modules():
if len(list(module.children())) > 0:
continue
if j == i:
set_module(model, name, layer)
return
j += 1
def get_all_specific_type_layers_name(model: torch.nn.Module, types: Tuple[Type[torch.nn.Module]]):
"""Get names of all layers which are give types in a PyTorch model. (e.g. `Conv2d`, `Linear`)
Example:
>>> from torchvision.models import vgg16
>>> model = vgg16()
>>> get_all_specific_type_layers_name(model, (torch.nn.Conv2d))
['features.0', 'features.2', 'features.5', ...]
Args:
model (torch.nn.Module): A PyTorch model.
types (Tuple[Type[torch.nn.Module]]): Target types, e.g. `(e.g. torch.nn.Conv2d, torch.nn.Linear)`
Returns:
List[str]: Names of all layers which are give types.
"""
res = []
for name, m in model.named_modules():
if isinstance(m, types):
res += [name]
return res
class LayerActivation:
"""Collect the input and output of a middle module of a PyTorch model during inference.
Layer is a wide concept in this class. A module (e.g. ResBlock in ResNet) can be also regarded as a "layer".
Example:
>>> from torchvision.models import vgg16
>>> model = vgg16()
>>> # collect the input and output of 5th layer in VGG16
>>> layer_activation = LayerActivation(get_ith_layer(model, 5), 'cuda')
>>> model(torch.rand((1, 3, 224, 224)))
>>> layer_activation.input
tensor([[...]])
>>> layer_activation.output
tensor([[...]])
>>> layer_activation.remove()
"""
def __init__(self, layer: torch.nn.Module, detach: bool, device: str):
"""Register forward hook on corresponding layer.
Args:
layer (torch.nn.Module): Target layer.
device (str): Where the collected data is located.
"""
self.hook = layer.register_forward_hook(self._hook_fn)
self.detach = detach
self.device = device
self.input: torch.Tensor = None
self.output: torch.Tensor = None
self.layer = layer
def __str__(self):
return '- ' + str(self.layer)
def _hook_fn(self, module, input, output):
# TODO: input or output may be a tuple
if isinstance(input, tuple):
self.input = input[0].to(self.device)
else:
self.input = input.to(self.device)
if isinstance(output, tuple):
self.output = output[0].to(self.device)
else:
self.output = output.to(self.device)
if self.detach:
self.input = self.input.detach()
self.output = self.output.detach()
def remove(self):
"""Remove the hook in the model to avoid performance effect.
Use this after using the collected data.
"""
self.hook.remove()
class LayerActivation2:
"""Collect the input and output of a middle module of a PyTorch model during inference.
Layer is a wide concept in this class. A module (e.g. ResBlock in ResNet) can be also regarded as a "layer".
Example:
>>> from torchvision.models import vgg16
>>> model = vgg16()
>>> # collect the input and output of 5th layer in VGG16
>>> layer_activation = LayerActivation(get_ith_layer(model, 5), 'cuda')
>>> model(torch.rand((1, 3, 224, 224)))
>>> layer_activation.input
tensor([[...]])
>>> layer_activation.output
tensor([[...]])
>>> layer_activation.remove()
"""
def __init__(self, layer: torch.nn.Module):
"""Register forward hook on corresponding layer.
Args:
layer (torch.nn.Module): Target layer.
device (str): Where the collected data is located.
"""
assert layer is not None
self.hook = layer.register_forward_hook(self._hook_fn)
self.input: torch.Tensor = None
self.output: torch.Tensor = None
self.layer = layer
def __str__(self):
return '- ' + str(self.layer)
def _hook_fn(self, module, input, output):
self.input = input
self.output = output
def remove(self):
"""Remove the hook in the model to avoid performance effect.
Use this after using the collected data.
"""
self.hook.remove()
class LayerActivation3:
"""Collect the input and output of a middle module of a PyTorch model during inference.
Layer is a wide concept in this class. A module (e.g. ResBlock in ResNet) can be also regarded as a "layer".
Example:
>>> from torchvision.models import vgg16
>>> model = vgg16()
>>> # collect the input and output of 5th layer in VGG16
>>> layer_activation = LayerActivation(get_ith_layer(model, 5), 'cuda')
>>> model(torch.rand((1, 3, 224, 224)))
>>> layer_activation.input
tensor([[...]])
>>> layer_activation.output
tensor([[...]])
>>> layer_activation.remove()
"""
def __init__(self, layer: torch.nn.Module, detach: bool, device: str):
"""Register forward hook on corresponding layer.
Args:
layer (torch.nn.Module): Target layer.
device (str): Where the collected data is located.
"""
self.hook = layer.register_forward_hook(self._hook_fn)
self.detach = detach
self.device = device
self.input: torch.Tensor = None
self.output: torch.Tensor = None
self.layer = layer
def __str__(self):
return '- ' + str(self.layer)
def _hook_fn(self, module, input, output):
# TODO: input or output may be a tuple
self.input = input
self.output = output
# if self.detach:
# self.input = self.input.detach()
# self.output = self.output.detach()
def remove(self):
"""Remove the hook in the model to avoid performance effect.
Use this after using the collected data.
"""
self.hook.remove()
class LayerActivationWrapper:
"""A wrapper of :attr:`LayerActivation` which has the same API, but broaden the concept "layer".
Now a series of layers can be regarded as "hyper-layer" in this class.
Example:
>>> from torchvision.models import vgg16
>>> model = vgg16()
>>> # collect the input of 5th layer, and output of 7th layer in VGG16
>>> # i.e. regard 5th~7th layer as a whole module,
>>> # and collect the input and output of this module
>>> layer_activation = LayerActivationWrapper([
LayerActivation(get_ith_layer(model, 5), 'cuda'),
LayerActivation(get_ith_layer(model, 6), 'cuda')
LayerActivation(get_ith_layer(model, 7), 'cuda')
])
>>> model(torch.rand((1, 3, 224, 224)))
>>> layer_activation.input
tensor([[...]])
>>> layer_activation.output
tensor([[...]])
>>> layer_activation.remove()
"""
def __init__(self, las: List[LayerActivation]):
"""
Args:
las (List[LayerActivation]): The layer activations of a series of layers.
"""
self.las = las
def __str__(self):
return '\n'.join([str(la) for la in self.las])
@property
def input(self):
"""Get the collected input data of first layer.
Returns:
torch.Tensor: Collected input data of first layer.
"""
return self.las[0].input
@property
def output(self):
"""Get the collected input data of last layer.
Returns:
torch.Tensor: Collected input data of last layer.
"""
return self.las[-1].output
def remove(self):
"""Remove all hooks in the model to avoid performance effect.
Use this after using the collected data.
"""
[la.remove() for la in self.las]
class TimeProfiler:
""" (NOT VERIFIED. DON'T USE ME)
"""
def __init__(self, layer: torch.nn, device):
self.before_infer_hook = layer.register_forward_pre_hook(self.before_hook_fn)
self.after_infer_hook = layer.register_forward_hook(self.after_hook_fn)
self.device = device
self.infer_time = None
self._start_time = None
if self.device != 'cpu':
self.s, self.e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
def before_hook_fn(self, module, input):
if self.device == 'cpu':
self._start_time = time.time()
else:
self.s.record()
def after_hook_fn(self, module, input, output):
if self.device == 'cpu':
self.infer_time = time.time() - self._start_time
else:
self.e.record()
torch.cuda.synchronize()
self.infer_time = self.s.elapsed_time(self.e) / 1000.
def remove(self):
self.before_infer_hook.remove()
self.after_infer_hook.remove()
class TimeProfilerWrapper:
""" (NOT VERIFIED. DON'T USE ME)
"""
def __init__(self, tps: List[TimeProfiler]):
self.tps = tps
@property
def infer_time(self):
return sum([tp.infer_time for tp in self.tps])
def remove(self):
[tp.remove() for tp in self.tps]