Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved | |
import numpy as np | |
from mmengine.hooks import Hook | |
from mmengine.model import is_model_wrapper | |
from mmcls.models.heads import ArcFaceClsHead | |
from mmcls.registry import HOOKS | |
class SetAdaptiveMarginsHook(Hook): | |
r"""Set adaptive-margins in ArcFaceClsHead based on the power of | |
category-wise count. | |
A PyTorch implementation of paper `Google Landmark Recognition 2020 | |
Competition Third Place Solution <https://arxiv.org/abs/2010.05350>`_. | |
The margins will be | |
:math:`\text{f}(n) = (marginMax - marginMin) · norm(n^p) + marginMin`. | |
The `n` indicates the number of occurrences of a category. | |
Args: | |
margin_min (float): Lower bound of margins. Defaults to 0.05. | |
margin_max (float): Upper bound of margins. Defaults to 0.5. | |
power (float): The power of category freqercy. Defaults to -0.25. | |
""" | |
def __init__(self, margin_min=0.05, margin_max=0.5, power=-0.25) -> None: | |
self.margin_min = margin_min | |
self.margin_max = margin_max | |
self.margin_range = margin_max - margin_min | |
self.p = power | |
def before_train(self, runner): | |
"""change the margins in ArcFaceClsHead. | |
Args: | |
runner (obj: `Runner`): Runner. | |
""" | |
model = runner.model | |
if is_model_wrapper(model): | |
model = model.module | |
if (hasattr(model, 'head') | |
and not isinstance(model.head, ArcFaceClsHead)): | |
raise ValueError( | |
'Hook ``SetFreqPowAdvMarginsHook`` could only be used ' | |
f'for ``ArcFaceClsHead``, but get {type(model.head)}') | |
# generate margins base on the dataset. | |
gt_labels = runner.train_dataloader.dataset.get_gt_labels() | |
label_count = np.bincount(gt_labels) | |
label_count[label_count == 0] = 1 # At least one occurrence | |
pow_freq = np.power(label_count, self.p) | |
min_f, max_f = pow_freq.min(), pow_freq.max() | |
normized_pow_freq = (pow_freq - min_f) / (max_f - min_f) | |
margins = normized_pow_freq * self.margin_range + self.margin_min | |
assert len(margins) == runner.model.head.num_classes | |
model.head.set_margins(margins) | |