You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Loss.py 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #!/usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. # __author__="Danqing Wang"
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. import torch
  18. import torch.nn.functional as F
  19. from fastNLP.core.losses import LossBase
  20. from tools.logger import *
  21. class MyCrossEntropyLoss(LossBase):
  22. def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'):
  23. super().__init__()
  24. self._init_param_map(pred=pred, target=target, mask=mask)
  25. self.padding_idx = padding_idx
  26. self.reduce = reduce
  27. def get_loss(self, pred, target, mask):
  28. """
  29. :param pred: [batch, N, 2]
  30. :param target: [batch, N]
  31. :param input_mask: [batch, N]
  32. :return:
  33. """
  34. # logger.debug(pred[0:5, :, :])
  35. # logger.debug(target[0:5, :])
  36. batch, N, _ = pred.size()
  37. pred = pred.view(-1, 2)
  38. target = target.view(-1)
  39. loss = F.cross_entropy(input=pred, target=target,
  40. ignore_index=self.padding_idx, reduction=self.reduce)
  41. loss = loss.view(batch, -1)
  42. loss = loss.masked_fill(mask.eq(0), 0)
  43. loss = loss.sum(1).mean()
  44. logger.debug("loss %f", loss)
  45. return loss