|
|
|
import torch |
|
|
|
from annotator.oneformer.detectron2.layers import nonzero_tuple |
|
|
|
__all__ = ["subsample_labels"] |
|
|
|
|
|
def subsample_labels( |
|
labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int |
|
): |
|
""" |
|
Return `num_samples` (or fewer, if not enough found) |
|
random samples from `labels` which is a mixture of positives & negatives. |
|
It will try to return as many positives as possible without |
|
exceeding `positive_fraction * num_samples`, and then try to |
|
fill the remaining slots with negatives. |
|
|
|
Args: |
|
labels (Tensor): (N, ) label vector with values: |
|
* -1: ignore |
|
* bg_label: background ("negative") class |
|
* otherwise: one or more foreground ("positive") classes |
|
num_samples (int): The total number of labels with value >= 0 to return. |
|
Values that are not sampled will be filled with -1 (ignore). |
|
positive_fraction (float): The number of subsampled labels with values > 0 |
|
is `min(num_positives, int(positive_fraction * num_samples))`. The number |
|
of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`. |
|
In order words, if there are not enough positives, the sample is filled with |
|
negatives. If there are also not enough negatives, then as many elements are |
|
sampled as is possible. |
|
bg_label (int): label index of background ("negative") class. |
|
|
|
Returns: |
|
pos_idx, neg_idx (Tensor): |
|
1D vector of indices. The total length of both is `num_samples` or fewer. |
|
""" |
|
positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0] |
|
negative = nonzero_tuple(labels == bg_label)[0] |
|
|
|
num_pos = int(num_samples * positive_fraction) |
|
|
|
num_pos = min(positive.numel(), num_pos) |
|
num_neg = num_samples - num_pos |
|
|
|
num_neg = min(negative.numel(), num_neg) |
|
|
|
|
|
perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] |
|
perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] |
|
|
|
pos_idx = positive[perm1] |
|
neg_idx = negative[perm2] |
|
return pos_idx, neg_idx |
|
|