|
|
|
import datetime |
|
import logging |
|
import time |
|
from collections import OrderedDict, abc |
|
from contextlib import ExitStack, contextmanager |
|
from typing import List, Union |
|
import torch |
|
from torch import nn |
|
|
|
from annotator.oneformer.detectron2.utils.comm import get_world_size, is_main_process |
|
from annotator.oneformer.detectron2.utils.logger import log_every_n_seconds |
|
|
|
|
|
class DatasetEvaluator: |
|
""" |
|
Base class for a dataset evaluator. |
|
|
|
The function :func:`inference_on_dataset` runs the model over |
|
all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs. |
|
|
|
This class will accumulate information of the inputs/outputs (by :meth:`process`), |
|
and produce evaluation results in the end (by :meth:`evaluate`). |
|
""" |
|
|
|
def reset(self): |
|
""" |
|
Preparation for a new round of evaluation. |
|
Should be called before starting a round of evaluation. |
|
""" |
|
pass |
|
|
|
def process(self, inputs, outputs): |
|
""" |
|
Process the pair of inputs and outputs. |
|
If they contain batches, the pairs can be consumed one-by-one using `zip`: |
|
|
|
.. code-block:: python |
|
|
|
for input_, output in zip(inputs, outputs): |
|
# do evaluation on single input/output pair |
|
... |
|
|
|
Args: |
|
inputs (list): the inputs that's used to call the model. |
|
outputs (list): the return value of `model(inputs)` |
|
""" |
|
pass |
|
|
|
def evaluate(self): |
|
""" |
|
Evaluate/summarize the performance, after processing all input/output pairs. |
|
|
|
Returns: |
|
dict: |
|
A new evaluator class can return a dict of arbitrary format |
|
as long as the user can process the results. |
|
In our train_net.py, we expect the following format: |
|
|
|
* key: the name of the task (e.g., bbox) |
|
* value: a dict of {metric name: score}, e.g.: {"AP50": 80} |
|
""" |
|
pass |
|
|
|
|
|
class DatasetEvaluators(DatasetEvaluator): |
|
""" |
|
Wrapper class to combine multiple :class:`DatasetEvaluator` instances. |
|
|
|
This class dispatches every evaluation call to |
|
all of its :class:`DatasetEvaluator`. |
|
""" |
|
|
|
def __init__(self, evaluators): |
|
""" |
|
Args: |
|
evaluators (list): the evaluators to combine. |
|
""" |
|
super().__init__() |
|
self._evaluators = evaluators |
|
|
|
def reset(self): |
|
for evaluator in self._evaluators: |
|
evaluator.reset() |
|
|
|
def process(self, inputs, outputs): |
|
for evaluator in self._evaluators: |
|
evaluator.process(inputs, outputs) |
|
|
|
def evaluate(self): |
|
results = OrderedDict() |
|
for evaluator in self._evaluators: |
|
result = evaluator.evaluate() |
|
if is_main_process() and result is not None: |
|
for k, v in result.items(): |
|
assert ( |
|
k not in results |
|
), "Different evaluators produce results with the same key {}".format(k) |
|
results[k] = v |
|
return results |
|
|
|
|
|
def inference_on_dataset( |
|
model, data_loader, evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None] |
|
): |
|
""" |
|
Run model on the data_loader and evaluate the metrics with evaluator. |
|
Also benchmark the inference speed of `model.__call__` accurately. |
|
The model will be used in eval mode. |
|
|
|
Args: |
|
model (callable): a callable which takes an object from |
|
`data_loader` and returns some outputs. |
|
|
|
If it's an nn.Module, it will be temporarily set to `eval` mode. |
|
If you wish to evaluate a model in `training` mode instead, you can |
|
wrap the given model and override its behavior of `.eval()` and `.train()`. |
|
data_loader: an iterable object with a length. |
|
The elements it generates will be the inputs to the model. |
|
evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark, |
|
but don't want to do any evaluation. |
|
|
|
Returns: |
|
The return value of `evaluator.evaluate()` |
|
""" |
|
num_devices = get_world_size() |
|
logger = logging.getLogger(__name__) |
|
logger.info("Start inference on {} batches".format(len(data_loader))) |
|
|
|
total = len(data_loader) |
|
if evaluator is None: |
|
|
|
evaluator = DatasetEvaluators([]) |
|
if isinstance(evaluator, abc.MutableSequence): |
|
evaluator = DatasetEvaluators(evaluator) |
|
evaluator.reset() |
|
|
|
num_warmup = min(5, total - 1) |
|
start_time = time.perf_counter() |
|
total_data_time = 0 |
|
total_compute_time = 0 |
|
total_eval_time = 0 |
|
with ExitStack() as stack: |
|
if isinstance(model, nn.Module): |
|
stack.enter_context(inference_context(model)) |
|
stack.enter_context(torch.no_grad()) |
|
|
|
start_data_time = time.perf_counter() |
|
for idx, inputs in enumerate(data_loader): |
|
total_data_time += time.perf_counter() - start_data_time |
|
if idx == num_warmup: |
|
start_time = time.perf_counter() |
|
total_data_time = 0 |
|
total_compute_time = 0 |
|
total_eval_time = 0 |
|
|
|
start_compute_time = time.perf_counter() |
|
outputs = model(inputs) |
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
total_compute_time += time.perf_counter() - start_compute_time |
|
|
|
start_eval_time = time.perf_counter() |
|
evaluator.process(inputs, outputs) |
|
total_eval_time += time.perf_counter() - start_eval_time |
|
|
|
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) |
|
data_seconds_per_iter = total_data_time / iters_after_start |
|
compute_seconds_per_iter = total_compute_time / iters_after_start |
|
eval_seconds_per_iter = total_eval_time / iters_after_start |
|
total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start |
|
if idx >= num_warmup * 2 or compute_seconds_per_iter > 5: |
|
eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1))) |
|
log_every_n_seconds( |
|
logging.INFO, |
|
( |
|
f"Inference done {idx + 1}/{total}. " |
|
f"Dataloading: {data_seconds_per_iter:.4f} s/iter. " |
|
f"Inference: {compute_seconds_per_iter:.4f} s/iter. " |
|
f"Eval: {eval_seconds_per_iter:.4f} s/iter. " |
|
f"Total: {total_seconds_per_iter:.4f} s/iter. " |
|
f"ETA={eta}" |
|
), |
|
n=5, |
|
) |
|
start_data_time = time.perf_counter() |
|
|
|
|
|
total_time = time.perf_counter() - start_time |
|
total_time_str = str(datetime.timedelta(seconds=total_time)) |
|
|
|
logger.info( |
|
"Total inference time: {} ({:.6f} s / iter per device, on {} devices)".format( |
|
total_time_str, total_time / (total - num_warmup), num_devices |
|
) |
|
) |
|
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time))) |
|
logger.info( |
|
"Total inference pure compute time: {} ({:.6f} s / iter per device, on {} devices)".format( |
|
total_compute_time_str, total_compute_time / (total - num_warmup), num_devices |
|
) |
|
) |
|
|
|
results = evaluator.evaluate() |
|
|
|
|
|
if results is None: |
|
results = {} |
|
return results |
|
|
|
|
|
@contextmanager |
|
def inference_context(model): |
|
""" |
|
A context where the model is temporarily changed to eval mode, |
|
and restored to previous mode afterwards. |
|
|
|
Args: |
|
model: a torch Module |
|
""" |
|
training_mode = model.training |
|
model.eval() |
|
yield |
|
model.train(training_mode) |
|
|