You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

keypoints.py 7.8 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import numpy as np
  3. from typing import Any, List, Tuple, Union
  4. import torch
  5. from detectron2.layers import interpolate
  6. class Keypoints:
  7. """
  8. Stores keypoint annotation data. GT Instances have a `gt_keypoints` property
  9. containing the x,y location and visibility flag of each keypoint. This tensor has shape
  10. (N, K, 3) where N is the number of instances and K is the number of keypoints per instance.
  11. The visibility flag follows the COCO format and must be one of three integers:
  12. * v=0: not labeled (in which case x=y=0)
  13. * v=1: labeled but not visible
  14. * v=2: labeled and visible
  15. """
  16. def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]):
  17. """
  18. Arguments:
  19. keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint.
  20. The shape should be (N, K, 3) where N is the number of
  21. instances, and K is the number of keypoints per instance.
  22. """
  23. device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device("cpu")
  24. keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device)
  25. assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape
  26. self.tensor = keypoints
  27. def __len__(self) -> int:
  28. return self.tensor.size(0)
  29. def to(self, *args: Any, **kwargs: Any) -> "Keypoints":
  30. return type(self)(self.tensor.to(*args, **kwargs))
  31. def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor:
  32. """
  33. Arguments:
  34. boxes: Nx4 tensor, the boxes to draw the keypoints to
  35. Returns:
  36. heatmaps:
  37. A tensor of shape (N, K) containing an integer spatial label
  38. in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
  39. valid:
  40. A tensor of shape (N, K) containing whether each keypoint is in the roi or not.
  41. """
  42. return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size)
  43. def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints":
  44. """
  45. Create a new `Keypoints` by indexing on this `Keypoints`.
  46. The following usage are allowed:
  47. 1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance.
  48. 2. `new_kpts = kpts[2:10]`: return a slice of key points.
  49. 3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor
  50. with `length = len(kpts)`. Nonzero elements in the vector will be selected.
  51. Note that the returned Keypoints might share storage with this Keypoints,
  52. subject to Pytorch's indexing semantics.
  53. """
  54. if isinstance(item, int):
  55. return Keypoints([self.tensor[item]])
  56. return Keypoints(self.tensor[item])
  57. def __repr__(self) -> str:
  58. s = self.__class__.__name__ + "("
  59. s += "num_instances={})".format(len(self.tensor))
  60. return s
  61. # TODO make this nicer, this is a direct translation from C2 (but removing the inner loop)
  62. def _keypoints_to_heatmap(
  63. keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int
  64. ) -> Tuple[torch.Tensor, torch.Tensor]:
  65. """
  66. Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space.
  67. Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the
  68. closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the
  69. continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"):
  70. d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
  71. Arguments:
  72. keypoints: tensor of keypoint locations in of shape (N, K, 3).
  73. rois: Nx4 tensor of rois in xyxy format
  74. heatmap_size: integer side length of square heatmap.
  75. Returns:
  76. heatmaps: A tensor of shape (N, K) containing an integer spatial label
  77. in the range [0, heatmap_size**2 - 1] for each keypoint in the input.
  78. valid: A tensor of shape (N, K) containing whether each keypoint is in
  79. the roi or not.
  80. """
  81. if rois.numel() == 0:
  82. return rois.new().long(), rois.new().long()
  83. offset_x = rois[:, 0]
  84. offset_y = rois[:, 1]
  85. scale_x = heatmap_size / (rois[:, 2] - rois[:, 0])
  86. scale_y = heatmap_size / (rois[:, 3] - rois[:, 1])
  87. offset_x = offset_x[:, None]
  88. offset_y = offset_y[:, None]
  89. scale_x = scale_x[:, None]
  90. scale_y = scale_y[:, None]
  91. x = keypoints[..., 0]
  92. y = keypoints[..., 1]
  93. x_boundary_inds = x == rois[:, 2][:, None]
  94. y_boundary_inds = y == rois[:, 3][:, None]
  95. x = (x - offset_x) * scale_x
  96. x = x.floor().long()
  97. y = (y - offset_y) * scale_y
  98. y = y.floor().long()
  99. x[x_boundary_inds] = heatmap_size - 1
  100. y[y_boundary_inds] = heatmap_size - 1
  101. valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size)
  102. vis = keypoints[..., 2] > 0
  103. valid = (valid_loc & vis).long()
  104. lin_ind = y * heatmap_size + x
  105. heatmaps = lin_ind * valid
  106. return heatmaps, valid
  107. @torch.no_grad()
  108. def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
  109. """
  110. Args:
  111. maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W)
  112. rois (Tensor): (#ROIs, 4)
  113. Extract predicted keypoint locations from heatmaps. Output has shape
  114. (#rois, #keypoints, 4) with the last dimension corresponding to (x, y, logit, prob)
  115. for each keypoint.
  116. Converts a discrete image coordinate in an NxN image to a continuous keypoint coordinate. We
  117. maintain consistency with keypoints_to_heatmap by using the conversion from Heckbert 1990:
  118. c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate.
  119. """
  120. offset_x = rois[:, 0]
  121. offset_y = rois[:, 1]
  122. widths = (rois[:, 2] - rois[:, 0]).clamp(min=1)
  123. heights = (rois[:, 3] - rois[:, 1]).clamp(min=1)
  124. widths_ceil = widths.ceil()
  125. heights_ceil = heights.ceil()
  126. num_rois, num_keypoints = maps.shape[:2]
  127. xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4)
  128. width_corrections = widths / widths_ceil
  129. height_corrections = heights / heights_ceil
  130. keypoints_idx = torch.arange(num_keypoints, device=maps.device)
  131. for i in range(num_rois):
  132. outsize = (int(heights_ceil[i]), int(widths_ceil[i]))
  133. roi_map = interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False).squeeze(
  134. 0
  135. ) # #keypoints x H x W
  136. # softmax over the spatial region
  137. max_score, _ = roi_map.view(num_keypoints, -1).max(1)
  138. max_score = max_score.view(num_keypoints, 1, 1)
  139. tmp_full_resolution = (roi_map - max_score).exp_()
  140. tmp_pool_resolution = (maps[i] - max_score).exp_()
  141. # Produce scores over the region H x W, but normalize with POOL_H x POOL_W
  142. # So that the scores of objects of different absolute sizes will be more comparable
  143. roi_map_probs = tmp_full_resolution / tmp_pool_resolution.sum((1, 2), keepdim=True)
  144. w = roi_map.shape[2]
  145. pos = roi_map.view(num_keypoints, -1).argmax(1)
  146. x_int = pos % w
  147. y_int = (pos - x_int) // w
  148. assert (
  149. roi_map_probs[keypoints_idx, y_int, x_int]
  150. == roi_map_probs.view(num_keypoints, -1).max(1)[0]
  151. ).all()
  152. x = (x_int.float() + 0.5) * width_corrections[i]
  153. y = (y_int.float() + 0.5) * height_corrections[i]
  154. xy_preds[i, :, 0] = x + offset_x[i]
  155. xy_preds[i, :, 1] = y + offset_y[i]
  156. xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int]
  157. xy_preds[i, :, 3] = roi_map_probs[keypoints_idx, y_int, x_int]
  158. return xy_preds

No Description