|
import torch |
|
import torch.nn as nn |
|
from .base import BaseLosses |
|
|
|
|
|
class CommitLoss(nn.Module): |
|
""" |
|
Useless Wrapper |
|
""" |
|
def __init__(self, **kwargs): |
|
super().__init__() |
|
|
|
def forward(self, commit, commit2, **kwargs): |
|
return commit |
|
|
|
|
|
class GPTLosses(BaseLosses): |
|
|
|
def __init__(self, cfg, stage, num_joints, **kwargs): |
|
|
|
self.stage = stage |
|
recons_loss = cfg.LOSS.ABLATION.RECONS_LOSS |
|
|
|
|
|
losses = [] |
|
params = {} |
|
if stage == "vae": |
|
losses.append("recons_feature") |
|
params['recons_feature'] = cfg.LOSS.LAMBDA_FEATURE |
|
|
|
losses.append("recons_velocity") |
|
params['recons_velocity'] = cfg.LOSS.LAMBDA_VELOCITY |
|
|
|
losses.append("vq_commit") |
|
params['vq_commit'] = cfg.LOSS.LAMBDA_COMMIT |
|
elif stage in ["lm_pretrain", "lm_instruct"]: |
|
losses.append("gpt_loss") |
|
params['gpt_loss'] = cfg.LOSS.LAMBDA_CLS |
|
|
|
|
|
losses_func = {} |
|
for loss in losses: |
|
if loss.split('_')[0] == 'recons': |
|
if recons_loss == "l1": |
|
losses_func[loss] = nn.L1Loss |
|
elif recons_loss == "l2": |
|
losses_func[loss] = nn.MSELoss |
|
elif recons_loss == "l1_smooth": |
|
losses_func[loss] = nn.SmoothL1Loss |
|
elif loss.split('_')[1] in [ |
|
'commit', 'loss', 'gpt', 'm2t2m', 't2m2t' |
|
]: |
|
losses_func[loss] = CommitLoss |
|
elif loss.split('_')[1] in ['cls', 'lm']: |
|
losses_func[loss] = nn.CrossEntropyLoss |
|
else: |
|
raise NotImplementedError(f"Loss {loss} not implemented.") |
|
|
|
super().__init__(cfg, losses, params, losses_func, num_joints, |
|
**kwargs) |
|
|
|
def update(self, rs_set): |
|
'''Update the losses''' |
|
total: float = 0.0 |
|
|
|
if self.stage in ["vae"]: |
|
total += self._update_loss("recons_feature", rs_set['m_rst'], |
|
rs_set['m_ref']) |
|
|
|
nfeats = rs_set['m_rst'].shape[-1] |
|
if nfeats in [263, 135 + 263]: |
|
if nfeats == 135 + 263: |
|
vel_start = 135 + 4 |
|
elif nfeats == 263: |
|
vel_start = 4 |
|
total += self._update_loss( |
|
"recons_velocity", |
|
rs_set['m_rst'][..., vel_start:(self.num_joints - 1) * 3 + |
|
vel_start], |
|
rs_set['m_ref'][..., vel_start:(self.num_joints - 1) * 3 + |
|
vel_start]) |
|
else: |
|
if self._params['recons_velocity'] != 0.0: |
|
raise NotImplementedError( |
|
"Velocity not implemented for nfeats = {})".format(nfeats)) |
|
total += self._update_loss("vq_commit", rs_set['loss_commit'], |
|
rs_set['loss_commit']) |
|
|
|
if self.stage in ["lm_pretrain", "lm_instruct"]: |
|
total += self._update_loss("gpt_loss", rs_set['outputs'].loss, |
|
rs_set['outputs'].loss) |
|
|
|
|
|
self.total += total.detach() |
|
self.count += 1 |
|
|
|
return total |
|
|