from mmdet.models import ResNet, MaskFormerFusionHead, CrossEntropyLoss, DiceLoss from app.models.detectors import YOSOVideoSam from app.models.heads import RapSAMVideoHead from app.models.necks import YOSONeck num_things_classes = 80 num_stuff_classes = 53 ov_model_name = 'convnext_large_d_320' ov_datasets_name = 'CocoPanopticOVDataset' num_classes = num_things_classes + num_stuff_classes model = dict( type=YOSOVideoSam, data_preprocessor=None, backbone=dict( type=ResNet, depth=50, num_stages=4, out_indices=(0, 1, 2, 3), frozen_stages=-1, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=True, init_cfg=None, ), neck=dict( type=YOSONeck, agg_dim=128, hidden_dim=256, backbone_shape=[256, 512, 1024, 2048], ), panoptic_head=dict( type=RapSAMVideoHead, prompt_with_kernel_updator=False, panoptic_with_kernel_updator=True, use_adaptor=True, use_kernel_updator=True, sphere_cls=True, ov_classifier_name=f'{ov_model_name}_{ov_datasets_name}', num_stages=3, feat_channels=256, num_things_classes=num_things_classes, num_stuff_classes=num_stuff_classes, num_queries=100, loss_cls=dict( type=CrossEntropyLoss, use_sigmoid=False, loss_weight=2.0, reduction='mean', class_weight=[1.0] * num_classes + [0.1]), loss_mask=dict( type=CrossEntropyLoss, use_sigmoid=True, reduction='mean', loss_weight=5.0), loss_dice=dict( type=DiceLoss, use_sigmoid=True, activate=True, reduction='mean', naive_dice=True, eps=1.0, loss_weight=5.0) ), panoptic_fusion_head=dict( type=MaskFormerFusionHead, num_things_classes=num_things_classes, num_stuff_classes=num_stuff_classes, loss_panoptic=None, init_cfg=None ), train_cfg=None, test_cfg=dict( panoptic_on=True, # For now, the dataset does not support # evaluating semantic segmentation metric. semantic_on=False, instance_on=True, # max_per_image is for instance segmentation. max_per_image=100, iou_thr=0.8, # In Mask2Former's panoptic postprocessing, # it will filter mask area where score is less than 0.5 . filter_low_score=True), init_cfg=dict( type='Pretrained', checkpoint='models/rapsam_r50_12e.pth' ) )