You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kd_loss.py 2.9 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from ..builder import LOSSES
  6. from .utils import weighted_loss
  7. @mmcv.jit(derivate=True, coderize=True)
  8. @weighted_loss
  9. def knowledge_distillation_kl_div_loss(pred,
  10. soft_label,
  11. T,
  12. detach_target=True):
  13. r"""Loss function for knowledge distilling using KL divergence.
  14. Args:
  15. pred (Tensor): Predicted logits with shape (N, n + 1).
  16. soft_label (Tensor): Target logits with shape (N, N + 1).
  17. T (int): Temperature for distillation.
  18. detach_target (bool): Remove soft_label from automatic differentiation
  19. Returns:
  20. torch.Tensor: Loss tensor with shape (N,).
  21. """
  22. assert pred.size() == soft_label.size()
  23. target = F.softmax(soft_label / T, dim=1)
  24. if detach_target:
  25. target = target.detach()
  26. kd_loss = F.kl_div(
  27. F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * (
  28. T * T)
  29. return kd_loss
  30. @LOSSES.register_module()
  31. class KnowledgeDistillationKLDivLoss(nn.Module):
  32. """Loss function for knowledge distilling using KL divergence.
  33. Args:
  34. reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
  35. loss_weight (float): Loss weight of current loss.
  36. T (int): Temperature for distillation.
  37. """
  38. def __init__(self, reduction='mean', loss_weight=1.0, T=10):
  39. super(KnowledgeDistillationKLDivLoss, self).__init__()
  40. assert T >= 1
  41. self.reduction = reduction
  42. self.loss_weight = loss_weight
  43. self.T = T
  44. def forward(self,
  45. pred,
  46. soft_label,
  47. weight=None,
  48. avg_factor=None,
  49. reduction_override=None):
  50. """Forward function.
  51. Args:
  52. pred (Tensor): Predicted logits with shape (N, n + 1).
  53. soft_label (Tensor): Target logits with shape (N, N + 1).
  54. weight (torch.Tensor, optional): The weight of loss for each
  55. prediction. Defaults to None.
  56. avg_factor (int, optional): Average factor that is used to average
  57. the loss. Defaults to None.
  58. reduction_override (str, optional): The reduction method used to
  59. override the original reduction method of the loss.
  60. Defaults to None.
  61. """
  62. assert reduction_override in (None, 'none', 'mean', 'sum')
  63. reduction = (
  64. reduction_override if reduction_override else self.reduction)
  65. loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss(
  66. pred,
  67. soft_label,
  68. weight,
  69. reduction=reduction,
  70. avg_factor=avg_factor,
  71. T=self.T)
  72. return loss_kd

No Description

Contributors (3)