Spaces:
Sleeping
Sleeping
# Copyright (c) 2023, Tri Dao. | |
import torch | |
import torch.nn as nn | |
from flash_attn.ops.triton.cross_entropy import cross_entropy_loss | |
class CrossEntropyLoss(nn.Module): | |
def __init__( | |
self, | |
ignore_index=-100, | |
reduction="mean", | |
label_smoothing=0.0, | |
logit_scale=1.0, | |
lse_square_scale=0.0, | |
inplace_backward=False, | |
process_group=None, | |
return_z_loss=False, | |
): | |
""" | |
Arguments: | |
ignore_index: int. If labels == ignore_index, the loss is set to 0.0. | |
label_smoothing: float | |
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. | |
This is also referred to as "z-loss". | |
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. | |
This saves memory. | |
process_group: if not None, we're doing Tensor Parallel: each process is responsible for | |
one part of the vocab. The loss will be aggregated across processes. | |
return_z_loss: bool. If True, we return the component of the loss contributed by | |
the lse_square_scale value. This value is only for logging and does not support | |
backprop. | |
""" | |
super().__init__() | |
if reduction not in ["mean", "none", "sum"]: | |
raise NotImplementedError("Only support reduction = 'mean' or 'none' or 'sum'") | |
self.ignore_index = ignore_index | |
self.reduction = reduction | |
self.label_smoothing = label_smoothing | |
self.logit_scale = logit_scale | |
self.lse_square_scale = lse_square_scale | |
self.inplace_backward = inplace_backward | |
self.process_group = process_group | |
self.return_z_loss = return_z_loss | |
def forward(self, input, target): | |
""" | |
Arguments: | |
input: (batch, vocab_size) | |
target: (batch,) | |
Returns: | |
losses: (batch,) if reduction is 'none', else (1,), dtype float | |
z_loss: (batch,) if reduction is 'none', else (1,), dtype float (if self.return_z_loss) | |
""" | |
assert input.is_cuda and target.is_cuda, "Only support CUDA tensors" | |
loss, z_loss = cross_entropy_loss( | |
input, | |
target, | |
label_smoothing=self.label_smoothing, | |
logit_scale=self.logit_scale, | |
lse_square_scale=self.lse_square_scale, | |
ignore_index=self.ignore_index, | |
inplace_backward=self.inplace_backward, | |
process_group=self.process_group, | |
) | |
if self.reduction == "mean": | |
loss = loss.sum() / (target != self.ignore_index).sum() | |
elif self.reduction == "sum": | |
loss = loss.sum() | |
else: | |
loss = loss | |
if not self.return_z_loss: | |
return loss | |
if self.reduction == "mean": | |
z_loss = z_loss.sum() / (target != self.ignore_index).sum() | |
elif self.reduction == "sum": | |
z_loss = z_loss.sum() | |
else: | |
z_loss = z_loss | |
return loss, z_loss | |