Spaces:
Runtime error
Runtime error
File size: 8,880 Bytes
5e0b9df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# ------------------------------------------------------------------------
# HOTR official code : hotr/models/detr.py
# Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
"""
DETR & HOTR model and criterion classes.
"""
import torch
import torch.nn.functional as F
from torch import nn
from hotr.util.misc import (NestedTensor, nested_tensor_from_tensor_list)
from .backbone import build_backbone
from .detr_matcher import build_matcher
from .hotr_matcher import build_hoi_matcher
from .transformer import build_transformer, build_hoi_transformer
from .criterion import SetCriterion
from .post_process import PostProcess
from .feed_forward import MLP
from .hotr import HOTR
from .hotr_v1 import HOTR_V1
class DETR(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of object classes
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
hidden_dim = transformer.d_model
self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
self.backbone = backbone
self.aux_loss = aux_loss
def forward(self, samples: NestedTensor):
""" The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the classification logits (including no-object) for all queries.
Shape= [batch_size x num_queries x (num_classes + 1)]
- "pred_boxes": The normalized boxes coordinates for all queries, represented as
(center_x, center_y, height, width). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if isinstance(samples, (list, torch.Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)
src, mask = features[-1].decompose()
assert mask is not None
hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
outputs_class = self.class_embed(hs)
outputs_coord = self.bbox_embed(hs).sigmoid()
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
return out
@torch.jit.unused
def _set_aux_loss(self, outputs_class, outputs_coord):
# this is a workaround to make torchscript happy, as torchscript
# doesn't support dictionary with non-homogeneous values, such
# as a dict having both a Tensor and a list.
return [{'pred_logits': a, 'pred_boxes': b}
for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
def build(args):
device = torch.device(args.device)
backbone = build_backbone(args)
transformer = build_transformer(args)
model = DETR(
backbone,
transformer,
num_classes=args.num_classes,
num_queries=args.num_queries,
aux_loss=args.aux_loss,
)
matcher = build_matcher(args)
weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
weight_dict['loss_giou'] = args.giou_loss_coef
# TODO this is a hack
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1):
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
losses = ['labels', 'boxes', 'cardinality'] if args.frozen_weights is None else []
if args.HOIDet:
hoi_matcher = build_hoi_matcher(args)
hoi_losses = []
hoi_losses.append('pair_labels')
hoi_losses.append('pair_actions')
if args.dataset_file == 'hico-det': hoi_losses.append('pair_targets')
hoi_weight_dict={}
hoi_weight_dict['loss_hidx'] = args.hoi_idx_loss_coef
hoi_weight_dict['loss_oidx'] = args.hoi_idx_loss_coef
hoi_weight_dict['loss_h_consistency'] = args.hoi_idx_consistency_loss_coef
hoi_weight_dict['loss_o_consistency'] = args.hoi_idx_consistency_loss_coef
hoi_weight_dict['loss_act'] = args.hoi_act_loss_coef
hoi_weight_dict['loss_act_consistency'] = args.hoi_act_consistency_loss_coef
if args.dataset_file == 'hico-det':
hoi_weight_dict['loss_tgt'] = args.hoi_tgt_loss_coef
hoi_weight_dict['loss_tgt_consistency'] = args.hoi_tgt_consistency_loss_coef
if args.hoi_aux_loss:
hoi_aux_weight_dict = {}
for i in range(args.hoi_dec_layers):
hoi_aux_weight_dict.update({k + f'_{i}': v for k, v in hoi_weight_dict.items()})
hoi_weight_dict.update(hoi_aux_weight_dict)
criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=hoi_weight_dict,
eos_coef=args.eos_coef, losses=losses, num_actions=args.num_actions,
HOI_losses=hoi_losses, HOI_matcher=hoi_matcher, args=args)
interaction_transformer = build_hoi_transformer(args) # if (args.share_enc and args.pretrained_dec) else None
kwargs = {}
if args.dataset_file == 'hico-det': kwargs['return_obj_class'] = args.valid_obj_ids
if args.sep_enc_forward:
model = HOTR_V1(
detr=model,
num_hoi_queries=args.num_hoi_queries,
num_actions=args.num_actions,
interaction_transformer=interaction_transformer,
augpath_name = args.augpath_name,
share_dec_param = args.share_dec_param,
stop_grad_stage = args.stop_grad_stage,
freeze_detr=(args.frozen_weights is not None),
share_enc=args.share_enc,
pretrained_dec=args.pretrained_dec,
temperature=args.temperature,
hoi_aux_loss=args.hoi_aux_loss,
**kwargs # only return verb class for HICO-DET dataset
)
else:
model = HOTR(
detr=model,
num_hoi_queries=args.num_hoi_queries,
num_actions=args.num_actions,
interaction_transformer=interaction_transformer,
augpath_name = args.augpath_name,
share_dec_param = args.share_dec_param,
stop_grad_stage = args.stop_grad_stage,
freeze_detr=(args.frozen_weights is not None),
share_enc=args.share_enc,
pretrained_dec=args.pretrained_dec,
temperature=args.temperature,
hoi_aux_loss=args.hoi_aux_loss,
**kwargs # only return verb class for HICO-DET dataset
)
postprocessors = {'hoi': PostProcess(args.HOIDet)}
else:
criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=weight_dict,
eos_coef=args.eos_coef, losses=losses)
postprocessors = {'bbox': PostProcess(args.HOIDet)}
criterion.to(device)
return model, criterion, postprocessors |