Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from multiprocessing.reduction import ForkingPickler | |
from numbers import Number | |
from typing import Sequence, Union | |
import numpy as np | |
import torch | |
from mmengine.structures import BaseDataElement, LabelData | |
from mmengine.utils import is_str | |
def format_label( | |
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor: | |
"""Convert various python types to label-format tensor. | |
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |
:class:`Sequence`, :class:`int`. | |
Args: | |
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value. | |
Returns: | |
:obj:`torch.Tensor`: The foramtted label tensor. | |
""" | |
# Handle single number | |
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0: | |
value = int(value.item()) | |
if isinstance(value, np.ndarray): | |
value = torch.from_numpy(value).to(torch.long) | |
elif isinstance(value, Sequence) and not is_str(value): | |
value = torch.tensor(value).to(torch.long) | |
elif isinstance(value, int): | |
value = torch.LongTensor([value]) | |
elif not isinstance(value, torch.Tensor): | |
raise TypeError(f'Type {type(value)} is not an available label type.') | |
assert value.ndim == 1, \ | |
f'The dims of value should be 1, but got {value.ndim}.' | |
return value | |
def format_score( | |
value: Union[torch.Tensor, np.ndarray, Sequence, int]) -> torch.Tensor: | |
"""Convert various python types to score-format tensor. | |
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, | |
:class:`Sequence`. | |
Args: | |
value (torch.Tensor | numpy.ndarray | Sequence): Score values. | |
Returns: | |
:obj:`torch.Tensor`: The foramtted score tensor. | |
""" | |
if isinstance(value, np.ndarray): | |
value = torch.from_numpy(value).float() | |
elif isinstance(value, Sequence) and not is_str(value): | |
value = torch.tensor(value).float() | |
elif not isinstance(value, torch.Tensor): | |
raise TypeError(f'Type {type(value)} is not an available label type.') | |
assert value.ndim == 1, \ | |
f'The dims of value should be 1, but got {value.ndim}.' | |
return value | |
class ClsDataSample(BaseDataElement): | |
"""A data structure interface of classification task. | |
It's used as interfaces between different components. | |
Meta fields: | |
img_shape (Tuple): The shape of the corresponding input image. | |
Used for visualization. | |
ori_shape (Tuple): The original shape of the corresponding image. | |
Used for visualization. | |
num_classes (int): The number of all categories. | |
Used for label format conversion. | |
Data fields: | |
gt_label (:obj:`~mmengine.structures.LabelData`): The ground truth | |
label. | |
pred_label (:obj:`~mmengine.structures.LabelData`): The predicted | |
label. | |
scores (torch.Tensor): The outputs of model. | |
logits (torch.Tensor): The outputs of model without softmax nor | |
sigmoid. | |
Examples: | |
>>> import torch | |
>>> from mmcls.structures import ClsDataSample | |
>>> | |
>>> img_meta = dict(img_shape=(960, 720), num_classes=5) | |
>>> data_sample = ClsDataSample(metainfo=img_meta) | |
>>> data_sample.set_gt_label(3) | |
>>> print(data_sample) | |
<ClsDataSample( | |
META INFORMATION | |
num_classes = 5 | |
img_shape = (960, 720) | |
DATA FIELDS | |
gt_label: <LabelData( | |
META INFORMATION | |
num_classes: 5 | |
DATA FIELDS | |
label: tensor([3]) | |
) at 0x7f21fb1b9190> | |
) at 0x7f21fb1b9880> | |
>>> # For multi-label data | |
>>> data_sample.set_gt_label([0, 1, 4]) | |
>>> print(data_sample.gt_label) | |
<LabelData( | |
META INFORMATION | |
num_classes: 5 | |
DATA FIELDS | |
label: tensor([0, 1, 4]) | |
) at 0x7fd7d1b41970> | |
>>> # Set one-hot format score | |
>>> score = torch.tensor([0.1, 0.1, 0.6, 0.1, 0.1]) | |
>>> data_sample.set_pred_score(score) | |
>>> print(data_sample.pred_label) | |
<LabelData( | |
META INFORMATION | |
num_classes: 5 | |
DATA FIELDS | |
score: tensor([0.1, 0.1, 0.6, 0.1, 0.1]) | |
) at 0x7fd7d1b41970> | |
""" | |
def set_gt_label( | |
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] | |
) -> 'ClsDataSample': | |
"""Set label of ``gt_label``.""" | |
label_data = getattr(self, '_gt_label', LabelData()) | |
label_data.label = format_label(value) | |
self.gt_label = label_data | |
return self | |
def set_gt_score(self, value: torch.Tensor) -> 'ClsDataSample': | |
"""Set score of ``gt_label``.""" | |
label_data = getattr(self, '_gt_label', LabelData()) | |
label_data.score = format_score(value) | |
if hasattr(self, 'num_classes'): | |
assert len(label_data.score) == self.num_classes, \ | |
f'The length of score {len(label_data.score)} should be '\ | |
f'equal to the num_classes {self.num_classes}.' | |
else: | |
self.set_field( | |
name='num_classes', | |
value=len(label_data.score), | |
field_type='metainfo') | |
self.gt_label = label_data | |
return self | |
def set_pred_label( | |
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number] | |
) -> 'ClsDataSample': | |
"""Set label of ``pred_label``.""" | |
label_data = getattr(self, '_pred_label', LabelData()) | |
label_data.label = format_label(value) | |
self.pred_label = label_data | |
return self | |
def set_pred_score(self, value: torch.Tensor) -> 'ClsDataSample': | |
"""Set score of ``pred_label``.""" | |
label_data = getattr(self, '_pred_label', LabelData()) | |
label_data.score = format_score(value) | |
if hasattr(self, 'num_classes'): | |
assert len(label_data.score) == self.num_classes, \ | |
f'The length of score {len(label_data.score)} should be '\ | |
f'equal to the num_classes {self.num_classes}.' | |
else: | |
self.set_field( | |
name='num_classes', | |
value=len(label_data.score), | |
field_type='metainfo') | |
self.pred_label = label_data | |
return self | |
def gt_label(self): | |
return self._gt_label | |
def gt_label(self, value: LabelData): | |
self.set_field(value, '_gt_label', dtype=LabelData) | |
def gt_label(self): | |
del self._gt_label | |
def pred_label(self): | |
return self._pred_label | |
def pred_label(self, value: LabelData): | |
self.set_field(value, '_pred_label', dtype=LabelData) | |
def pred_label(self): | |
del self._pred_label | |
def _reduce_cls_datasample(data_sample): | |
"""reduce ClsDataSample.""" | |
attr_dict = data_sample.__dict__ | |
convert_keys = [] | |
for k, v in attr_dict.items(): | |
if isinstance(v, LabelData): | |
attr_dict[k] = v.numpy() | |
convert_keys.append(k) | |
return _rebuild_cls_datasample, (attr_dict, convert_keys) | |
def _rebuild_cls_datasample(attr_dict, convert_keys): | |
"""rebuild ClsDataSample.""" | |
data_sample = ClsDataSample() | |
for k in convert_keys: | |
attr_dict[k] = attr_dict[k].to_tensor() | |
data_sample.__dict__ = attr_dict | |
return data_sample | |
# Due to the multi-processing strategy of PyTorch, ClsDataSample may consume | |
# many file descriptors because it contains multiple LabelData with tensors. | |
# Here we overwrite the reduce function of ClsDataSample in ForkingPickler and | |
# convert these tensors to np.ndarray during pickling. It may influence the | |
# performance of dataloader, but slightly because these tensors in LabelData | |
# are very small. | |
ForkingPickler.register(ClsDataSample, _reduce_cls_datasample) | |