|
import enum |
|
import time |
|
from typing import List, Tuple, Type |
|
import torch |
|
import warnings |
|
import os |
|
import thop |
|
|
|
from ...common.others import get_cur_time_str |
|
|
|
|
|
class ModelSaveMethod(enum.Enum): |
|
""" |
|
- WEIGHT: save model by `torch.save(model.state_dict(), ...)` |
|
- FULL: save model by `torch.save(model, ...)` |
|
- JIT: convert model to JIT format and save it by `torch.jit.save(jit_model, ...)` |
|
""" |
|
WEIGHT = 0 |
|
FULL = 1 |
|
JIT = 2 |
|
|
|
|
|
def save_model(model: torch.nn.Module, |
|
model_file_path: str, |
|
save_method: ModelSaveMethod, |
|
model_input_size: Tuple[int]=None): |
|
"""Save a PyTorch model. |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
model_file_path (str): Target model file path. |
|
save_method (ModelSaveMethod): The method to save model. |
|
model_input_size (Tuple[int], optional): \ |
|
This is required if :attr:`save_method` is :attr:`ModelSaveMethod.JIT`. \ |
|
Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. \ |
|
Defaults to None. |
|
""" |
|
|
|
model.eval() |
|
if save_method == ModelSaveMethod.WEIGHT: |
|
torch.save(model.state_dict(), model_file_path) |
|
|
|
elif save_method == ModelSaveMethod.FULL: |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
torch.save(model, model_file_path) |
|
|
|
elif save_method == ModelSaveMethod.JIT: |
|
assert model_input_size is not None |
|
dummy_input = torch.ones(model_input_size, device=get_model_device(model)) |
|
new_model = torch.jit.trace(model, dummy_input, check_trace=False) |
|
torch.jit.save(new_model, model_file_path) |
|
|
|
|
|
def get_model_size(model: torch.nn.Module, return_MB=False): |
|
"""Get size of a PyTorch model (default in Byte). |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
return_MB (bool, optional): Return result in MB (/= 1024**2). Defaults to False. |
|
|
|
Returns: |
|
int: Model size. |
|
""" |
|
pid = os.getpid() |
|
tmp_model_file_path = './tmp-get-model-size-{}-{}.model'.format(pid, get_cur_time_str()) |
|
save_model(model, tmp_model_file_path, ModelSaveMethod.WEIGHT) |
|
|
|
model_size = os.path.getsize(tmp_model_file_path) |
|
os.remove(tmp_model_file_path) |
|
|
|
if return_MB: |
|
model_size /= 1024**2 |
|
|
|
return model_size |
|
|
|
|
|
def get_model_device(model: torch.nn.Module): |
|
"""Get device of a PyTorch model. |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
|
|
Returns: |
|
str: The device of :attr:`model` ('cpu' or 'cuda:x'). |
|
""" |
|
return list(model.parameters())[0].device |
|
|
|
|
|
def get_model_latency(model: torch.nn.Module, model_input_size: Tuple[int], sample_num: int, |
|
device: str, warmup_sample_num: int, return_detail=False): |
|
"""Get the latency (inference time) of a PyTorch model. |
|
|
|
Reference: https://deci.ai/resources/blog/measure-inference-time-deep-neural-networks/ |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. |
|
sample_num (int): How many inputs which size is :attr:`model_input_size` will be tested and compute the average latency as result. |
|
device (str): Typically be 'cpu' or 'cuda'. |
|
warmup_sample_num (int): Let model perform some dummy inference to warm up the test environment to avoid measurement loss. |
|
return_detail (bool, optional): Beside the average latency, return all result measured. Defaults to False. |
|
|
|
Returns: |
|
Union[float, Tuple[float, List[float]]]: The average latency (and all lantecy data) of :attr:`model`. |
|
""" |
|
if isinstance(model_input_size, tuple): |
|
dummy_input = torch.rand(model_input_size).to(device) |
|
else: |
|
dummy_input = model_input_size |
|
|
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
for _ in range(warmup_sample_num): |
|
model(dummy_input) |
|
|
|
infer_time_list = [] |
|
|
|
if device == 'cuda' or 'cuda' in str(device): |
|
with torch.no_grad(): |
|
for _ in range(sample_num): |
|
s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
|
s.record() |
|
model(dummy_input) |
|
e.record() |
|
torch.cuda.synchronize() |
|
cur_model_infer_time = s.elapsed_time(e) / 1000. |
|
infer_time_list += [cur_model_infer_time] |
|
|
|
else: |
|
with torch.no_grad(): |
|
for _ in range(sample_num): |
|
start = time.time() |
|
model(dummy_input) |
|
cur_model_infer_time = time.time() - start |
|
infer_time_list += [cur_model_infer_time] |
|
|
|
avg_infer_time = sum(infer_time_list) / sample_num |
|
|
|
if return_detail: |
|
return avg_infer_time, infer_time_list |
|
return avg_infer_time |
|
|
|
|
|
def get_model_flops_and_params(model: torch.nn.Module, model_input_size: Tuple[int], return_M=False): |
|
"""Get FLOPs and number of parameters of a PyTorch model. |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
model_input_size (Tuple[int]): Typically be `(1, 3, 32, 32)` or `(1, 3, 224, 224)`. |
|
|
|
Returns: |
|
Tuple[float, float]: FLOPs and number of parameters of :attr:`model`. |
|
""" |
|
device = get_model_device(model) |
|
ops, param = thop.profile(model, (torch.ones(model_input_size).to(device), ), verbose=False) |
|
ops, param = ops * 2, param |
|
if return_M: |
|
ops, param = ops / 1e6, param / 1e6 |
|
return ops, param |
|
|
|
|
|
def get_module(model: torch.nn.Module, module_name: str): |
|
"""Get a module from a PyTorch model. |
|
|
|
Example: |
|
>>> from torchvision.models import resnet18 |
|
>>> model = resnet18() |
|
>>> get_module(model, 'layer1.0') |
|
BasicBlock( |
|
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) |
|
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
(relu): ReLU(inplace=True) |
|
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) |
|
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
) |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
module_name (str): Module name. |
|
|
|
Returns: |
|
torch.nn.Module: Corrsponding module. |
|
""" |
|
for name, module in model.named_modules(): |
|
if name == module_name: |
|
return module |
|
|
|
return None |
|
|
|
|
|
def get_parameter(model: torch.nn.Module, param_name: str): |
|
return getattr( |
|
get_module(model, '.'.join(param_name.split('.')[0: -1])), |
|
param_name.split('.')[-1] |
|
) |
|
|
|
|
|
def get_super_module(model: torch.nn.Module, module_name: str): |
|
"""Get the super module of a module in a PyTorch model. |
|
|
|
Example: |
|
>>> from torchvision.models import resnet18 |
|
>>> model = resnet18() |
|
>>> get_super_module(model, 'layer1.0.conv1') |
|
BasicBlock( |
|
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) |
|
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
(relu): ReLU(inplace=True) |
|
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) |
|
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
) |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
module_name (str): Module name. |
|
|
|
Returns: |
|
torch.nn.Module: Super module of module :attr:`module_name`. |
|
""" |
|
super_module_name = '.'.join(module_name.split('.')[0:-1]) |
|
return get_module(model, super_module_name) |
|
|
|
|
|
def set_module(model: torch.nn.Module, module_name: str, module: torch.nn.Module): |
|
"""Set module in a PyTorch model. |
|
|
|
Example: |
|
>>> from torchvision.models import resnet18 |
|
>>> model = resnet18() |
|
>>> set_module(model, 'layer1.0', torch.nn.Conv2d(64, 64, 3)) |
|
>>> model |
|
ResNet( |
|
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) |
|
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) |
|
(relu): ReLU(inplace=True) |
|
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) |
|
(layer1): Sequential( |
|
--> (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) |
|
(1): BasicBlock( |
|
... |
|
) |
|
... |
|
) |
|
... |
|
) |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
module_name (str): Module name. |
|
module (torch.nn.Module): Target module which will be set into :attr:`model`. |
|
""" |
|
super_module = get_super_module(model, module_name) |
|
setattr(super_module, module_name.split('.')[-1], module) |
|
|
|
|
|
def get_ith_layer(model: torch.nn.Module, i: int): |
|
"""Get i-th layer in a PyTorch model. |
|
|
|
Example: |
|
>>> from torchvision.models import vgg16 |
|
>>> model = vgg16() |
|
>>> get_ith_layer(model, 5) |
|
Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
i (int): Index of target layer. |
|
|
|
Returns: |
|
torch.nn.Module: i-th layer in :attr:`model`. |
|
""" |
|
j = 0 |
|
for module in model.modules(): |
|
if len(list(module.children())) > 0: |
|
continue |
|
if j == i: |
|
return module |
|
j += 1 |
|
return None |
|
|
|
|
|
def get_ith_layer_name(model: torch.nn.Module, i: int): |
|
"""Get the name of i-th layer in a PyTorch model. |
|
|
|
Example: |
|
>>> from torchvision.models import vgg16 |
|
>>> model = vgg16() |
|
>>> get_ith_layer_name(model, 5) |
|
'features.5' |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
i (int): Index of target layer. |
|
|
|
Returns: |
|
str: The name of i-th layer in :attr:`model`. |
|
""" |
|
j = 0 |
|
for name, module in model.named_modules(): |
|
if len(list(module.children())) > 0: |
|
continue |
|
if j == i: |
|
return name |
|
j += 1 |
|
return None |
|
|
|
|
|
def set_ith_layer(model: torch.nn.Module, i: int, layer: torch.nn.Module): |
|
"""Set i-th layer in a PyTorch model. |
|
|
|
Example: |
|
>>> from torchvision.models import vgg16 |
|
>>> model = vgg16() |
|
>>> model |
|
VGG( |
|
(features): Sequential( |
|
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
(1): ReLU(inplace=True) |
|
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
... |
|
) |
|
... |
|
) |
|
>>> set_ith_layer(model, 2, torch.nn.Conv2d(64, 128, 3)) |
|
VGG( |
|
(features): Sequential( |
|
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
(1): ReLU(inplace=True) |
|
--> (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) |
|
... |
|
) |
|
... |
|
) |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
i (int): Index of target layer. |
|
layer (torch.nn.Module): The layer which will be set into :attr:`model`. |
|
""" |
|
j = 0 |
|
for name, module in model.named_modules(): |
|
if len(list(module.children())) > 0: |
|
continue |
|
if j == i: |
|
set_module(model, name, layer) |
|
return |
|
j += 1 |
|
|
|
|
|
def get_all_specific_type_layers_name(model: torch.nn.Module, types: Tuple[Type[torch.nn.Module]]): |
|
"""Get names of all layers which are give types in a PyTorch model. (e.g. `Conv2d`, `Linear`) |
|
|
|
Example: |
|
>>> from torchvision.models import vgg16 |
|
>>> model = vgg16() |
|
>>> get_all_specific_type_layers_name(model, (torch.nn.Conv2d)) |
|
['features.0', 'features.2', 'features.5', ...] |
|
|
|
Args: |
|
model (torch.nn.Module): A PyTorch model. |
|
types (Tuple[Type[torch.nn.Module]]): Target types, e.g. `(e.g. torch.nn.Conv2d, torch.nn.Linear)` |
|
|
|
Returns: |
|
List[str]: Names of all layers which are give types. |
|
""" |
|
res = [] |
|
for name, m in model.named_modules(): |
|
if isinstance(m, types): |
|
res += [name] |
|
return res |
|
|
|
|
|
class LayerActivation: |
|
"""Collect the input and output of a middle module of a PyTorch model during inference. |
|
|
|
Layer is a wide concept in this class. A module (e.g. ResBlock in ResNet) can be also regarded as a "layer". |
|
|
|
Example: |
|
>>> from torchvision.models import vgg16 |
|
>>> model = vgg16() |
|
>>> # collect the input and output of 5th layer in VGG16 |
|
>>> layer_activation = LayerActivation(get_ith_layer(model, 5), 'cuda') |
|
>>> model(torch.rand((1, 3, 224, 224))) |
|
>>> layer_activation.input |
|
tensor([[...]]) |
|
>>> layer_activation.output |
|
tensor([[...]]) |
|
>>> layer_activation.remove() |
|
""" |
|
def __init__(self, layer: torch.nn.Module, detach: bool, device: str): |
|
"""Register forward hook on corresponding layer. |
|
|
|
Args: |
|
layer (torch.nn.Module): Target layer. |
|
device (str): Where the collected data is located. |
|
""" |
|
self.hook = layer.register_forward_hook(self._hook_fn) |
|
self.detach = detach |
|
self.device = device |
|
self.input: torch.Tensor = None |
|
self.output: torch.Tensor = None |
|
self.layer = layer |
|
|
|
def __str__(self): |
|
return '- ' + str(self.layer) |
|
|
|
def _hook_fn(self, module, input, output): |
|
|
|
if isinstance(input, tuple): |
|
self.input = input[0].to(self.device) |
|
else: |
|
self.input = input.to(self.device) |
|
|
|
if isinstance(output, tuple): |
|
self.output = output[0].to(self.device) |
|
else: |
|
self.output = output.to(self.device) |
|
|
|
if self.detach: |
|
self.input = self.input.detach() |
|
self.output = self.output.detach() |
|
|
|
def remove(self): |
|
"""Remove the hook in the model to avoid performance effect. |
|
Use this after using the collected data. |
|
""" |
|
self.hook.remove() |
|
|
|
|
|
class LayerActivation2: |
|
"""Collect the input and output of a middle module of a PyTorch model during inference. |
|
|
|
Layer is a wide concept in this class. A module (e.g. ResBlock in ResNet) can be also regarded as a "layer". |
|
|
|
Example: |
|
>>> from torchvision.models import vgg16 |
|
>>> model = vgg16() |
|
>>> # collect the input and output of 5th layer in VGG16 |
|
>>> layer_activation = LayerActivation(get_ith_layer(model, 5), 'cuda') |
|
>>> model(torch.rand((1, 3, 224, 224))) |
|
>>> layer_activation.input |
|
tensor([[...]]) |
|
>>> layer_activation.output |
|
tensor([[...]]) |
|
>>> layer_activation.remove() |
|
""" |
|
def __init__(self, layer: torch.nn.Module): |
|
"""Register forward hook on corresponding layer. |
|
|
|
Args: |
|
layer (torch.nn.Module): Target layer. |
|
device (str): Where the collected data is located. |
|
""" |
|
assert layer is not None |
|
|
|
self.hook = layer.register_forward_hook(self._hook_fn) |
|
self.input: torch.Tensor = None |
|
self.output: torch.Tensor = None |
|
self.layer = layer |
|
|
|
def __str__(self): |
|
return '- ' + str(self.layer) |
|
|
|
def _hook_fn(self, module, input, output): |
|
self.input = input |
|
self.output = output |
|
|
|
def remove(self): |
|
"""Remove the hook in the model to avoid performance effect. |
|
Use this after using the collected data. |
|
""" |
|
self.hook.remove() |
|
|
|
class LayerActivation3: |
|
"""Collect the input and output of a middle module of a PyTorch model during inference. |
|
|
|
Layer is a wide concept in this class. A module (e.g. ResBlock in ResNet) can be also regarded as a "layer". |
|
|
|
Example: |
|
>>> from torchvision.models import vgg16 |
|
>>> model = vgg16() |
|
>>> # collect the input and output of 5th layer in VGG16 |
|
>>> layer_activation = LayerActivation(get_ith_layer(model, 5), 'cuda') |
|
>>> model(torch.rand((1, 3, 224, 224))) |
|
>>> layer_activation.input |
|
tensor([[...]]) |
|
>>> layer_activation.output |
|
tensor([[...]]) |
|
>>> layer_activation.remove() |
|
""" |
|
def __init__(self, layer: torch.nn.Module, detach: bool, device: str): |
|
"""Register forward hook on corresponding layer. |
|
|
|
Args: |
|
layer (torch.nn.Module): Target layer. |
|
device (str): Where the collected data is located. |
|
""" |
|
self.hook = layer.register_forward_hook(self._hook_fn) |
|
self.detach = detach |
|
self.device = device |
|
self.input: torch.Tensor = None |
|
self.output: torch.Tensor = None |
|
self.layer = layer |
|
|
|
def __str__(self): |
|
return '- ' + str(self.layer) |
|
|
|
def _hook_fn(self, module, input, output): |
|
|
|
self.input = input |
|
self.output = output |
|
|
|
|
|
|
|
|
|
|
|
def remove(self): |
|
"""Remove the hook in the model to avoid performance effect. |
|
Use this after using the collected data. |
|
""" |
|
self.hook.remove() |
|
|
|
class LayerActivationWrapper: |
|
"""A wrapper of :attr:`LayerActivation` which has the same API, but broaden the concept "layer". |
|
Now a series of layers can be regarded as "hyper-layer" in this class. |
|
|
|
Example: |
|
>>> from torchvision.models import vgg16 |
|
>>> model = vgg16() |
|
>>> # collect the input of 5th layer, and output of 7th layer in VGG16 |
|
>>> # i.e. regard 5th~7th layer as a whole module, |
|
>>> # and collect the input and output of this module |
|
>>> layer_activation = LayerActivationWrapper([ |
|
LayerActivation(get_ith_layer(model, 5), 'cuda'), |
|
LayerActivation(get_ith_layer(model, 6), 'cuda') |
|
LayerActivation(get_ith_layer(model, 7), 'cuda') |
|
]) |
|
>>> model(torch.rand((1, 3, 224, 224))) |
|
>>> layer_activation.input |
|
tensor([[...]]) |
|
>>> layer_activation.output |
|
tensor([[...]]) |
|
>>> layer_activation.remove() |
|
""" |
|
def __init__(self, las: List[LayerActivation]): |
|
""" |
|
Args: |
|
las (List[LayerActivation]): The layer activations of a series of layers. |
|
""" |
|
self.las = las |
|
|
|
def __str__(self): |
|
return '\n'.join([str(la) for la in self.las]) |
|
|
|
@property |
|
def input(self): |
|
"""Get the collected input data of first layer. |
|
|
|
Returns: |
|
torch.Tensor: Collected input data of first layer. |
|
""" |
|
return self.las[0].input |
|
|
|
@property |
|
def output(self): |
|
"""Get the collected input data of last layer. |
|
|
|
Returns: |
|
torch.Tensor: Collected input data of last layer. |
|
""" |
|
return self.las[-1].output |
|
|
|
def remove(self): |
|
"""Remove all hooks in the model to avoid performance effect. |
|
Use this after using the collected data. |
|
""" |
|
[la.remove() for la in self.las] |
|
|
|
|
|
class TimeProfiler: |
|
""" (NOT VERIFIED. DON'T USE ME) |
|
""" |
|
def __init__(self, layer: torch.nn, device): |
|
self.before_infer_hook = layer.register_forward_pre_hook(self.before_hook_fn) |
|
self.after_infer_hook = layer.register_forward_hook(self.after_hook_fn) |
|
|
|
self.device = device |
|
self.infer_time = None |
|
self._start_time = None |
|
|
|
if self.device != 'cpu': |
|
self.s, self.e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
|
|
|
def before_hook_fn(self, module, input): |
|
if self.device == 'cpu': |
|
self._start_time = time.time() |
|
else: |
|
self.s.record() |
|
|
|
def after_hook_fn(self, module, input, output): |
|
if self.device == 'cpu': |
|
self.infer_time = time.time() - self._start_time |
|
else: |
|
self.e.record() |
|
torch.cuda.synchronize() |
|
self.infer_time = self.s.elapsed_time(self.e) / 1000. |
|
|
|
def remove(self): |
|
self.before_infer_hook.remove() |
|
self.after_infer_hook.remove() |
|
|
|
|
|
class TimeProfilerWrapper: |
|
""" (NOT VERIFIED. DON'T USE ME) |
|
""" |
|
def __init__(self, tps: List[TimeProfiler]): |
|
self.tps = tps |
|
|
|
@property |
|
def infer_time(self): |
|
return sum([tp.infer_time for tp in self.tps]) |
|
|
|
def remove(self): |
|
[tp.remove() for tp in self.tps] |