from typing import List import os import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, AutoConfig, AutoModelForCausalLM from .segment_anything_2.sam2.build_sam import build_sam2, build_sam2_video_predictor from .unilm.beit3.modeling_utils import BEiT3Wrapper, _get_base_config, _get_large_config from .configuration_evf import EvfConfig from .segment_anything_2.sam2.utils.misc import load_video_frames from collections import OrderedDict def dice_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, scale=1000, # 100000.0, eps=1e-6, ): """ Compute the DICE loss, similar to generalized IOU for masks Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). """ inputs = inputs.sigmoid() inputs = inputs.flatten(1, 2) targets = targets.flatten(1, 2) numerator = 2 * (inputs / scale * targets).sum(-1) denominator = (inputs / scale).sum(-1) + (targets / scale).sum(-1) loss = 1 - (numerator + eps) / (denominator + eps) loss = loss.sum() / (num_masks + 1e-8) return loss def sigmoid_ce_loss( inputs: torch.Tensor, targets: torch.Tensor, num_masks: float, ): """ Args: inputs: A float tensor of arbitrary shape. The predictions for each example. targets: A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). Returns: Loss tensor """ loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") loss = loss.flatten(1, 2).mean(1).sum() / (num_masks + 1e-8) return loss class EvfSam2Model(PreTrainedModel): config_class = EvfConfig def __init__(self, config, **kwargs): super(EvfSam2Model, self).__init__(config) self.config = config self.vision_pretrained = kwargs.get("vision_pretrained", None) self.encoder_pretrained = kwargs.get("encoder_pretrained", None) self.dice_loss_weight = kwargs.get("dice_loss_weight", None) self.bce_loss_weight = kwargs.get("bce_loss_weight", None) self.train_mask_decoder = kwargs.get("train_mask_decoder", False) self.train_prompt_encoder = kwargs.get("train_prompt_encoder", False) self.initialize_evf_modules(config) self._bb_feat_sizes = [ (256, 256), (128, 128), (64, 64), ] def initialize_evf_modules(self, config): # SAM if config.sam_scale == "large": self.visual_model = build_sam2_video_predictor( "sam2_hiera_l.yaml", self.vision_pretrained, device=None) elif config.sam_scale == "tiny": self.visual_model = build_sam2_video_predictor( "sam2_hiera_t.yaml", self.vision_pretrained, device=None) else: raise NotImplementedError for param in self.visual_model.parameters(): param.requires_grad = False if self.train_mask_decoder: self.visual_model.sam_mask_decoder.train() for param in self.visual_model.sam_mask_decoder.parameters(): param.requires_grad = True if self.train_prompt_encoder: self.visual_model.sam_prompt_encoder.no_mask_embed.requires_grad_( True) # beit-3 if self.config.mm_extractor_scale == "base": beit_config = _get_base_config() elif self.config.mm_extractor_scale == "large": beit_config = _get_large_config() else: raise AttributeError( f"model config should contain key 'mm_extractor_scale', with value 'base' or 'large'." ) self.mm_extractor = BEiT3Wrapper(beit_config) if self.encoder_pretrained is not None: beit_state_dict = torch.load(self.encoder_pretrained)["model"] self.mm_extractor.load_state_dict(beit_state_dict, strict=False) for param in self.mm_extractor.parameters(): param.requires_grad = True # Projection layer in_dim = config.hidden_size assert in_dim==beit_config.encoder_embed_dim, \ f"projection layer dim {in_dim} mismatch with mm_extractor dim {beit_config.encoder_embed_dim}" out_dim = config.out_dim text_fc = [ nn.Linear(in_dim, in_dim), nn.ReLU(), nn.Linear(in_dim, out_dim) ] self.text_hidden_fcs = nn.ModuleList([nn.Sequential(*text_fc)]) self.text_hidden_fcs.train() for param in self.text_hidden_fcs.parameters(): param.requires_grad = True def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: """ Perform PostProcessing on output masks. """ masks = masks.float() masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) return masks # def forward( # self, # images: torch.FloatTensor, # images_evf: torch.FloatTensor, # input_ids: torch.LongTensor, # attention_masks: torch.LongTensor, # offset: torch.LongTensor, # masks_list: List[torch.FloatTensor], # label_list: List[torch.Tensor], # resize_list: List[tuple], # inference: bool = False, # **kwargs, # ): # # image_embeddings = self.get_visual_embs(images) # backbone_out = self.visual_model.forward_image(images) # # dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn']) # _, image_embeddings, _, _ = self.visual_model._prepare_backbone_features(backbone_out) # image_embeddings = [_.to(images.dtype) for _ in image_embeddings] # batch_size = images.shape[0] # if self.visual_model.directly_add_no_mem_embed: # image_embeddings[-1] = image_embeddings[-1] + self.visual_model.no_mem_embed # feats = [ # feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) # for feat, feat_size in zip(image_embeddings[::-1], self._bb_feat_sizes[::-1]) # ][::-1] # _features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} # assert batch_size == len(offset) - 1 # images_evf_list = [] # for i in range(len(offset) - 1): # start_i, end_i = offset[i], offset[i + 1] # images_evf_i = ( # images_evf[i] # .unsqueeze(0) # .expand(end_i - start_i, -1, -1, -1) # .contiguous() # ) # images_evf_list.append(images_evf_i) # images_evf = torch.cat(images_evf_list, dim=0) # multimask_output = False # output = self.mm_extractor.beit3( # visual_tokens=images_evf, # textual_tokens=input_ids, # text_padding_position=~attention_masks # ) # feat = output["encoder_out"][:, :1, ...] # feat = self.text_hidden_fcs[0](feat) # feat = torch.split(feat, [offset[i+1] - offset[i] for i in range(len(offset)-1)]) # pred_masks = [] # for i in range(len(feat)): # ( # sparse_embeddings, # dense_embeddings, # ) = self.visual_model.sam_prompt_encoder( # points=None, # boxes=None, # masks=None, # text_embeds=feat[i], # ) # sparse_embeddings = sparse_embeddings.to(feat[i].dtype) # high_res_features = [ # feat_level[i].unsqueeze(0) # for feat_level in _features["high_res_feats"] # ] # low_res_masks, iou_predictions, _, _ = self.visual_model.sam_mask_decoder( # image_embeddings=_features["image_embed"][i].unsqueeze(0), # image_pe=self.visual_model.sam_prompt_encoder.get_dense_pe(), # sparse_prompt_embeddings=sparse_embeddings, # dense_prompt_embeddings=dense_embeddings, # multimask_output=multimask_output, # repeat_image = True, # high_res_features=high_res_features, # ) # if multimask_output: # sorted_ids = torch.argsort(iou_predictions, dim=-1, descending=True) # low_res_masks = torch.take_along_dim(low_res_masks, sorted_ids[..., None, None], dim=1)[:, :1] # pred_mask = self.postprocess_masks( # low_res_masks, # orig_hw=label_list[i].shape, # ) # pred_masks.append(pred_mask[:, 0]) # gt_masks = masks_list # if inference: # return { # "pred_masks": pred_masks, # "gt_masks": gt_masks, # } # mask_bce_loss = 0 # mask_dice_loss = 0 # num_masks = 0 # for batch_idx in range(len(pred_masks)): # gt_mask = gt_masks[batch_idx] # pred_mask = pred_masks[batch_idx] # assert ( # gt_mask.shape[0] == pred_mask.shape[0] # ), "gt_mask.shape: {}, pred_mask.shape: {}".format( # gt_mask.shape, pred_mask.shape # ) # mask_bce_loss += ( # sigmoid_ce_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) # * gt_mask.shape[0] # ) # mask_dice_loss += ( # dice_loss(pred_mask, gt_mask, num_masks=gt_mask.shape[0]) # * gt_mask.shape[0] # ) # num_masks += gt_mask.shape[0] # mask_bce_loss = self.bce_loss_weight * mask_bce_loss / (num_masks + 1e-8) # mask_dice_loss = self.dice_loss_weight * mask_dice_loss / (num_masks + 1e-8) # mask_loss = mask_bce_loss + mask_dice_loss # loss = mask_loss # return { # "loss": loss, # "mask_bce_loss": mask_bce_loss, # "mask_dice_loss": mask_dice_loss, # "mask_loss": mask_loss, # } def inference( self, video_path, images_evf, input_ids, # original_size_list, multimask_output=False, ): predictor = self.visual_model inference_state = predictor.init_state(video_path=video_path) predictor.reset_state(inference_state) multimask_output = multimask_output output = self.mm_extractor.beit3( visual_tokens=images_evf, textual_tokens=input_ids, text_padding_position=torch.zeros_like(input_ids)) feat = output["encoder_out"][:, :1, ...] feat = self.text_hidden_fcs[0](feat) ann_frame_idx = 0 # the frame index we interact with ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers) _, out_obj_ids, out_mask_logits = predictor.add_new_text( inference_state=inference_state, frame_idx=ann_frame_idx, obj_id=ann_obj_id, text=feat) # run propagation throughout the video and collect the results in a dict video_segments = { } # video_segments contains the per-frame segmentation results for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video( inference_state): video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } return video_segments AutoConfig.register("evf", EvfConfig) AutoModelForCausalLM.register(EvfConfig, EvfSam2Model)