import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from fastNLP.core.losses import LossBase reduce_func = { 'none': lambda x, mask: x*mask, 'sum': lambda x, mask: (x*mask).sum(), 'mean': lambda x, mask: (x*mask).sum() / mask.sum(), } class LabelSmoothCrossEntropy(nn.Module): def __init__(self, smoothing=0.1, ignore_index=-100, reduction='mean'): global reduce_func super().__init__() if smoothing < 0 or smoothing > 1: raise ValueError('invalid smoothing value: {}'.format(smoothing)) self.smoothing = smoothing self.ignore_index = ignore_index if reduction not in reduce_func: raise ValueError('invalid reduce type: {}'.format(reduction)) self.reduce_func = reduce_func[reduction] def forward(self, input, target): input = F.log_softmax(input, dim=1) # [N, C, ...] smooth_val = self.smoothing / input.size(1) # [N, C, ...] target_logit = input.new_full(input.size(), fill_value=smooth_val) target_logit.scatter_(1, target[:, None], 1 - self.smoothing) result = -(target_logit * input).sum(1) # [N, ...] mask = (target != self.ignore_index).float() return self.reduce_func(result, mask) class SmoothCE(LossBase): def __init__(self, pred=None, target=None, **kwargs): super().__init__() self.loss_fn = LabelSmoothCrossEntropy(**kwargs) self._init_param_map(pred=pred, target=target) def get_loss(self, pred, target): return self.loss_fn(pred, target) if __name__ == '__main__': loss_fn = nn.CrossEntropyLoss(ignore_index=0) sm_loss_fn = LabelSmoothCrossEntropy(smoothing=0, ignore_index=0) predict = torch.tensor([[0, 0.2, 0.7, 0.1, 0], [0, 0.9, 0.2, 0.1, 0], [1, 0.2, 0.7, 0.1, 0]]) target = torch.tensor([2, 1, 0]) loss = loss_fn(predict, target) sm_loss = sm_loss_fn(predict, target) print(loss, sm_loss)