You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

point_rend_roi_head.py 19 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
  3. import os
  4. import warnings
  5. import numpy as np
  6. import torch
  7. import torch.nn.functional as F
  8. from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point
  9. from mmdet.core import bbox2roi, bbox_mapping, merge_aug_masks
  10. from .. import builder
  11. from ..builder import HEADS
  12. from .standard_roi_head import StandardRoIHead
  13. @HEADS.register_module()
  14. class PointRendRoIHead(StandardRoIHead):
  15. """`PointRend <https://arxiv.org/abs/1912.08193>`_."""
  16. def __init__(self, point_head, *args, **kwargs):
  17. super().__init__(*args, **kwargs)
  18. assert self.with_bbox and self.with_mask
  19. self.init_point_head(point_head)
  20. def init_point_head(self, point_head):
  21. """Initialize ``point_head``"""
  22. self.point_head = builder.build_head(point_head)
  23. def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
  24. img_metas):
  25. """Run forward function and calculate loss for mask head and point head
  26. in training."""
  27. mask_results = super()._mask_forward_train(x, sampling_results,
  28. bbox_feats, gt_masks,
  29. img_metas)
  30. if mask_results['loss_mask'] is not None:
  31. loss_point = self._mask_point_forward_train(
  32. x, sampling_results, mask_results['mask_pred'], gt_masks,
  33. img_metas)
  34. mask_results['loss_mask'].update(loss_point)
  35. return mask_results
  36. def _mask_point_forward_train(self, x, sampling_results, mask_pred,
  37. gt_masks, img_metas):
  38. """Run forward function and calculate loss for point head in
  39. training."""
  40. pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
  41. rel_roi_points = self.point_head.get_roi_rel_points_train(
  42. mask_pred, pos_labels, cfg=self.train_cfg)
  43. rois = bbox2roi([res.pos_bboxes for res in sampling_results])
  44. fine_grained_point_feats = self._get_fine_grained_point_feats(
  45. x, rois, rel_roi_points, img_metas)
  46. coarse_point_feats = point_sample(mask_pred, rel_roi_points)
  47. mask_point_pred = self.point_head(fine_grained_point_feats,
  48. coarse_point_feats)
  49. mask_point_target = self.point_head.get_targets(
  50. rois, rel_roi_points, sampling_results, gt_masks, self.train_cfg)
  51. loss_mask_point = self.point_head.loss(mask_point_pred,
  52. mask_point_target, pos_labels)
  53. return loss_mask_point
  54. def _get_fine_grained_point_feats(self, x, rois, rel_roi_points,
  55. img_metas):
  56. """Sample fine grained feats from each level feature map and
  57. concatenate them together.
  58. Args:
  59. x (tuple[Tensor]): Feature maps of all scale level.
  60. rois (Tensor): shape (num_rois, 5).
  61. rel_roi_points (Tensor): A tensor of shape (num_rois, num_points,
  62. 2) that contains [0, 1] x [0, 1] normalized coordinates of the
  63. most uncertain points from the [mask_height, mask_width] grid.
  64. img_metas (list[dict]): Image meta info.
  65. Returns:
  66. Tensor: The fine grained features for each points,
  67. has shape (num_rois, feats_channels, num_points).
  68. """
  69. num_imgs = len(img_metas)
  70. fine_grained_feats = []
  71. for idx in range(self.mask_roi_extractor.num_inputs):
  72. feats = x[idx]
  73. spatial_scale = 1. / float(
  74. self.mask_roi_extractor.featmap_strides[idx])
  75. point_feats = []
  76. for batch_ind in range(num_imgs):
  77. # unravel batch dim
  78. feat = feats[batch_ind].unsqueeze(0)
  79. inds = (rois[:, 0].long() == batch_ind)
  80. if inds.any():
  81. rel_img_points = rel_roi_point_to_rel_img_point(
  82. rois[inds], rel_roi_points[inds], feat.shape[2:],
  83. spatial_scale).unsqueeze(0)
  84. point_feat = point_sample(feat, rel_img_points)
  85. point_feat = point_feat.squeeze(0).transpose(0, 1)
  86. point_feats.append(point_feat)
  87. fine_grained_feats.append(torch.cat(point_feats, dim=0))
  88. return torch.cat(fine_grained_feats, dim=1)
  89. def _mask_point_forward_test(self, x, rois, label_pred, mask_pred,
  90. img_metas):
  91. """Mask refining process with point head in testing.
  92. Args:
  93. x (tuple[Tensor]): Feature maps of all scale level.
  94. rois (Tensor): shape (num_rois, 5).
  95. label_pred (Tensor): The predication class for each rois.
  96. mask_pred (Tensor): The predication coarse masks of
  97. shape (num_rois, num_classes, small_size, small_size).
  98. img_metas (list[dict]): Image meta info.
  99. Returns:
  100. Tensor: The refined masks of shape (num_rois, num_classes,
  101. large_size, large_size).
  102. """
  103. refined_mask_pred = mask_pred.clone()
  104. for subdivision_step in range(self.test_cfg.subdivision_steps):
  105. refined_mask_pred = F.interpolate(
  106. refined_mask_pred,
  107. scale_factor=self.test_cfg.scale_factor,
  108. mode='bilinear',
  109. align_corners=False)
  110. # If `subdivision_num_points` is larger or equal to the
  111. # resolution of the next step, then we can skip this step
  112. num_rois, channels, mask_height, mask_width = \
  113. refined_mask_pred.shape
  114. if (self.test_cfg.subdivision_num_points >=
  115. self.test_cfg.scale_factor**2 * mask_height * mask_width
  116. and
  117. subdivision_step < self.test_cfg.subdivision_steps - 1):
  118. continue
  119. point_indices, rel_roi_points = \
  120. self.point_head.get_roi_rel_points_test(
  121. refined_mask_pred, label_pred, cfg=self.test_cfg)
  122. fine_grained_point_feats = self._get_fine_grained_point_feats(
  123. x, rois, rel_roi_points, img_metas)
  124. coarse_point_feats = point_sample(mask_pred, rel_roi_points)
  125. mask_point_pred = self.point_head(fine_grained_point_feats,
  126. coarse_point_feats)
  127. point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
  128. refined_mask_pred = refined_mask_pred.reshape(
  129. num_rois, channels, mask_height * mask_width)
  130. refined_mask_pred = refined_mask_pred.scatter_(
  131. 2, point_indices, mask_point_pred)
  132. refined_mask_pred = refined_mask_pred.view(num_rois, channels,
  133. mask_height, mask_width)
  134. return refined_mask_pred
  135. def simple_test_mask(self,
  136. x,
  137. img_metas,
  138. det_bboxes,
  139. det_labels,
  140. rescale=False):
  141. """Obtain mask prediction without augmentation."""
  142. ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
  143. scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
  144. if isinstance(scale_factors[0], float):
  145. warnings.warn(
  146. 'Scale factor in img_metas should be a '
  147. 'ndarray with shape (4,) '
  148. 'arrange as (factor_w, factor_h, factor_w, factor_h), '
  149. 'The scale_factor with float type has been deprecated. ')
  150. scale_factors = np.array([scale_factors] * 4, dtype=np.float32)
  151. num_imgs = len(det_bboxes)
  152. if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
  153. segm_results = [[[] for _ in range(self.mask_head.num_classes)]
  154. for _ in range(num_imgs)]
  155. else:
  156. # if det_bboxes is rescaled to the original image size, we need to
  157. # rescale it back to the testing scale to obtain RoIs.
  158. _bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))]
  159. if rescale:
  160. scale_factors = [
  161. torch.from_numpy(scale_factor).to(det_bboxes[0].device)
  162. for scale_factor in scale_factors
  163. ]
  164. _bboxes = [
  165. _bboxes[i] * scale_factors[i] for i in range(len(_bboxes))
  166. ]
  167. mask_rois = bbox2roi(_bboxes)
  168. mask_results = self._mask_forward(x, mask_rois)
  169. # split batch mask prediction back to each image
  170. mask_pred = mask_results['mask_pred']
  171. num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
  172. mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
  173. mask_rois = mask_rois.split(num_mask_roi_per_img, 0)
  174. # apply mask post-processing to each image individually
  175. segm_results = []
  176. for i in range(num_imgs):
  177. if det_bboxes[i].shape[0] == 0:
  178. segm_results.append(
  179. [[] for _ in range(self.mask_head.num_classes)])
  180. else:
  181. x_i = [xx[[i]] for xx in x]
  182. mask_rois_i = mask_rois[i]
  183. mask_rois_i[:, 0] = 0 # TODO: remove this hack
  184. mask_pred_i = self._mask_point_forward_test(
  185. x_i, mask_rois_i, det_labels[i], mask_preds[i],
  186. [img_metas])
  187. segm_result = self.mask_head.get_seg_masks(
  188. mask_pred_i, _bboxes[i], det_labels[i], self.test_cfg,
  189. ori_shapes[i], scale_factors[i], rescale)
  190. segm_results.append(segm_result)
  191. return segm_results
  192. def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
  193. """Test for mask head with test time augmentation."""
  194. if det_bboxes.shape[0] == 0:
  195. segm_result = [[] for _ in range(self.mask_head.num_classes)]
  196. else:
  197. aug_masks = []
  198. for x, img_meta in zip(feats, img_metas):
  199. img_shape = img_meta[0]['img_shape']
  200. scale_factor = img_meta[0]['scale_factor']
  201. flip = img_meta[0]['flip']
  202. _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
  203. scale_factor, flip)
  204. mask_rois = bbox2roi([_bboxes])
  205. mask_results = self._mask_forward(x, mask_rois)
  206. mask_results['mask_pred'] = self._mask_point_forward_test(
  207. x, mask_rois, det_labels, mask_results['mask_pred'],
  208. img_meta)
  209. # convert to numpy array to save memory
  210. aug_masks.append(
  211. mask_results['mask_pred'].sigmoid().cpu().numpy())
  212. merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg)
  213. ori_shape = img_metas[0][0]['ori_shape']
  214. segm_result = self.mask_head.get_seg_masks(
  215. merged_masks,
  216. det_bboxes,
  217. det_labels,
  218. self.test_cfg,
  219. ori_shape,
  220. scale_factor=1.0,
  221. rescale=False)
  222. return segm_result
  223. def _onnx_get_fine_grained_point_feats(self, x, rois, rel_roi_points):
  224. """Export the process of sampling fine grained feats to onnx.
  225. Args:
  226. x (tuple[Tensor]): Feature maps of all scale level.
  227. rois (Tensor): shape (num_rois, 5).
  228. rel_roi_points (Tensor): A tensor of shape (num_rois, num_points,
  229. 2) that contains [0, 1] x [0, 1] normalized coordinates of the
  230. most uncertain points from the [mask_height, mask_width] grid.
  231. Returns:
  232. Tensor: The fine grained features for each points,
  233. has shape (num_rois, feats_channels, num_points).
  234. """
  235. batch_size = x[0].shape[0]
  236. num_rois = rois.shape[0]
  237. fine_grained_feats = []
  238. for idx in range(self.mask_roi_extractor.num_inputs):
  239. feats = x[idx]
  240. spatial_scale = 1. / float(
  241. self.mask_roi_extractor.featmap_strides[idx])
  242. rel_img_points = rel_roi_point_to_rel_img_point(
  243. rois, rel_roi_points, feats, spatial_scale)
  244. channels = feats.shape[1]
  245. num_points = rel_img_points.shape[1]
  246. rel_img_points = rel_img_points.reshape(batch_size, -1, num_points,
  247. 2)
  248. point_feats = point_sample(feats, rel_img_points)
  249. point_feats = point_feats.transpose(1, 2).reshape(
  250. num_rois, channels, num_points)
  251. fine_grained_feats.append(point_feats)
  252. return torch.cat(fine_grained_feats, dim=1)
  253. def _mask_point_onnx_export(self, x, rois, label_pred, mask_pred):
  254. """Export mask refining process with point head to onnx.
  255. Args:
  256. x (tuple[Tensor]): Feature maps of all scale level.
  257. rois (Tensor): shape (num_rois, 5).
  258. label_pred (Tensor): The predication class for each rois.
  259. mask_pred (Tensor): The predication coarse masks of
  260. shape (num_rois, num_classes, small_size, small_size).
  261. Returns:
  262. Tensor: The refined masks of shape (num_rois, num_classes,
  263. large_size, large_size).
  264. """
  265. refined_mask_pred = mask_pred.clone()
  266. for subdivision_step in range(self.test_cfg.subdivision_steps):
  267. refined_mask_pred = F.interpolate(
  268. refined_mask_pred,
  269. scale_factor=self.test_cfg.scale_factor,
  270. mode='bilinear',
  271. align_corners=False)
  272. # If `subdivision_num_points` is larger or equal to the
  273. # resolution of the next step, then we can skip this step
  274. num_rois, channels, mask_height, mask_width = \
  275. refined_mask_pred.shape
  276. if (self.test_cfg.subdivision_num_points >=
  277. self.test_cfg.scale_factor**2 * mask_height * mask_width
  278. and
  279. subdivision_step < self.test_cfg.subdivision_steps - 1):
  280. continue
  281. point_indices, rel_roi_points = \
  282. self.point_head.get_roi_rel_points_test(
  283. refined_mask_pred, label_pred, cfg=self.test_cfg)
  284. fine_grained_point_feats = self._onnx_get_fine_grained_point_feats(
  285. x, rois, rel_roi_points)
  286. coarse_point_feats = point_sample(mask_pred, rel_roi_points)
  287. mask_point_pred = self.point_head(fine_grained_point_feats,
  288. coarse_point_feats)
  289. point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
  290. refined_mask_pred = refined_mask_pred.reshape(
  291. num_rois, channels, mask_height * mask_width)
  292. is_trt_backend = os.environ.get('ONNX_BACKEND') == 'MMCVTensorRT'
  293. # avoid ScatterElements op in ONNX for TensorRT
  294. if is_trt_backend:
  295. mask_shape = refined_mask_pred.shape
  296. point_shape = point_indices.shape
  297. inds_dim0 = torch.arange(point_shape[0]).reshape(
  298. point_shape[0], 1, 1).expand_as(point_indices)
  299. inds_dim1 = torch.arange(point_shape[1]).reshape(
  300. 1, point_shape[1], 1).expand_as(point_indices)
  301. inds_1d = inds_dim0.reshape(
  302. -1) * mask_shape[1] * mask_shape[2] + inds_dim1.reshape(
  303. -1) * mask_shape[2] + point_indices.reshape(-1)
  304. refined_mask_pred = refined_mask_pred.reshape(-1)
  305. refined_mask_pred[inds_1d] = mask_point_pred.reshape(-1)
  306. refined_mask_pred = refined_mask_pred.reshape(*mask_shape)
  307. else:
  308. refined_mask_pred = refined_mask_pred.scatter_(
  309. 2, point_indices, mask_point_pred)
  310. refined_mask_pred = refined_mask_pred.view(num_rois, channels,
  311. mask_height, mask_width)
  312. return refined_mask_pred
  313. def mask_onnx_export(self, x, img_metas, det_bboxes, det_labels, **kwargs):
  314. """Export mask branch to onnx which supports batch inference.
  315. Args:
  316. x (tuple[Tensor]): Feature maps of all scale level.
  317. img_metas (list[dict]): Image meta info.
  318. det_bboxes (Tensor): Bboxes and corresponding scores.
  319. has shape [N, num_bboxes, 5].
  320. det_labels (Tensor): class labels of
  321. shape [N, num_bboxes].
  322. Returns:
  323. Tensor: The segmentation results of shape [N, num_bboxes,
  324. image_height, image_width].
  325. """
  326. if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
  327. raise RuntimeError('[ONNX Error] Can not record MaskHead '
  328. 'as it has not been executed this time')
  329. batch_size = det_bboxes.size(0)
  330. # if det_bboxes is rescaled to the original image size, we need to
  331. # rescale it back to the testing scale to obtain RoIs.
  332. det_bboxes = det_bboxes[..., :4]
  333. batch_index = torch.arange(
  334. det_bboxes.size(0), device=det_bboxes.device).float().view(
  335. -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1)
  336. mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
  337. mask_rois = mask_rois.view(-1, 5)
  338. mask_results = self._mask_forward(x, mask_rois)
  339. mask_pred = mask_results['mask_pred']
  340. max_shape = img_metas[0]['img_shape_for_onnx']
  341. num_det = det_bboxes.shape[1]
  342. det_bboxes = det_bboxes.reshape(-1, 4)
  343. det_labels = det_labels.reshape(-1)
  344. mask_pred = self._mask_point_onnx_export(x, mask_rois, det_labels,
  345. mask_pred)
  346. segm_results = self.mask_head.onnx_export(mask_pred, det_bboxes,
  347. det_labels, self.test_cfg,
  348. max_shape)
  349. segm_results = segm_results.reshape(batch_size, num_det, max_shape[0],
  350. max_shape[1])
  351. return segm_results

No Description

Contributors (3)