File size: 3,174 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
import torch.nn.functional as F
from mmengine.structures import LabelData

if hasattr(torch, 'tensor_split'):
    tensor_split = torch.tensor_split
else:
    # A simple implementation of `tensor_split`.
    def tensor_split(input: torch.Tensor, indices: list):
        outs = []
        for start, end in zip([0] + indices, indices + [input.size(0)]):
            outs.append(input[start:end])
        return outs


def cat_batch_labels(elements: List[LabelData], device=None):
    """Concat the ``label`` of a batch of :obj:`LabelData` to a tensor.

    Args:
        elements (List[LabelData]): A batch of :obj`LabelData`.
        device (torch.device, optional): The output device of the batch label.
            Defaults to None.

    Returns:
        Tuple[torch.Tensor, List[int]]: The first item is the concated label
        tensor, and the second item is the split indices of every sample.
    """
    item = elements[0]
    if 'label' not in item._data_fields:
        return None, None

    labels = []
    splits = [0]
    for element in elements:
        labels.append(element.label)
        splits.append(splits[-1] + element.label.size(0))
    batch_label = torch.cat(labels)
    if device is not None:
        batch_label = batch_label.to(device=device)
    return batch_label, splits[1:-1]


def batch_label_to_onehot(batch_label, split_indices, num_classes):
    """Convert a concated label tensor to onehot format.

    Args:
        batch_label (torch.Tensor): A concated label tensor from multiple
            samples.
        split_indices (List[int]): The split indices of every sample.
        num_classes (int): The number of classes.

    Returns:
        torch.Tensor: The onehot format label tensor.

    Examples:
        >>> import torch
        >>> from mmcls.structures import batch_label_to_onehot
        >>> # Assume a concated label from 3 samples.
        >>> # label 1: [0, 1], label 2: [0, 2, 4], label 3: [3, 1]
        >>> batch_label = torch.tensor([0, 1, 0, 2, 4, 3, 1])
        >>> split_indices = [2, 5]
        >>> batch_label_to_onehot(batch_label, split_indices, num_classes=5)
        tensor([[1, 1, 0, 0, 0],
                [1, 0, 1, 0, 1],
                [0, 1, 0, 1, 0]])
    """
    sparse_onehot_list = F.one_hot(batch_label, num_classes)
    onehot_list = [
        sparse_onehot.sum(0)
        for sparse_onehot in tensor_split(sparse_onehot_list, split_indices)
    ]
    return torch.stack(onehot_list)


def stack_batch_scores(elements, device=None):
    """Stack the ``score`` of a batch of :obj:`LabelData` to a tensor.

    Args:
        elements (List[LabelData]): A batch of :obj`LabelData`.
        device (torch.device, optional): The output device of the batch label.
            Defaults to None.

    Returns:
        torch.Tensor: The stacked score tensor.
    """
    item = elements[0]
    if 'score' not in item._data_fields:
        return None

    batch_score = torch.stack([element.score for element in elements])
    if device is not None:
        batch_score = batch_score.to(device)
    return batch_score