Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from abc import ABCMeta, abstractmethod | |
from typing import List, Optional, Union | |
import torch | |
from mmengine.model import BaseModel | |
from mmengine.structures import BaseDataElement | |
from torch.utils.data import DataLoader | |
class BaseRetriever(BaseModel, metaclass=ABCMeta): | |
"""Base class for retriever. | |
Args: | |
init_cfg (dict, optional): Initialization config dict. | |
Defaults to None. | |
data_preprocessor (dict, optional): The config for preprocessing input | |
data. If None, it will use "BaseDataPreprocessor" as type, see | |
:class:`mmengine.model.BaseDataPreprocessor` for more details. | |
Defaults to None. | |
prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be | |
retrieved. The following four types are supported. | |
- DataLoader: The original dataloader serves as the prototype. | |
- dict: The configuration to construct Dataloader. | |
- str: The path of the saved vector. | |
- torch.Tensor: The saved tensor whose dimension should be dim. | |
Attributes: | |
prototype (Union[DataLoader, dict, str, torch.Tensor]): Database to be | |
retrieved. The following four types are supported. | |
- DataLoader: The original dataloader serves as the prototype. | |
- dict: The configuration to construct Dataloader. | |
- str: The path of the saved vector. | |
- torch.Tensor: The saved tensor whose dimension should be dim. | |
data_preprocessor (:obj:`mmengine.model.BaseDataPreprocessor`): An | |
extra data pre-processing module, which processes data from | |
dataloader to the format accepted by :meth:`forward`. | |
""" | |
def __init__( | |
self, | |
prototype: Union[DataLoader, dict, str, torch.Tensor] = None, | |
data_preprocessor: Optional[dict] = None, | |
init_cfg: Optional[dict] = None, | |
): | |
super(BaseRetriever, self).__init__( | |
init_cfg=init_cfg, data_preprocessor=data_preprocessor) | |
self.prototype = prototype | |
self.prototype_inited = False | |
def forward(self, | |
inputs: torch.Tensor, | |
data_samples: Optional[List[BaseDataElement]] = None, | |
mode: str = 'loss'): | |
"""The unified entry for a forward process in both training and test. | |
The method should accept three modes: "tensor", "predict" and "loss": | |
- "tensor": Forward the whole network and return tensor without any | |
post-processing, same as a common nn.Module. | |
- "predict": Forward and return the predictions, which are fully | |
processed to a list of :obj:`ClsDataSample`. | |
- "loss": Forward and return a dict of losses according to the given | |
inputs and data samples. | |
Note that this method doesn't handle neither back propagation nor | |
optimizer updating, which are done in the :meth:`train_step`. | |
Args: | |
inputs (torch.Tensor, tuple): The input tensor with shape | |
(N, C, ...) in general. | |
data_samples (List[ClsDataSample], optional): The annotation | |
data of every samples. It's required if ``mode="loss"``. | |
Defaults to None. | |
mode (str): Return what kind of value. Defaults to 'tensor'. | |
Returns: | |
The return type depends on ``mode``. | |
- If ``mode="tensor"``, return a tensor. | |
- If ``mode="predict"``, return a list of | |
:obj:`mmcls.structures.ClsDataSample`. | |
- If ``mode="loss"``, return a dict of tensor. | |
""" | |
pass | |
def extract_feat(self, inputs: torch.Tensor): | |
"""Extract features from the input tensor with shape (N, C, ...). | |
The sub-classes are recommended to implement this method to extract | |
features from backbone and neck. | |
Args: | |
inputs (Tensor): A batch of inputs. The shape of it should be | |
``(num_samples, num_channels, *img_shape)``. | |
""" | |
raise NotImplementedError | |
def loss(self, inputs: torch.Tensor, | |
data_samples: List[BaseDataElement]) -> dict: | |
"""Calculate losses from a batch of inputs and data samples. | |
Args: | |
inputs (torch.Tensor): The input tensor with shape | |
(N, C, ...) in general. | |
data_samples (List[ClsDataSample]): The annotation data of | |
every samples. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
raise NotImplementedError | |
def predict(self, | |
inputs: tuple, | |
data_samples: Optional[List[BaseDataElement]] = None, | |
**kwargs) -> List[BaseDataElement]: | |
"""Predict results from the extracted features. | |
Args: | |
inputs (tuple): The features extracted from the backbone. | |
data_samples (List[BaseDataElement], optional): The annotation | |
data of every samples. Defaults to None. | |
**kwargs: Other keyword arguments accepted by the ``predict`` | |
method of :attr:`head`. | |
""" | |
raise NotImplementedError | |
def matching(self, inputs: torch.Tensor): | |
"""Compare the prototype and calculate the similarity. | |
Args: | |
inputs (torch.Tensor): The input tensor with shape (N, C). | |
""" | |
raise NotImplementedError | |
def prepare_prototype(self): | |
"""Preprocessing the prototype before predict.""" | |
raise NotImplementedError | |
def dump_prototype(self, path): | |
"""Save the features extracted from the prototype to the specific path. | |
Args: | |
path (str): Path to save feature. | |
""" | |
raise NotImplementedError | |