Spaces:
Runtime error
Runtime error
File size: 5,129 Bytes
f549064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence, Tuple
import torch
import torch.nn as nn
from mmengine.model import ModuleDict
from mmcls.registry import MODELS
from mmcls.structures import MultiTaskDataSample
from .base_head import BaseHead
def loss_convertor(loss_func, task_name):
def wrapped(inputs, data_samples, **kwargs):
mask = torch.empty(len(data_samples), dtype=torch.bool)
task_data_samples = []
for i, data_sample in enumerate(data_samples):
assert isinstance(data_sample, MultiTaskDataSample)
sample_mask = task_name in data_sample
mask[i] = sample_mask
if sample_mask:
task_data_samples.append(data_sample.get(task_name))
if len(task_data_samples) == 0:
return {'loss': torch.tensor(0.), 'mask_size': torch.tensor(0.)}
# Mask the inputs of the task
def mask_inputs(inputs, mask):
if isinstance(inputs, Sequence):
return type(inputs)(
[mask_inputs(input, mask) for input in inputs])
elif isinstance(inputs, torch.Tensor):
return inputs[mask]
masked_inputs = mask_inputs(inputs, mask)
loss_output = loss_func(masked_inputs, task_data_samples, **kwargs)
loss_output['mask_size'] = mask.sum().to(torch.float)
return loss_output
return wrapped
@MODELS.register_module()
class MultiTaskHead(BaseHead):
"""Multi task head.
Args:
task_heads (dict): Sub heads to use, the key will be use to rename the
loss components.
common_cfg (dict): The common settings for all heads. Defaults to an
empty dict.
init_cfg (dict, optional): The extra initialization settings.
Defaults to None.
"""
def __init__(self, task_heads, init_cfg=None, **kwargs):
super(MultiTaskHead, self).__init__(init_cfg=init_cfg)
assert isinstance(task_heads, dict), 'The `task_heads` argument' \
"should be a dict, which's keys are task names and values are" \
'configs of head for the task.'
self.task_heads = ModuleDict()
for task_name, sub_head in task_heads.items():
if not isinstance(sub_head, nn.Module):
sub_head = MODELS.build(sub_head, default_args=kwargs)
sub_head.loss = loss_convertor(sub_head.loss, task_name)
self.task_heads[task_name] = sub_head
def forward(self, feats):
"""The forward process."""
return {
task_name: head(feats)
for task_name, head in self.task_heads.items()
}
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[MultiTaskDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
data_samples (List[MultiTaskDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components, each task loss
key will be prefixed by the task_name like "task1_loss"
"""
losses = dict()
for task_name, head in self.task_heads.items():
head_loss = head.loss(feats, data_samples, **kwargs)
for k, v in head_loss.items():
losses[f'{task_name}_{k}'] = v
return losses
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[MultiTaskDataSample] = None
) -> List[MultiTaskDataSample]:
"""Inference without augmentation.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
data_samples (List[MultiTaskDataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[MultiTaskDataSample]: A list of data samples which contains
the predicted results.
"""
predictions_dict = dict()
for task_name, head in self.task_heads.items():
task_samples = head.predict(feats)
batch_size = len(task_samples)
predictions_dict[task_name] = task_samples
if data_samples is None:
data_samples = [MultiTaskDataSample() for _ in range(batch_size)]
for task_name, task_samples in predictions_dict.items():
for data_sample, task_sample in zip(data_samples, task_samples):
task_sample.set_field(
task_name in data_sample.tasks,
'eval_mask',
field_type='metainfo')
if task_name in data_sample.tasks:
data_sample.get(task_name).update(task_sample)
else:
data_sample.set_field(task_sample, task_name)
return data_samples
|