You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ae_loss.py 3.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from ..builder import LOSSES
  7. @mmcv.jit(derivate=True, coderize=True)
  8. def ae_loss_per_image(tl_preds, br_preds, match):
  9. """Associative Embedding Loss in one image.
  10. Associative Embedding Loss including two parts: pull loss and push loss.
  11. Pull loss makes embedding vectors from same object closer to each other.
  12. Push loss distinguish embedding vector from different objects, and makes
  13. the gap between them is large enough.
  14. During computing, usually there are 3 cases:
  15. - no object in image: both pull loss and push loss will be 0.
  16. - one object in image: push loss will be 0 and pull loss is computed
  17. by the two corner of the only object.
  18. - more than one objects in image: pull loss is computed by corner pairs
  19. from each object, push loss is computed by each object with all
  20. other objects. We use confusion matrix with 0 in diagonal to
  21. compute the push loss.
  22. Args:
  23. tl_preds (tensor): Embedding feature map of left-top corner.
  24. br_preds (tensor): Embedding feature map of bottim-right corner.
  25. match (list): Downsampled coordinates pair of each ground truth box.
  26. """
  27. tl_list, br_list, me_list = [], [], []
  28. if len(match) == 0: # no object in image
  29. pull_loss = tl_preds.sum() * 0.
  30. push_loss = tl_preds.sum() * 0.
  31. else:
  32. for m in match:
  33. [tl_y, tl_x], [br_y, br_x] = m
  34. tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1)
  35. br_e = br_preds[:, br_y, br_x].view(-1, 1)
  36. tl_list.append(tl_e)
  37. br_list.append(br_e)
  38. me_list.append((tl_e + br_e) / 2.0)
  39. tl_list = torch.cat(tl_list)
  40. br_list = torch.cat(br_list)
  41. me_list = torch.cat(me_list)
  42. assert tl_list.size() == br_list.size()
  43. # N is object number in image, M is dimension of embedding vector
  44. N, M = tl_list.size()
  45. pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2)
  46. pull_loss = pull_loss.sum() / N
  47. margin = 1 # exp setting of CornerNet, details in section 3.3 of paper
  48. # confusion matrix of push loss
  49. conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list
  50. conf_weight = 1 - torch.eye(N).type_as(me_list)
  51. conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs())
  52. if N > 1: # more than one object in current image
  53. push_loss = F.relu(conf_mat).sum() / (N * (N - 1))
  54. else:
  55. push_loss = tl_preds.sum() * 0.
  56. return pull_loss, push_loss
  57. @LOSSES.register_module()
  58. class AssociativeEmbeddingLoss(nn.Module):
  59. """Associative Embedding Loss.
  60. More details can be found in
  61. `Associative Embedding <https://arxiv.org/abs/1611.05424>`_ and
  62. `CornerNet <https://arxiv.org/abs/1808.01244>`_ .
  63. Code is modified from `kp_utils.py <https://github.com/princeton-vl/CornerNet/blob/master/models/py_utils/kp_utils.py#L180>`_ # noqa: E501
  64. Args:
  65. pull_weight (float): Loss weight for corners from same object.
  66. push_weight (float): Loss weight for corners from different object.
  67. """
  68. def __init__(self, pull_weight=0.25, push_weight=0.25):
  69. super(AssociativeEmbeddingLoss, self).__init__()
  70. self.pull_weight = pull_weight
  71. self.push_weight = push_weight
  72. def forward(self, pred, target, match):
  73. """Forward function."""
  74. batch = pred.size(0)
  75. pull_all, push_all = 0.0, 0.0
  76. for i in range(batch):
  77. pull, push = ae_loss_per_image(pred[i], target[i], match[i])
  78. pull_all += self.pull_weight * pull
  79. push_all += self.push_weight * push
  80. return pull_all, push_all

No Description

Contributors (3)