atatakun's picture
Duplicate from atatakun/testapp2
18dd6ad
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Callable, List, Optional, Union
import torch
from ..dist_utils import master_only
from .hook import HOOKS, Hook
@HOOKS.register_module()
class ProfilerHook(Hook):
"""Profiler to analyze performance during training.
PyTorch Profiler is a tool that allows the collection of the performance
metrics during the training. More details on Profiler can be found at
https://pytorch.org/docs/1.8.1/profiler.html#torch.profiler.profile
Args:
by_epoch (bool): Profile performance by epoch or by iteration.
Default: True.
profile_iters (int): Number of iterations for profiling.
If ``by_epoch=True``, profile_iters indicates that they are the
first profile_iters epochs at the beginning of the
training, otherwise it indicates the first profile_iters
iterations. Default: 1.
activities (list[str]): List of activity groups (CPU, CUDA) to use in
profiling. Default: ['cpu', 'cuda'].
schedule (dict, optional): Config of generating the callable schedule.
if schedule is None, profiler will not add step markers into the
trace and table view. Default: None.
on_trace_ready (callable, dict): Either a handler or a dict of generate
handler. Default: None.
record_shapes (bool): Save information about operator's input shapes.
Default: False.
profile_memory (bool): Track tensor memory allocation/deallocation.
Default: False.
with_stack (bool): Record source information (file and line number)
for the ops. Default: False.
with_flops (bool): Use formula to estimate the FLOPS of specific
operators (matrix multiplication and 2D convolution).
Default: False.
json_trace_path (str, optional): Exports the collected trace in Chrome
JSON format. Default: None.
Example:
>>> runner = ... # instantiate a Runner
>>> # tensorboard trace
>>> trace_config = dict(type='tb_trace', dir_name='work_dir')
>>> profiler_config = dict(on_trace_ready=trace_config)
>>> runner.register_profiler_hook(profiler_config)
>>> runner.run(data_loaders=[trainloader], workflow=[('train', 1)])
"""
def __init__(self,
by_epoch: bool = True,
profile_iters: int = 1,
activities: List[str] = ['cpu', 'cuda'],
schedule: Optional[dict] = None,
on_trace_ready: Optional[Union[Callable, dict]] = None,
record_shapes: bool = False,
profile_memory: bool = False,
with_stack: bool = False,
with_flops: bool = False,
json_trace_path: Optional[str] = None) -> None:
try:
from torch import profiler # torch version >= 1.8.1
except ImportError:
raise ImportError('profiler is the new feature of torch1.8.1, '
f'but your version is {torch.__version__}')
assert isinstance(by_epoch, bool), '``by_epoch`` should be a boolean.'
self.by_epoch = by_epoch
if profile_iters < 1:
raise ValueError('profile_iters should be greater than 0, but got '
f'{profile_iters}')
self.profile_iters = profile_iters
if not isinstance(activities, list):
raise ValueError(
f'activities should be list, but got {type(activities)}')
self.activities = []
for activity in activities:
activity = activity.lower()
if activity == 'cpu':
self.activities.append(profiler.ProfilerActivity.CPU)
elif activity == 'cuda':
self.activities.append(profiler.ProfilerActivity.CUDA)
else:
raise ValueError(
f'activity should be "cpu" or "cuda", but got {activity}')
if schedule is not None:
self.schedule = profiler.schedule(**schedule)
else:
self.schedule = None
self.on_trace_ready = on_trace_ready
self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
self.with_flops = with_flops
self.json_trace_path = json_trace_path
@master_only
def before_run(self, runner):
if self.by_epoch and runner.max_epochs < self.profile_iters:
raise ValueError('self.profile_iters should not be greater than '
f'{runner.max_epochs}')
if not self.by_epoch and runner.max_iters < self.profile_iters:
raise ValueError('self.profile_iters should not be greater than '
f'{runner.max_iters}')
if callable(self.on_trace_ready): # handler
_on_trace_ready = self.on_trace_ready
elif isinstance(self.on_trace_ready, dict): # config of handler
trace_cfg = self.on_trace_ready.copy()
trace_type = trace_cfg.pop('type') # log_trace handler
if trace_type == 'log_trace':
def _log_handler(prof):
print(prof.key_averages().table(**trace_cfg))
_on_trace_ready = _log_handler
elif trace_type == 'tb_trace': # tensorboard_trace handler
try:
import torch_tb_profiler # noqa: F401
except ImportError:
raise ImportError('please run "pip install '
'torch-tb-profiler" to install '
'torch_tb_profiler')
_on_trace_ready = torch.profiler.tensorboard_trace_handler(
**trace_cfg)
else:
raise ValueError('trace_type should be "log_trace" or '
f'"tb_trace", but got {trace_type}')
elif self.on_trace_ready is None:
_on_trace_ready = None # type: ignore
else:
raise ValueError('on_trace_ready should be handler, dict or None, '
f'but got {type(self.on_trace_ready)}')
if runner.max_epochs > 1:
warnings.warn(f'profiler will profile {runner.max_epochs} epochs '
'instead of 1 epoch. Since profiler will slow down '
'the training, it is recommended to train 1 epoch '
'with ProfilerHook and adjust your setting according'
' to the profiler summary. During normal training '
'(epoch > 1), you may disable the ProfilerHook.')
self.profiler = torch.profiler.profile(
activities=self.activities,
schedule=self.schedule,
on_trace_ready=_on_trace_ready,
record_shapes=self.record_shapes,
profile_memory=self.profile_memory,
with_stack=self.with_stack,
with_flops=self.with_flops)
self.profiler.__enter__()
runner.logger.info('profiler is profiling...')
@master_only
def after_train_epoch(self, runner):
if self.by_epoch and runner.epoch == self.profile_iters - 1:
runner.logger.info('profiler may take a few minutes...')
self.profiler.__exit__(None, None, None)
if self.json_trace_path is not None:
self.profiler.export_chrome_trace(self.json_trace_path)
@master_only
def after_train_iter(self, runner):
self.profiler.step()
if not self.by_epoch and runner.iter == self.profile_iters - 1:
runner.logger.info('profiler may take a few minutes...')
self.profiler.__exit__(None, None, None)
if self.json_trace_path is not None:
self.profiler.export_chrome_trace(self.json_trace_path)