Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from mmdet.registry import MODELS | |
from .utils import weighted_loss | |
def knowledge_distillation_kl_div_loss(pred: Tensor, | |
soft_label: Tensor, | |
T: int, | |
detach_target: bool = True) -> Tensor: | |
r"""Loss function for knowledge distilling using KL divergence. | |
Args: | |
pred (Tensor): Predicted logits with shape (N, n + 1). | |
soft_label (Tensor): Target logits with shape (N, N + 1). | |
T (int): Temperature for distillation. | |
detach_target (bool): Remove soft_label from automatic differentiation | |
Returns: | |
Tensor: Loss tensor with shape (N,). | |
""" | |
assert pred.size() == soft_label.size() | |
target = F.softmax(soft_label / T, dim=1) | |
if detach_target: | |
target = target.detach() | |
kd_loss = F.kl_div( | |
F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * ( | |
T * T) | |
return kd_loss | |
class KnowledgeDistillationKLDivLoss(nn.Module): | |
"""Loss function for knowledge distilling using KL divergence. | |
Args: | |
reduction (str): Options are `'none'`, `'mean'` and `'sum'`. | |
loss_weight (float): Loss weight of current loss. | |
T (int): Temperature for distillation. | |
""" | |
def __init__(self, | |
reduction: str = 'mean', | |
loss_weight: float = 1.0, | |
T: int = 10) -> None: | |
super().__init__() | |
assert T >= 1 | |
self.reduction = reduction | |
self.loss_weight = loss_weight | |
self.T = T | |
def forward(self, | |
pred: Tensor, | |
soft_label: Tensor, | |
weight: Optional[Tensor] = None, | |
avg_factor: Optional[int] = None, | |
reduction_override: Optional[str] = None) -> Tensor: | |
"""Forward function. | |
Args: | |
pred (Tensor): Predicted logits with shape (N, n + 1). | |
soft_label (Tensor): Target logits with shape (N, N + 1). | |
weight (Tensor, optional): The weight of loss for each | |
prediction. Defaults to None. | |
avg_factor (int, optional): Average factor that is used to average | |
the loss. Defaults to None. | |
reduction_override (str, optional): The reduction method used to | |
override the original reduction method of the loss. | |
Defaults to None. | |
Returns: | |
Tensor: Loss tensor. | |
""" | |
assert reduction_override in (None, 'none', 'mean', 'sum') | |
reduction = ( | |
reduction_override if reduction_override else self.reduction) | |
loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss( | |
pred, | |
soft_label, | |
weight, | |
reduction=reduction, | |
avg_factor=avg_factor, | |
T=self.T) | |
return loss_kd | |