Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
import logging | |
from typing import Callable, Dict, List, Optional, Tuple, Union | |
import math | |
import copy | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from torch import nn | |
from torch.nn import functional as F | |
from fvcore.nn import sigmoid_focal_loss_jit, giou_loss, smooth_l1_loss | |
import fvcore.nn.weight_init as weight_init | |
from detectron2.config import configurable | |
from detectron2.data.detection_utils import get_fed_loss_cls_weights | |
from detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple | |
from detectron2.modeling.box_regression import Box2BoxTransform, _dense_box_regression_loss | |
from detectron2.structures import Boxes, Instances, BitMasks, pairwise_iou, pairwise_ioa | |
from detectron2.utils.events import get_event_storage | |
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers | |
class TexeEmbedClassifier(nn.Module): | |
def __init__( | |
self, | |
input_shape: ShapeSpec, | |
zs_weight_dim: int = 1024, | |
norm_weight: bool = True, | |
norm_temperature: float = 50.0, | |
): | |
super().__init__() | |
if isinstance(input_shape, int): # some backward compatibility | |
input_shape = ShapeSpec(channels=input_shape) | |
input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1) | |
self.norm_weight = norm_weight | |
self.norm_temperature = norm_temperature | |
self.linear = nn.Linear(input_size, zs_weight_dim) | |
def forward(self, x, text_embed): | |
x = self.linear(x) | |
if self.norm_weight: | |
x = self.norm_temperature * F.normalize(x, p=2, dim=1) | |
x = torch.mm(x, text_embed) | |
return x | |
class VLMFastRCNNOutputLayers(nn.Module): | |
def __init__( | |
self, | |
input_shape: ShapeSpec, | |
box2box_transform, | |
use_sigmoid_ce: bool = True, | |
test_score_thresh: float = 0.0, | |
test_nms_thresh: float = 0.5, | |
test_topk_per_image: int = 100, | |
): | |
super().__init__() | |
if isinstance(input_shape, int): # some backward compatibility | |
input_shape = ShapeSpec(channels=input_shape) | |
self.box2box_transform = box2box_transform | |
self.use_sigmoid_ce = use_sigmoid_ce | |
self.test_score_thresh = test_score_thresh | |
self.test_nms_thresh = test_nms_thresh | |
self.test_topk_per_image = test_topk_per_image | |
input_size = input_shape.channels * \ | |
(input_shape.width or 1) * (input_shape.height or 1) | |
# bbox_pred | |
self.bbox_pred = nn.Sequential( | |
nn.Linear(input_size, input_size), | |
nn.ReLU(inplace=True), | |
nn.Linear(input_size, 4) | |
) | |
# cls_score | |
self.cls_score = TexeEmbedClassifier(input_shape) | |
def forward(self, x, text_embed): | |
if x.dim() > 2: | |
x = torch.flatten(x, start_dim=1) | |
cls_scores = self.cls_score(x, text_embed) | |
proposal_deltas = self.bbox_pred(x) | |
return cls_scores, proposal_deltas | |
def predict_boxes(self, predictions, proposals): | |
if not len(proposals): | |
return [] | |
_, proposal_deltas = predictions | |
num_prop_per_image = [len(p) for p in proposals] | |
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) | |
predict_boxes = self.box2box_transform.apply_deltas( | |
proposal_deltas, | |
proposal_boxes, | |
) # Nx(KxB) | |
return predict_boxes.split(num_prop_per_image) | |
def predict_probs(self, predictions, proposals): | |
cls_scores, _ = predictions | |
num_inst_per_image = [len(p) for p in proposals] | |
cls_scores = cls_scores.split(num_inst_per_image, dim=0) | |
final_scores = [] | |
for cls_score in cls_scores: | |
final_score = cls_score.sigmoid() if self.use_sigmoid_ce else F.softmax(cls_score, dim=-1) | |
final_scores.append(final_score) | |
return final_scores | |