Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List | |
import torch | |
import torch.nn.functional as F | |
from mmengine.structures import LabelData | |
if hasattr(torch, 'tensor_split'): | |
tensor_split = torch.tensor_split | |
else: | |
# A simple implementation of `tensor_split`. | |
def tensor_split(input: torch.Tensor, indices: list): | |
outs = [] | |
for start, end in zip([0] + indices, indices + [input.size(0)]): | |
outs.append(input[start:end]) | |
return outs | |
def cat_batch_labels(elements: List[LabelData], device=None): | |
"""Concat the ``label`` of a batch of :obj:`LabelData` to a tensor. | |
Args: | |
elements (List[LabelData]): A batch of :obj`LabelData`. | |
device (torch.device, optional): The output device of the batch label. | |
Defaults to None. | |
Returns: | |
Tuple[torch.Tensor, List[int]]: The first item is the concated label | |
tensor, and the second item is the split indices of every sample. | |
""" | |
item = elements[0] | |
if 'label' not in item._data_fields: | |
return None, None | |
labels = [] | |
splits = [0] | |
for element in elements: | |
labels.append(element.label) | |
splits.append(splits[-1] + element.label.size(0)) | |
batch_label = torch.cat(labels) | |
if device is not None: | |
batch_label = batch_label.to(device=device) | |
return batch_label, splits[1:-1] | |
def batch_label_to_onehot(batch_label, split_indices, num_classes): | |
"""Convert a concated label tensor to onehot format. | |
Args: | |
batch_label (torch.Tensor): A concated label tensor from multiple | |
samples. | |
split_indices (List[int]): The split indices of every sample. | |
num_classes (int): The number of classes. | |
Returns: | |
torch.Tensor: The onehot format label tensor. | |
Examples: | |
>>> import torch | |
>>> from mmcls.structures import batch_label_to_onehot | |
>>> # Assume a concated label from 3 samples. | |
>>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1] | |
>>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1]) | |
>>> split_indices = [2, 5] | |
>>> batch_label_to_onehot(batch_label, split_indices, num_classes=5) | |
tensor([[1, 1, 0, 0, 0], | |
[1, 0, 1, 0, 1], | |
[0, 1, 0, 1, 0]]) | |
""" | |
sparse_onehot_list = F.one_hot(batch_label, num_classes) | |
onehot_list = [ | |
sparse_onehot.sum(0) | |
for sparse_onehot in tensor_split(sparse_onehot_list, split_indices) | |
] | |
return torch.stack(onehot_list) | |
def stack_batch_scores(elements, device=None): | |
"""Stack the ``score`` of a batch of :obj:`LabelData` to a tensor. | |
Args: | |
elements (List[LabelData]): A batch of :obj`LabelData`. | |
device (torch.device, optional): The output device of the batch label. | |
Defaults to None. | |
Returns: | |
torch.Tensor: The stacked score tensor. | |
""" | |
item = elements[0] | |
if 'score' not in item._data_fields: | |
return None | |
batch_score = torch.stack([element.score for element in elements]) | |
if device is not None: | |
batch_score = batch_score.to(device) | |
return batch_score | |