Spaces:
Sleeping
Sleeping
import torch | |
import torch.optim | |
import torch.nn.functional as F | |
import copy | |
def update_generation_losses(losses, nums, micro, macro, bs, length, loss): | |
# Update Losses | |
losses[micro] += \ | |
[copy.deepcopy(losses[micro][-1])] | |
losses[macro] += \ | |
[copy.deepcopy(losses[macro][-1])] | |
losses[micro][-1] *= nums[micro] | |
losses[macro][-1] *= nums[macro] | |
nums[macro] += bs | |
if isinstance(length, int): | |
update_indiv_generation_losses( | |
losses, nums, micro, macro, bs, length, loss) | |
else: | |
update_tensor_generation_losses( | |
losses, nums, micro, macro, bs, length, loss) | |
def update_indiv_generation_losses(losses, nums, micro, | |
macro, bs, length, loss): | |
nums[micro] += (bs * length) | |
batch_loss = loss * bs | |
losses[micro][-1] += batch_loss | |
losses[micro][-1] /= nums[micro] | |
losses[macro][-1] += batch_loss / length | |
losses[macro][-1] /= nums[macro] | |
def update_tensor_generation_losses(losses, nums, micro, | |
macro, bs, length, loss): | |
nums[micro] += length.sum().item() | |
losses[micro][-1] += loss.sum().item() | |
losses[micro][-1] /= nums[micro] | |
losses[macro][-1] += (loss / length.float()).sum().item() | |
losses[macro][-1] /= nums[macro] | |
def modify_output_for_loss_fn(loss_fn, output, dim): | |
if loss_fn == "ce": | |
return output | |
if loss_fn == "mse": | |
return F.softmax(output, dim=dim) | |
if loss_fn == "nll": | |
return F.log_softmax(output, dim=dim) | |
if loss_fn in ["bce", "wbce", "wbce1"]: | |
return torch.sigmoid(output) | |