Spaces:
Running
on
Zero
Running
on
Zero
Print logits for debugging.
Browse files- utils/model.py +6 -0
utils/model.py
CHANGED
@@ -458,6 +458,9 @@ class OwlViTForClassification(nn.Module):
|
|
458 |
loss_dict["loss_sym_box_label"] = sym_loss_box_label
|
459 |
# ----------------------------------------------------------------------------------------
|
460 |
|
|
|
|
|
|
|
461 |
# Predict image-level classes (batch_size, num_patches, num_queries)
|
462 |
image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
|
463 |
|
@@ -475,6 +478,9 @@ class OwlViTForClassification(nn.Module):
|
|
475 |
sym_loss = self.loss_symmetric(logits_per_image, logits_per_text, targets_cls)
|
476 |
loss_dict["loss_xclip"] = sym_loss
|
477 |
|
|
|
|
|
|
|
478 |
return pred_logits, part_logits, loss_dict
|
479 |
|
480 |
def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
|
|
|
458 |
loss_dict["loss_sym_box_label"] = sym_loss_box_label
|
459 |
# ----------------------------------------------------------------------------------------
|
460 |
|
461 |
+
#DEBUG:
|
462 |
+
print(f"im_features size: {image_feats.shape}, text_embeds size: {text_embeds.shape}")
|
463 |
+
print(f"im_features sum: {image_feats.sum().item()}, text_embeds sum: {text_embeds.sum().item()}")
|
464 |
# Predict image-level classes (batch_size, num_patches, num_queries)
|
465 |
image_text_logits, pred_logits, part_logits = self.cls_head(image_feats, text_embeds, topk_idxs)
|
466 |
|
|
|
478 |
sym_loss = self.loss_symmetric(logits_per_image, logits_per_text, targets_cls)
|
479 |
loss_dict["loss_xclip"] = sym_loss
|
480 |
|
481 |
+
#DEBUG:
|
482 |
+
print(f"pred_logits size: {part_logits.shape}, pred_logits size: {part_logits.shape}")
|
483 |
+
print(f"part_logits sum: {pred_logits.sum().item()}, part_logits sum: {pred_logits.sum().item()}")
|
484 |
return pred_logits, part_logits, loss_dict
|
485 |
|
486 |
def loss_symmetric(self, text_logits: torch.Tensor, image_logits: torch.Tensor, targets: torch.Tensor, box_labels: torch.Tensor = None) -> torch.Tensor:
|