chanfee commited on
Commit
25b63fc
·
verified ·
1 Parent(s): f7c353b

Print logits for debugging.

Browse files
Files changed (1) hide show
  1. utils/model.py +6 -0
utils/model.py CHANGED
@@ -458,6 +458,9 @@ class OwlViTForClassification(nn.Module):
458
  loss_dict["loss_sym_box_label"] = sym_loss_box_label
459
  # ----------------------------------------------------------------------------------------
460
 
 
 
 
461
  # Predict image-level classes (batch_size, num_patches, num_queries)
462
  image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
463
 
@@ -475,6 +478,9 @@ class OwlViTForClassification(nn.Module):
475
  sym_loss = self.loss_symmetric(logits_per_image, logits_per_text, targets_cls)
476
  loss_dict["loss_xclip"] = sym_loss
477
 
 
 
 
478
  return pred_logits, part_logits, loss_dict
479
 
480
  def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
 
458
  loss_dict["loss_sym_box_label"] = sym_loss_box_label
459
  # ----------------------------------------------------------------------------------------
460
 
461
+ #DEBUG:
462
+ print(f"im_features size: {image_feats.shape}, text_embeds size: {text_embeds.shape}")
463
+ print(f"im_features sum: {image_feats.sum().item()}, text_embeds sum: {text_embeds.sum().item()}")
464
  # Predict image-level classes (batch_size, num_patches, num_queries)
465
  image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
466
 
 
478
  sym_loss = self.loss_symmetric(logits_per_image, logits_per_text, targets_cls)
479
  loss_dict["loss_xclip"] = sym_loss
480
 
481
+ #DEBUG:
482
+ print(f"pred_logits size: {part_logits.shape}, pred_logits size: {part_logits.shape}")
483
+ print(f"part_logits sum: {pred_logits.sum().item()}, part_logits sum: {pred_logits.sum().item()}")
484
  return pred_logits, part_logits, loss_dict
485
 
486
  def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor: