# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ DETR model and criterion classes. """ import torch import torch.nn.functional as F from torch import nn from utils.misc import (NestedTensor, nested_tensor_from_tensor_list) from .backbone import build_backbone from .transformer import build_transformer class DETR(nn.Module): """ This is the DETR module that performs object detection """ def __init__(self, backbone, transformer, num_queries, train_backbone, train_transformer, aux_loss=False): """ Initializes the model. Parameters: backbone: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py num_classes: number of object classes num_queries: number of object queries, ie detection slot. This is the maximal number of objects DETR can detect in a single image. For COCO, we recommend 100 queries. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. """ super().__init__() self.num_queries = num_queries self.transformer = transformer self.backbone = backbone if self.transformer is not None: hidden_dim = transformer.d_model self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) else: hidden_dim = backbone.num_channels if not train_backbone: for p in self.backbone.parameters(): p.requires_grad_(False) if self.transformer is not None and not train_transformer: for m in [self.transformer, self.input_proj]: for p in m.parameters(): p.requires_grad_(False) self.num_channels = hidden_dim def forward(self, samples: NestedTensor): """ The forward expects a NestedTensor, which consists of: - samples.tensor: batched images, of shape [batch_size x 3 x H x W] - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels It returns a dict with the following elements: - "pred_logits": the classification logits (including no-object) for all queries. Shape= [batch_size x num_queries x (num_classes + 1)] - "pred_boxes": The normalized boxes coordinates for all queries, represented as (center_x, center_y, height, width). These values are normalized in [0, 1], relative to the size of each individual image (disregarding possible padding). See PostProcess for information on how to retrieve the unnormalized bounding box. - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of dictionnaries containing the two above keys for each decoder layer. """ if isinstance(samples, (list, torch.Tensor)): samples = nested_tensor_from_tensor_list(samples) features, pos = self.backbone(samples) src, mask = features[-1].decompose() assert mask is not None if self.transformer is not None: out = self.transformer(self.input_proj(src), mask, pos[-1], query_embed=None) else: out = [mask.flatten(1), src.flatten(2).permute(2, 0, 1)] return out def build_detr(args): backbone = build_backbone(args) train_backbone = args.lr_visu_cnn > 0 train_transformer = args.lr_visu_tra > 0 if args.detr_enc_num > 0: transformer = build_transformer(args) else: transformer = None model = DETR( backbone, transformer, num_queries=args.num_queries, train_backbone=train_backbone, train_transformer=train_transformer ) return model