Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import abstractmethod | |
from math import ceil | |
from typing import Callable, Iterable, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from mmengine.config import Config | |
from mmengine.dataset import default_collate | |
from mmengine.fileio import get_file_backend | |
from mmengine.model import BaseModel | |
from mmengine.runner import load_checkpoint | |
from mmpretrain.structures import DataSample | |
from mmpretrain.utils import track | |
from .model import get_model, list_models | |
ModelType = Union[BaseModel, str, Config] | |
InputType = Union[str, np.ndarray, list] | |
class BaseInferencer: | |
"""Base inferencer for various tasks. | |
The BaseInferencer provides the standard workflow for inference as follows: | |
1. Preprocess the input data by :meth:`preprocess`. | |
2. Forward the data to the model by :meth:`forward`. ``BaseInferencer`` | |
assumes the model inherits from :class:`mmengine.models.BaseModel` and | |
will call `model.test_step` in :meth:`forward` by default. | |
3. Visualize the results by :meth:`visualize`. | |
4. Postprocess and return the results by :meth:`postprocess`. | |
When we call the subclasses inherited from BaseInferencer (not overriding | |
``__call__``), the workflow will be executed in order. | |
All subclasses of BaseInferencer could define the following class | |
attributes for customization: | |
- ``preprocess_kwargs``: The keys of the kwargs that will be passed to | |
:meth:`preprocess`. | |
- ``forward_kwargs``: The keys of the kwargs that will be passed to | |
:meth:`forward` | |
- ``visualize_kwargs``: The keys of the kwargs that will be passed to | |
:meth:`visualize` | |
- ``postprocess_kwargs``: The keys of the kwargs that will be passed to | |
:meth:`postprocess` | |
All attributes mentioned above should be a ``set`` of keys (strings), | |
and each key should not be duplicated. Actually, :meth:`__call__` will | |
dispatch all the arguments to the corresponding methods according to the | |
``xxx_kwargs`` mentioned above. | |
Subclasses inherited from ``BaseInferencer`` should implement | |
:meth:`_init_pipeline`, :meth:`visualize` and :meth:`postprocess`: | |
- _init_pipeline: Return a callable object to preprocess the input data. | |
- visualize: Visualize the results returned by :meth:`forward`. | |
- postprocess: Postprocess the results returned by :meth:`forward` and | |
:meth:`visualize`. | |
Args: | |
model (BaseModel | str | Config): A model name or a path to the config | |
file, or a :obj:`BaseModel` object. The model name can be found | |
by ``cls.list_models()`` and you can also query it in | |
:doc:`/modelzoo_statistics`. | |
pretrained (str, optional): Path to the checkpoint. If None, it will | |
try to find a pre-defined weight from the model you specified | |
(only work if the ``model`` is a model name). Defaults to None. | |
device (str | torch.device | None): Transfer the model to the target | |
device. Defaults to None. | |
device_map (str | dict | None): A map that specifies where each | |
submodule should go. It doesn't need to be refined to each | |
parameter/buffer name, once a given module name is inside, every | |
submodule of it will be sent to the same device. You can use | |
`device_map="auto"` to automatically generate the device map. | |
Defaults to None. | |
offload_folder (str | None): If the `device_map` contains any value | |
`"disk"`, the folder where we will offload weights. | |
**kwargs: Other keyword arguments to initialize the model (only work if | |
the ``model`` is a model name). | |
""" | |
preprocess_kwargs: set = set() | |
forward_kwargs: set = set() | |
visualize_kwargs: set = set() | |
postprocess_kwargs: set = set() | |
def __init__(self, | |
model: ModelType, | |
pretrained: Union[bool, str] = True, | |
device: Union[str, torch.device, None] = None, | |
device_map=None, | |
offload_folder=None, | |
**kwargs) -> None: | |
if isinstance(model, BaseModel): | |
if isinstance(pretrained, str): | |
load_checkpoint(model, pretrained, map_location='cpu') | |
if device_map is not None: | |
from .utils import dispatch_model | |
model = dispatch_model( | |
model, | |
device_map=device_map, | |
offload_folder=offload_folder) | |
elif device is not None: | |
model.to(device) | |
else: | |
model = get_model( | |
model, | |
pretrained, | |
device=device, | |
device_map=device_map, | |
offload_folder=offload_folder, | |
**kwargs) | |
model.eval() | |
self.config = model._config | |
self.model = model | |
self.pipeline = self._init_pipeline(self.config) | |
self.visualizer = None | |
def __call__( | |
self, | |
inputs, | |
return_datasamples: bool = False, | |
batch_size: int = 1, | |
**kwargs, | |
) -> dict: | |
"""Call the inferencer. | |
Args: | |
inputs (InputsType): Inputs for the inferencer. | |
return_datasamples (bool): Whether to return results as | |
:obj:`BaseDataElement`. Defaults to False. | |
batch_size (int): Batch size. Defaults to 1. | |
**kwargs: Key words arguments passed to :meth:`preprocess`, | |
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`. | |
Each key in kwargs should be in the corresponding set of | |
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs`` | |
and ``postprocess_kwargs``. | |
Returns: | |
dict: Inference and visualization results. | |
""" | |
( | |
preprocess_kwargs, | |
forward_kwargs, | |
visualize_kwargs, | |
postprocess_kwargs, | |
) = self._dispatch_kwargs(**kwargs) | |
ori_inputs = self._inputs_to_list(inputs) | |
inputs = self.preprocess( | |
ori_inputs, batch_size=batch_size, **preprocess_kwargs) | |
preds = [] | |
for data in track( | |
inputs, 'Inference', total=ceil(len(ori_inputs) / batch_size)): | |
preds.extend(self.forward(data, **forward_kwargs)) | |
visualization = self.visualize(ori_inputs, preds, **visualize_kwargs) | |
results = self.postprocess(preds, visualization, return_datasamples, | |
**postprocess_kwargs) | |
return results | |
def _inputs_to_list(self, inputs: InputType) -> list: | |
"""Preprocess the inputs to a list. | |
Cast the input data to a list of data. | |
- list or tuple: return inputs | |
- str: | |
- Directory path: return all files in the directory | |
- other cases: return a list containing the string. The string | |
could be a path to file, a url or other types of string according | |
to the task. | |
- other: return a list with one item. | |
Args: | |
inputs (str | array | list): Inputs for the inferencer. | |
Returns: | |
list: List of input for the :meth:`preprocess`. | |
""" | |
if isinstance(inputs, str): | |
backend = get_file_backend(inputs) | |
if hasattr(backend, 'isdir') and backend.isdir(inputs): | |
# Backends like HttpsBackend do not implement `isdir`, so only | |
# those backends that implement `isdir` could accept the inputs | |
# as a directory | |
file_list = backend.list_dir_or_file(inputs, list_dir=False) | |
inputs = [ | |
backend.join_path(inputs, file) for file in file_list | |
] | |
if not isinstance(inputs, (list, tuple)): | |
inputs = [inputs] | |
return list(inputs) | |
def preprocess(self, inputs: InputType, batch_size: int = 1, **kwargs): | |
"""Process the inputs into a model-feedable format. | |
Customize your preprocess by overriding this method. Preprocess should | |
return an iterable object, of which each item will be used as the | |
input of ``model.test_step``. | |
``BaseInferencer.preprocess`` will return an iterable chunked data, | |
which will be used in __call__ like this: | |
.. code-block:: python | |
def __call__(self, inputs, batch_size=1, **kwargs): | |
chunked_data = self.preprocess(inputs, batch_size, **kwargs) | |
for batch in chunked_data: | |
preds = self.forward(batch, **kwargs) | |
Args: | |
inputs (InputsType): Inputs given by user. | |
batch_size (int): batch size. Defaults to 1. | |
Yields: | |
Any: Data processed by the ``pipeline`` and ``default_collate``. | |
""" | |
chunked_data = self._get_chunk_data( | |
map(self.pipeline, inputs), batch_size) | |
yield from map(default_collate, chunked_data) | |
def forward(self, inputs: Union[dict, tuple], **kwargs): | |
"""Feed the inputs to the model.""" | |
return self.model.test_step(inputs) | |
def visualize(self, | |
inputs: list, | |
preds: List[DataSample], | |
show: bool = False, | |
**kwargs) -> List[np.ndarray]: | |
"""Visualize predictions. | |
Customize your visualization by overriding this method. visualize | |
should return visualization results, which could be np.ndarray or any | |
other objects. | |
Args: | |
inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`. | |
preds (Any): Predictions of the model. | |
show (bool): Whether to display the image in a popup window. | |
Defaults to False. | |
Returns: | |
List[np.ndarray]: Visualization results. | |
""" | |
if show: | |
raise NotImplementedError( | |
f'The `visualize` method of {self.__class__.__name__} ' | |
'is not implemented.') | |
def postprocess( | |
self, | |
preds: List[DataSample], | |
visualization: List[np.ndarray], | |
return_datasample=False, | |
**kwargs, | |
) -> dict: | |
"""Process the predictions and visualization results from ``forward`` | |
and ``visualize``. | |
This method should be responsible for the following tasks: | |
1. Convert datasamples into a json-serializable dict if needed. | |
2. Pack the predictions and visualization results and return them. | |
3. Dump or log the predictions. | |
Customize your postprocess by overriding this method. Make sure | |
``postprocess`` will return a dict with visualization results and | |
inference results. | |
Args: | |
preds (List[Dict]): Predictions of the model. | |
visualization (np.ndarray): Visualized predictions. | |
return_datasample (bool): Whether to return results as datasamples. | |
Defaults to False. | |
Returns: | |
dict: Inference and visualization results with key ``predictions`` | |
and ``visualization`` | |
- ``visualization (Any)``: Returned by :meth:`visualize` | |
- ``predictions`` (dict or DataSample): Returned by | |
:meth:`forward` and processed in :meth:`postprocess`. | |
If ``return_datasample=False``, it usually should be a | |
json-serializable dict containing only basic data elements such | |
as strings and numbers. | |
""" | |
def _init_pipeline(self, cfg: Config) -> Callable: | |
"""Initialize the test pipeline. | |
Return a pipeline to handle various input data, such as ``str``, | |
``np.ndarray``. It is an abstract method in BaseInferencer, and should | |
be implemented in subclasses. | |
The returned pipeline will be used to process a single data. | |
It will be used in :meth:`preprocess` like this: | |
.. code-block:: python | |
def preprocess(self, inputs, batch_size, **kwargs): | |
... | |
dataset = map(self.pipeline, dataset) | |
... | |
""" | |
def _get_chunk_data(self, inputs: Iterable, chunk_size: int): | |
"""Get batch data from dataset. | |
Args: | |
inputs (Iterable): An iterable dataset. | |
chunk_size (int): Equivalent to batch size. | |
Yields: | |
list: batch data. | |
""" | |
inputs_iter = iter(inputs) | |
while True: | |
try: | |
chunk_data = [] | |
for _ in range(chunk_size): | |
processed_data = next(inputs_iter) | |
chunk_data.append(processed_data) | |
yield chunk_data | |
except StopIteration: | |
if chunk_data: | |
yield chunk_data | |
break | |
def _dispatch_kwargs(self, **kwargs) -> Tuple[dict, dict, dict, dict]: | |
"""Dispatch kwargs to preprocess(), forward(), visualize() and | |
postprocess() according to the actual demands. | |
Returns: | |
Tuple[Dict, Dict, Dict, Dict]: kwargs passed to preprocess, | |
forward, visualize and postprocess respectively. | |
""" | |
# Ensure each argument only matches one function | |
method_kwargs = self.preprocess_kwargs | self.forward_kwargs | \ | |
self.visualize_kwargs | self.postprocess_kwargs | |
union_kwargs = method_kwargs | set(kwargs.keys()) | |
if union_kwargs != method_kwargs: | |
unknown_kwargs = union_kwargs - method_kwargs | |
raise ValueError( | |
f'unknown argument {unknown_kwargs} for `preprocess`, ' | |
'`forward`, `visualize` and `postprocess`') | |
preprocess_kwargs = {} | |
forward_kwargs = {} | |
visualize_kwargs = {} | |
postprocess_kwargs = {} | |
for key, value in kwargs.items(): | |
if key in self.preprocess_kwargs: | |
preprocess_kwargs[key] = value | |
if key in self.forward_kwargs: | |
forward_kwargs[key] = value | |
if key in self.visualize_kwargs: | |
visualize_kwargs[key] = value | |
if key in self.postprocess_kwargs: | |
postprocess_kwargs[key] = value | |
return ( | |
preprocess_kwargs, | |
forward_kwargs, | |
visualize_kwargs, | |
postprocess_kwargs, | |
) | |
def list_models(pattern: Optional[str] = None): | |
"""List models defined in metafile of corresponding packages. | |
Args: | |
pattern (str | None): A wildcard pattern to match model names. | |
Returns: | |
List[str]: a list of model names. | |
""" | |
return list_models(pattern=pattern) | |