|
import torch |
|
import torch.nn as nn |
|
|
|
class BaseLosses(nn.Module): |
|
def __init__(self, cfg, losses, params, losses_func, num_joints, **kwargs): |
|
super().__init__() |
|
|
|
|
|
self.num_joints = num_joints |
|
self._params = params |
|
|
|
|
|
losses.append("total") if "total" not in losses else None |
|
|
|
|
|
for loss in losses: |
|
self.register_buffer(loss, torch.tensor(0.0)) |
|
self.register_buffer("count", torch.tensor(0.0)) |
|
self.losses = losses |
|
|
|
|
|
self._losses_func = {} |
|
for loss in losses[:-1]: |
|
self._losses_func[loss] = losses_func[loss](reduction='mean') |
|
|
|
def _update_loss(self, loss: str, outputs, inputs): |
|
'''Update the loss and return the weighted loss.''' |
|
|
|
val = self._losses_func[loss](outputs, inputs) |
|
|
|
getattr(self, loss).add_(val.detach()) |
|
|
|
weighted_loss = self._params[loss] * val |
|
return weighted_loss |
|
|
|
def reset(self): |
|
'''Reset the losses to 0.''' |
|
for loss in self.losses: |
|
setattr(self, loss, torch.tensor(0.0, device=getattr(self, loss).device)) |
|
setattr(self, "count", torch.tensor(0.0, device=getattr(self, "count").device)) |
|
|
|
def compute(self, split): |
|
'''Compute the losses and return a dictionary with the losses.''' |
|
count = self.count |
|
|
|
loss_dict = {loss: getattr(self, loss)/count for loss in self.losses} |
|
|
|
log_dict = { self.loss2logname(loss, split): value.item() |
|
for loss, value in loss_dict.items() if not torch.isnan(value)} |
|
|
|
self.reset() |
|
return log_dict |
|
|
|
def loss2logname(self, loss: str, split: str): |
|
'''Convert the loss name to a log name.''' |
|
if loss == "total": |
|
log_name = f"{loss}/{split}" |
|
else: |
|
loss_type, name = loss.split("_") |
|
log_name = f"{loss_type}/{name}/{split}" |
|
return log_name |
|
|