You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

solo_head.py 47 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from mmcv.cnn import ConvModule
  8. from mmdet.core import InstanceData, mask_matrix_nms, multi_apply
  9. from mmdet.core.utils import center_of_mass, generate_coordinate
  10. from mmdet.models.builder import HEADS, build_loss
  11. from .base_mask_head import BaseMaskHead
  12. @HEADS.register_module()
  13. class SOLOHead(BaseMaskHead):
  14. """SOLO mask head used in `SOLO: Segmenting Objects by Locations.
  15. <https://arxiv.org/abs/1912.04488>`_
  16. Args:
  17. num_classes (int): Number of categories excluding the background
  18. category.
  19. in_channels (int): Number of channels in the input feature map.
  20. feat_channels (int): Number of hidden channels. Used in child classes.
  21. Default: 256.
  22. stacked_convs (int): Number of stacking convs of the head.
  23. Default: 4.
  24. strides (tuple): Downsample factor of each feature map.
  25. scale_ranges (tuple[tuple[int, int]]): Area range of multiple
  26. level masks, in the format [(min1, max1), (min2, max2), ...].
  27. A range of (16, 64) means the area range between (16, 64).
  28. pos_scale (float): Constant scale factor to control the center region.
  29. num_grids (list[int]): Divided image into a uniform grids, each
  30. feature map has a different grid value. The number of output
  31. channels is grid ** 2. Default: [40, 36, 24, 16, 12].
  32. cls_down_index (int): The index of downsample operation in
  33. classification branch. Default: 0.
  34. loss_mask (dict): Config of mask loss.
  35. loss_cls (dict): Config of classification loss.
  36. norm_cfg (dict): dictionary to construct and config norm layer.
  37. Default: norm_cfg=dict(type='GN', num_groups=32,
  38. requires_grad=True).
  39. train_cfg (dict): Training config of head.
  40. test_cfg (dict): Testing config of head.
  41. init_cfg (dict or list[dict], optional): Initialization config dict.
  42. """
  43. def __init__(
  44. self,
  45. num_classes,
  46. in_channels,
  47. feat_channels=256,
  48. stacked_convs=4,
  49. strides=(4, 8, 16, 32, 64),
  50. scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
  51. pos_scale=0.2,
  52. num_grids=[40, 36, 24, 16, 12],
  53. cls_down_index=0,
  54. loss_mask=None,
  55. loss_cls=None,
  56. norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
  57. train_cfg=None,
  58. test_cfg=None,
  59. init_cfg=[
  60. dict(type='Normal', layer='Conv2d', std=0.01),
  61. dict(
  62. type='Normal',
  63. std=0.01,
  64. bias_prob=0.01,
  65. override=dict(name='conv_mask_list')),
  66. dict(
  67. type='Normal',
  68. std=0.01,
  69. bias_prob=0.01,
  70. override=dict(name='conv_cls'))
  71. ],
  72. ):
  73. super(SOLOHead, self).__init__(init_cfg)
  74. self.num_classes = num_classes
  75. self.cls_out_channels = self.num_classes
  76. self.in_channels = in_channels
  77. self.feat_channels = feat_channels
  78. self.stacked_convs = stacked_convs
  79. self.strides = strides
  80. self.num_grids = num_grids
  81. # number of FPN feats
  82. self.num_levels = len(strides)
  83. assert self.num_levels == len(scale_ranges) == len(num_grids)
  84. self.scale_ranges = scale_ranges
  85. self.pos_scale = pos_scale
  86. self.cls_down_index = cls_down_index
  87. self.loss_cls = build_loss(loss_cls)
  88. self.loss_mask = build_loss(loss_mask)
  89. self.norm_cfg = norm_cfg
  90. self.init_cfg = init_cfg
  91. self.train_cfg = train_cfg
  92. self.test_cfg = test_cfg
  93. self._init_layers()
  94. def _init_layers(self):
  95. self.mask_convs = nn.ModuleList()
  96. self.cls_convs = nn.ModuleList()
  97. for i in range(self.stacked_convs):
  98. chn = self.in_channels + 2 if i == 0 else self.feat_channels
  99. self.mask_convs.append(
  100. ConvModule(
  101. chn,
  102. self.feat_channels,
  103. 3,
  104. stride=1,
  105. padding=1,
  106. norm_cfg=self.norm_cfg))
  107. chn = self.in_channels if i == 0 else self.feat_channels
  108. self.cls_convs.append(
  109. ConvModule(
  110. chn,
  111. self.feat_channels,
  112. 3,
  113. stride=1,
  114. padding=1,
  115. norm_cfg=self.norm_cfg))
  116. self.conv_mask_list = nn.ModuleList()
  117. for num_grid in self.num_grids:
  118. self.conv_mask_list.append(
  119. nn.Conv2d(self.feat_channels, num_grid**2, 1))
  120. self.conv_cls = nn.Conv2d(
  121. self.feat_channels, self.cls_out_channels, 3, padding=1)
  122. def resize_feats(self, feats):
  123. """Downsample the first feat and upsample last feat in feats."""
  124. out = []
  125. for i in range(len(feats)):
  126. if i == 0:
  127. out.append(
  128. F.interpolate(feats[0], scale_factor=0.5, mode='bilinear'))
  129. elif i == len(feats) - 1:
  130. out.append(
  131. F.interpolate(
  132. feats[i],
  133. size=feats[i - 1].shape[-2:],
  134. mode='bilinear'))
  135. else:
  136. out.append(feats[i])
  137. return out
  138. def forward(self, feats):
  139. assert len(feats) == self.num_levels
  140. feats = self.resize_feats(feats)
  141. mlvl_mask_preds = []
  142. mlvl_cls_preds = []
  143. for i in range(self.num_levels):
  144. x = feats[i]
  145. mask_feat = x
  146. cls_feat = x
  147. # generate and concat the coordinate
  148. coord_feat = generate_coordinate(mask_feat.size(),
  149. mask_feat.device)
  150. mask_feat = torch.cat([mask_feat, coord_feat], 1)
  151. for mask_layer in (self.mask_convs):
  152. mask_feat = mask_layer(mask_feat)
  153. mask_feat = F.interpolate(
  154. mask_feat, scale_factor=2, mode='bilinear')
  155. mask_pred = self.conv_mask_list[i](mask_feat)
  156. # cls branch
  157. for j, cls_layer in enumerate(self.cls_convs):
  158. if j == self.cls_down_index:
  159. num_grid = self.num_grids[i]
  160. cls_feat = F.interpolate(
  161. cls_feat, size=num_grid, mode='bilinear')
  162. cls_feat = cls_layer(cls_feat)
  163. cls_pred = self.conv_cls(cls_feat)
  164. if not self.training:
  165. feat_wh = feats[0].size()[-2:]
  166. upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
  167. mask_pred = F.interpolate(
  168. mask_pred.sigmoid(), size=upsampled_size, mode='bilinear')
  169. cls_pred = cls_pred.sigmoid()
  170. # get local maximum
  171. local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
  172. keep_mask = local_max[:, :, :-1, :-1] == cls_pred
  173. cls_pred = cls_pred * keep_mask
  174. mlvl_mask_preds.append(mask_pred)
  175. mlvl_cls_preds.append(cls_pred)
  176. return mlvl_mask_preds, mlvl_cls_preds
  177. def loss(self,
  178. mlvl_mask_preds,
  179. mlvl_cls_preds,
  180. gt_labels,
  181. gt_masks,
  182. img_metas,
  183. gt_bboxes=None,
  184. **kwargs):
  185. """Calculate the loss of total batch.
  186. Args:
  187. mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
  188. Each element in the list has shape
  189. (batch_size, num_grids**2 ,h ,w).
  190. mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element
  191. in the list has shape
  192. (batch_size, num_classes, num_grids ,num_grids).
  193. gt_labels (list[Tensor]): Labels of multiple images.
  194. gt_masks (list[Tensor]): Ground truth masks of multiple images.
  195. Each has shape (num_instances, h, w).
  196. img_metas (list[dict]): Meta information of multiple images.
  197. gt_bboxes (list[Tensor]): Ground truth bboxes of multiple
  198. images. Default: None.
  199. Returns:
  200. dict[str, Tensor]: A dictionary of loss components.
  201. """
  202. num_levels = self.num_levels
  203. num_imgs = len(gt_labels)
  204. featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds]
  205. # `BoolTensor` in `pos_masks` represent
  206. # whether the corresponding point is
  207. # positive
  208. pos_mask_targets, labels, pos_masks = multi_apply(
  209. self._get_targets_single,
  210. gt_bboxes,
  211. gt_labels,
  212. gt_masks,
  213. featmap_sizes=featmap_sizes)
  214. # change from the outside list meaning multi images
  215. # to the outside list meaning multi levels
  216. mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
  217. mlvl_pos_mask_preds = [[] for _ in range(num_levels)]
  218. mlvl_pos_masks = [[] for _ in range(num_levels)]
  219. mlvl_labels = [[] for _ in range(num_levels)]
  220. for img_id in range(num_imgs):
  221. assert num_levels == len(pos_mask_targets[img_id])
  222. for lvl in range(num_levels):
  223. mlvl_pos_mask_targets[lvl].append(
  224. pos_mask_targets[img_id][lvl])
  225. mlvl_pos_mask_preds[lvl].append(
  226. mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...])
  227. mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten())
  228. mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
  229. # cat multiple image
  230. temp_mlvl_cls_preds = []
  231. for lvl in range(num_levels):
  232. mlvl_pos_mask_targets[lvl] = torch.cat(
  233. mlvl_pos_mask_targets[lvl], dim=0)
  234. mlvl_pos_mask_preds[lvl] = torch.cat(
  235. mlvl_pos_mask_preds[lvl], dim=0)
  236. mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0)
  237. mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
  238. temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
  239. 0, 2, 3, 1).reshape(-1, self.cls_out_channels))
  240. num_pos = sum(item.sum() for item in mlvl_pos_masks)
  241. # dice loss
  242. loss_mask = []
  243. for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets):
  244. if pred.size()[0] == 0:
  245. loss_mask.append(pred.sum().unsqueeze(0))
  246. continue
  247. loss_mask.append(
  248. self.loss_mask(pred, target, reduction_override='none'))
  249. if num_pos > 0:
  250. loss_mask = torch.cat(loss_mask).sum() / num_pos
  251. else:
  252. loss_mask = torch.cat(loss_mask).mean()
  253. flatten_labels = torch.cat(mlvl_labels)
  254. flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
  255. loss_cls = self.loss_cls(
  256. flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
  257. return dict(loss_mask=loss_mask, loss_cls=loss_cls)
  258. def _get_targets_single(self,
  259. gt_bboxes,
  260. gt_labels,
  261. gt_masks,
  262. featmap_sizes=None):
  263. """Compute targets for predictions of single image.
  264. Args:
  265. gt_bboxes (Tensor): Ground truth bbox of each instance,
  266. shape (num_gts, 4).
  267. gt_labels (Tensor): Ground truth label of each instance,
  268. shape (num_gts,).
  269. gt_masks (Tensor): Ground truth mask of each instance,
  270. shape (num_gts, h, w).
  271. featmap_sizes (list[:obj:`torch.size`]): Size of each
  272. feature map from feature pyramid, each element
  273. means (feat_h, feat_w). Default: None.
  274. Returns:
  275. Tuple: Usually returns a tuple containing targets for predictions.
  276. - mlvl_pos_mask_targets (list[Tensor]): Each element represent
  277. the binary mask targets for positive points in this
  278. level, has shape (num_pos, out_h, out_w).
  279. - mlvl_labels (list[Tensor]): Each element is
  280. classification labels for all
  281. points in this level, has shape
  282. (num_grid, num_grid).
  283. - mlvl_pos_masks (list[Tensor]): Each element is
  284. a `BoolTensor` to represent whether the
  285. corresponding point in single level
  286. is positive, has shape (num_grid **2).
  287. """
  288. device = gt_labels.device
  289. gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
  290. (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
  291. mlvl_pos_mask_targets = []
  292. mlvl_labels = []
  293. mlvl_pos_masks = []
  294. for (lower_bound, upper_bound), stride, featmap_size, num_grid \
  295. in zip(self.scale_ranges, self.strides,
  296. featmap_sizes, self.num_grids):
  297. mask_target = torch.zeros(
  298. [num_grid**2, featmap_size[0], featmap_size[1]],
  299. dtype=torch.uint8,
  300. device=device)
  301. # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
  302. labels = torch.zeros([num_grid, num_grid],
  303. dtype=torch.int64,
  304. device=device) + self.num_classes
  305. pos_mask = torch.zeros([num_grid**2],
  306. dtype=torch.bool,
  307. device=device)
  308. gt_inds = ((gt_areas >= lower_bound) &
  309. (gt_areas <= upper_bound)).nonzero().flatten()
  310. if len(gt_inds) == 0:
  311. mlvl_pos_mask_targets.append(
  312. mask_target.new_zeros(0, featmap_size[0], featmap_size[1]))
  313. mlvl_labels.append(labels)
  314. mlvl_pos_masks.append(pos_mask)
  315. continue
  316. hit_gt_bboxes = gt_bboxes[gt_inds]
  317. hit_gt_labels = gt_labels[gt_inds]
  318. hit_gt_masks = gt_masks[gt_inds, ...]
  319. pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] -
  320. hit_gt_bboxes[:, 0]) * self.pos_scale
  321. pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] -
  322. hit_gt_bboxes[:, 1]) * self.pos_scale
  323. # Make sure hit_gt_masks has a value
  324. valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0
  325. output_stride = stride / 2
  326. for gt_mask, gt_label, pos_h_range, pos_w_range, \
  327. valid_mask_flag in \
  328. zip(hit_gt_masks, hit_gt_labels, pos_h_ranges,
  329. pos_w_ranges, valid_mask_flags):
  330. if not valid_mask_flag:
  331. continue
  332. upsampled_size = (featmap_sizes[0][0] * 4,
  333. featmap_sizes[0][1] * 4)
  334. center_h, center_w = center_of_mass(gt_mask)
  335. coord_w = int(
  336. (center_w / upsampled_size[1]) // (1. / num_grid))
  337. coord_h = int(
  338. (center_h / upsampled_size[0]) // (1. / num_grid))
  339. # left, top, right, down
  340. top_box = max(
  341. 0,
  342. int(((center_h - pos_h_range) / upsampled_size[0]) //
  343. (1. / num_grid)))
  344. down_box = min(
  345. num_grid - 1,
  346. int(((center_h + pos_h_range) / upsampled_size[0]) //
  347. (1. / num_grid)))
  348. left_box = max(
  349. 0,
  350. int(((center_w - pos_w_range) / upsampled_size[1]) //
  351. (1. / num_grid)))
  352. right_box = min(
  353. num_grid - 1,
  354. int(((center_w + pos_w_range) / upsampled_size[1]) //
  355. (1. / num_grid)))
  356. top = max(top_box, coord_h - 1)
  357. down = min(down_box, coord_h + 1)
  358. left = max(coord_w - 1, left_box)
  359. right = min(right_box, coord_w + 1)
  360. labels[top:(down + 1), left:(right + 1)] = gt_label
  361. # ins
  362. gt_mask = np.uint8(gt_mask.cpu().numpy())
  363. # Follow the original implementation, F.interpolate is
  364. # different from cv2 and opencv
  365. gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride)
  366. gt_mask = torch.from_numpy(gt_mask).to(device=device)
  367. for i in range(top, down + 1):
  368. for j in range(left, right + 1):
  369. index = int(i * num_grid + j)
  370. mask_target[index, :gt_mask.shape[0], :gt_mask.
  371. shape[1]] = gt_mask
  372. pos_mask[index] = True
  373. mlvl_pos_mask_targets.append(mask_target[pos_mask])
  374. mlvl_labels.append(labels)
  375. mlvl_pos_masks.append(pos_mask)
  376. return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks
  377. def get_results(self, mlvl_mask_preds, mlvl_cls_scores, img_metas,
  378. **kwargs):
  379. """Get multi-image mask results.
  380. Args:
  381. mlvl_mask_preds (list[Tensor]): Multi-level mask prediction.
  382. Each element in the list has shape
  383. (batch_size, num_grids**2 ,h ,w).
  384. mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
  385. in the list has shape
  386. (batch_size, num_classes, num_grids ,num_grids).
  387. img_metas (list[dict]): Meta information of all images.
  388. Returns:
  389. list[:obj:`InstanceData`]: Processed results of multiple
  390. images.Each :obj:`InstanceData` usually contains
  391. following keys.
  392. - scores (Tensor): Classification scores, has shape
  393. (num_instance,).
  394. - labels (Tensor): Has shape (num_instances,).
  395. - masks (Tensor): Processed mask results, has
  396. shape (num_instances, h, w).
  397. """
  398. mlvl_cls_scores = [
  399. item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
  400. ]
  401. assert len(mlvl_mask_preds) == len(mlvl_cls_scores)
  402. num_levels = len(mlvl_cls_scores)
  403. results_list = []
  404. for img_id in range(len(img_metas)):
  405. cls_pred_list = [
  406. mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels)
  407. for lvl in range(num_levels)
  408. ]
  409. mask_pred_list = [
  410. mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels)
  411. ]
  412. cls_pred_list = torch.cat(cls_pred_list, dim=0)
  413. mask_pred_list = torch.cat(mask_pred_list, dim=0)
  414. results = self._get_results_single(
  415. cls_pred_list, mask_pred_list, img_meta=img_metas[img_id])
  416. results_list.append(results)
  417. return results_list
  418. def _get_results_single(self, cls_scores, mask_preds, img_meta, cfg=None):
  419. """Get processed mask related results of single image.
  420. Args:
  421. cls_scores (Tensor): Classification score of all points
  422. in single image, has shape (num_points, num_classes).
  423. mask_preds (Tensor): Mask prediction of all points in
  424. single image, has shape (num_points, feat_h, feat_w).
  425. img_meta (dict): Meta information of corresponding image.
  426. cfg (dict, optional): Config used in test phase.
  427. Default: None.
  428. Returns:
  429. :obj:`InstanceData`: Processed results of single image.
  430. it usually contains following keys.
  431. - scores (Tensor): Classification scores, has shape
  432. (num_instance,).
  433. - labels (Tensor): Has shape (num_instances,).
  434. - masks (Tensor): Processed mask results, has
  435. shape (num_instances, h, w).
  436. """
  437. def empty_results(results, cls_scores):
  438. """Generate a empty results."""
  439. results.scores = cls_scores.new_ones(0)
  440. results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2])
  441. results.labels = cls_scores.new_ones(0)
  442. return results
  443. cfg = self.test_cfg if cfg is None else cfg
  444. assert len(cls_scores) == len(mask_preds)
  445. results = InstanceData(img_meta)
  446. featmap_size = mask_preds.size()[-2:]
  447. img_shape = results.img_shape
  448. ori_shape = results.ori_shape
  449. h, w, _ = img_shape
  450. upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
  451. score_mask = (cls_scores > cfg.score_thr)
  452. cls_scores = cls_scores[score_mask]
  453. if len(cls_scores) == 0:
  454. return empty_results(results, cls_scores)
  455. inds = score_mask.nonzero()
  456. cls_labels = inds[:, 1]
  457. # Filter the mask mask with an area is smaller than
  458. # stride of corresponding feature level
  459. lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0)
  460. strides = cls_scores.new_ones(lvl_interval[-1])
  461. strides[:lvl_interval[0]] *= self.strides[0]
  462. for lvl in range(1, self.num_levels):
  463. strides[lvl_interval[lvl -
  464. 1]:lvl_interval[lvl]] *= self.strides[lvl]
  465. strides = strides[inds[:, 0]]
  466. mask_preds = mask_preds[inds[:, 0]]
  467. masks = mask_preds > cfg.mask_thr
  468. sum_masks = masks.sum((1, 2)).float()
  469. keep = sum_masks > strides
  470. if keep.sum() == 0:
  471. return empty_results(results, cls_scores)
  472. masks = masks[keep]
  473. mask_preds = mask_preds[keep]
  474. sum_masks = sum_masks[keep]
  475. cls_scores = cls_scores[keep]
  476. cls_labels = cls_labels[keep]
  477. # maskness.
  478. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
  479. cls_scores *= mask_scores
  480. scores, labels, _, keep_inds = mask_matrix_nms(
  481. masks,
  482. cls_labels,
  483. cls_scores,
  484. mask_area=sum_masks,
  485. nms_pre=cfg.nms_pre,
  486. max_num=cfg.max_per_img,
  487. kernel=cfg.kernel,
  488. sigma=cfg.sigma,
  489. filter_thr=cfg.filter_thr)
  490. mask_preds = mask_preds[keep_inds]
  491. mask_preds = F.interpolate(
  492. mask_preds.unsqueeze(0), size=upsampled_size,
  493. mode='bilinear')[:, :, :h, :w]
  494. mask_preds = F.interpolate(
  495. mask_preds, size=ori_shape[:2], mode='bilinear').squeeze(0)
  496. masks = mask_preds > cfg.mask_thr
  497. results.masks = masks
  498. results.labels = labels
  499. results.scores = scores
  500. return results
  501. @HEADS.register_module()
  502. class DecoupledSOLOHead(SOLOHead):
  503. """Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations.
  504. <https://arxiv.org/abs/1912.04488>`_
  505. Args:
  506. init_cfg (dict or list[dict], optional): Initialization config dict.
  507. """
  508. def __init__(self,
  509. *args,
  510. init_cfg=[
  511. dict(type='Normal', layer='Conv2d', std=0.01),
  512. dict(
  513. type='Normal',
  514. std=0.01,
  515. bias_prob=0.01,
  516. override=dict(name='conv_mask_list_x')),
  517. dict(
  518. type='Normal',
  519. std=0.01,
  520. bias_prob=0.01,
  521. override=dict(name='conv_mask_list_y')),
  522. dict(
  523. type='Normal',
  524. std=0.01,
  525. bias_prob=0.01,
  526. override=dict(name='conv_cls'))
  527. ],
  528. **kwargs):
  529. super(DecoupledSOLOHead, self).__init__(
  530. *args, init_cfg=init_cfg, **kwargs)
  531. def _init_layers(self):
  532. self.mask_convs_x = nn.ModuleList()
  533. self.mask_convs_y = nn.ModuleList()
  534. self.cls_convs = nn.ModuleList()
  535. for i in range(self.stacked_convs):
  536. chn = self.in_channels + 1 if i == 0 else self.feat_channels
  537. self.mask_convs_x.append(
  538. ConvModule(
  539. chn,
  540. self.feat_channels,
  541. 3,
  542. stride=1,
  543. padding=1,
  544. norm_cfg=self.norm_cfg))
  545. self.mask_convs_y.append(
  546. ConvModule(
  547. chn,
  548. self.feat_channels,
  549. 3,
  550. stride=1,
  551. padding=1,
  552. norm_cfg=self.norm_cfg))
  553. chn = self.in_channels if i == 0 else self.feat_channels
  554. self.cls_convs.append(
  555. ConvModule(
  556. chn,
  557. self.feat_channels,
  558. 3,
  559. stride=1,
  560. padding=1,
  561. norm_cfg=self.norm_cfg))
  562. self.conv_mask_list_x = nn.ModuleList()
  563. self.conv_mask_list_y = nn.ModuleList()
  564. for num_grid in self.num_grids:
  565. self.conv_mask_list_x.append(
  566. nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
  567. self.conv_mask_list_y.append(
  568. nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
  569. self.conv_cls = nn.Conv2d(
  570. self.feat_channels, self.cls_out_channels, 3, padding=1)
  571. def forward(self, feats):
  572. assert len(feats) == self.num_levels
  573. feats = self.resize_feats(feats)
  574. mask_preds_x = []
  575. mask_preds_y = []
  576. cls_preds = []
  577. for i in range(self.num_levels):
  578. x = feats[i]
  579. mask_feat = x
  580. cls_feat = x
  581. # generate and concat the coordinate
  582. coord_feat = generate_coordinate(mask_feat.size(),
  583. mask_feat.device)
  584. mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1)
  585. mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1)
  586. for mask_layer_x, mask_layer_y in \
  587. zip(self.mask_convs_x, self.mask_convs_y):
  588. mask_feat_x = mask_layer_x(mask_feat_x)
  589. mask_feat_y = mask_layer_y(mask_feat_y)
  590. mask_feat_x = F.interpolate(
  591. mask_feat_x, scale_factor=2, mode='bilinear')
  592. mask_feat_y = F.interpolate(
  593. mask_feat_y, scale_factor=2, mode='bilinear')
  594. mask_pred_x = self.conv_mask_list_x[i](mask_feat_x)
  595. mask_pred_y = self.conv_mask_list_y[i](mask_feat_y)
  596. # cls branch
  597. for j, cls_layer in enumerate(self.cls_convs):
  598. if j == self.cls_down_index:
  599. num_grid = self.num_grids[i]
  600. cls_feat = F.interpolate(
  601. cls_feat, size=num_grid, mode='bilinear')
  602. cls_feat = cls_layer(cls_feat)
  603. cls_pred = self.conv_cls(cls_feat)
  604. if not self.training:
  605. feat_wh = feats[0].size()[-2:]
  606. upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
  607. mask_pred_x = F.interpolate(
  608. mask_pred_x.sigmoid(),
  609. size=upsampled_size,
  610. mode='bilinear')
  611. mask_pred_y = F.interpolate(
  612. mask_pred_y.sigmoid(),
  613. size=upsampled_size,
  614. mode='bilinear')
  615. cls_pred = cls_pred.sigmoid()
  616. # get local maximum
  617. local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
  618. keep_mask = local_max[:, :, :-1, :-1] == cls_pred
  619. cls_pred = cls_pred * keep_mask
  620. mask_preds_x.append(mask_pred_x)
  621. mask_preds_y.append(mask_pred_y)
  622. cls_preds.append(cls_pred)
  623. return mask_preds_x, mask_preds_y, cls_preds
  624. def loss(self,
  625. mlvl_mask_preds_x,
  626. mlvl_mask_preds_y,
  627. mlvl_cls_preds,
  628. gt_labels,
  629. gt_masks,
  630. img_metas,
  631. gt_bboxes=None,
  632. **kwargs):
  633. """Calculate the loss of total batch.
  634. Args:
  635. mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
  636. from x branch. Each element in the list has shape
  637. (batch_size, num_grids ,h ,w).
  638. mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
  639. from y branch. Each element in the list has shape
  640. (batch_size, num_grids ,h ,w).
  641. mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element
  642. in the list has shape
  643. (batch_size, num_classes, num_grids ,num_grids).
  644. gt_labels (list[Tensor]): Labels of multiple images.
  645. gt_masks (list[Tensor]): Ground truth masks of multiple images.
  646. Each has shape (num_instances, h, w).
  647. img_metas (list[dict]): Meta information of multiple images.
  648. gt_bboxes (list[Tensor]): Ground truth bboxes of multiple
  649. images. Default: None.
  650. Returns:
  651. dict[str, Tensor]: A dictionary of loss components.
  652. """
  653. num_levels = self.num_levels
  654. num_imgs = len(gt_labels)
  655. featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x]
  656. pos_mask_targets, labels, \
  657. xy_pos_indexes = \
  658. multi_apply(self._get_targets_single,
  659. gt_bboxes,
  660. gt_labels,
  661. gt_masks,
  662. featmap_sizes=featmap_sizes)
  663. # change from the outside list meaning multi images
  664. # to the outside list meaning multi levels
  665. mlvl_pos_mask_targets = [[] for _ in range(num_levels)]
  666. mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)]
  667. mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)]
  668. mlvl_labels = [[] for _ in range(num_levels)]
  669. for img_id in range(num_imgs):
  670. for lvl in range(num_levels):
  671. mlvl_pos_mask_targets[lvl].append(
  672. pos_mask_targets[img_id][lvl])
  673. mlvl_pos_mask_preds_x[lvl].append(
  674. mlvl_mask_preds_x[lvl][img_id,
  675. xy_pos_indexes[img_id][lvl][:, 1]])
  676. mlvl_pos_mask_preds_y[lvl].append(
  677. mlvl_mask_preds_y[lvl][img_id,
  678. xy_pos_indexes[img_id][lvl][:, 0]])
  679. mlvl_labels[lvl].append(labels[img_id][lvl].flatten())
  680. # cat multiple image
  681. temp_mlvl_cls_preds = []
  682. for lvl in range(num_levels):
  683. mlvl_pos_mask_targets[lvl] = torch.cat(
  684. mlvl_pos_mask_targets[lvl], dim=0)
  685. mlvl_pos_mask_preds_x[lvl] = torch.cat(
  686. mlvl_pos_mask_preds_x[lvl], dim=0)
  687. mlvl_pos_mask_preds_y[lvl] = torch.cat(
  688. mlvl_pos_mask_preds_y[lvl], dim=0)
  689. mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0)
  690. temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute(
  691. 0, 2, 3, 1).reshape(-1, self.cls_out_channels))
  692. num_pos = 0.
  693. # dice loss
  694. loss_mask = []
  695. for pred_x, pred_y, target in \
  696. zip(mlvl_pos_mask_preds_x,
  697. mlvl_pos_mask_preds_y, mlvl_pos_mask_targets):
  698. num_masks = pred_x.size(0)
  699. if num_masks == 0:
  700. # make sure can get grad
  701. loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0))
  702. continue
  703. num_pos += num_masks
  704. pred_mask = pred_y.sigmoid() * pred_x.sigmoid()
  705. loss_mask.append(
  706. self.loss_mask(pred_mask, target, reduction_override='none'))
  707. if num_pos > 0:
  708. loss_mask = torch.cat(loss_mask).sum() / num_pos
  709. else:
  710. loss_mask = torch.cat(loss_mask).mean()
  711. # cate
  712. flatten_labels = torch.cat(mlvl_labels)
  713. flatten_cls_preds = torch.cat(temp_mlvl_cls_preds)
  714. loss_cls = self.loss_cls(
  715. flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1)
  716. return dict(loss_mask=loss_mask, loss_cls=loss_cls)
  717. def _get_targets_single(self,
  718. gt_bboxes,
  719. gt_labels,
  720. gt_masks,
  721. featmap_sizes=None):
  722. """Compute targets for predictions of single image.
  723. Args:
  724. gt_bboxes (Tensor): Ground truth bbox of each instance,
  725. shape (num_gts, 4).
  726. gt_labels (Tensor): Ground truth label of each instance,
  727. shape (num_gts,).
  728. gt_masks (Tensor): Ground truth mask of each instance,
  729. shape (num_gts, h, w).
  730. featmap_sizes (list[:obj:`torch.size`]): Size of each
  731. feature map from feature pyramid, each element
  732. means (feat_h, feat_w). Default: None.
  733. Returns:
  734. Tuple: Usually returns a tuple containing targets for predictions.
  735. - mlvl_pos_mask_targets (list[Tensor]): Each element represent
  736. the binary mask targets for positive points in this
  737. level, has shape (num_pos, out_h, out_w).
  738. - mlvl_labels (list[Tensor]): Each element is
  739. classification labels for all
  740. points in this level, has shape
  741. (num_grid, num_grid).
  742. - mlvl_xy_pos_indexes (list[Tensor]): Each element
  743. in the list contains the index of positive samples in
  744. corresponding level, has shape (num_pos, 2), last
  745. dimension 2 present (index_x, index_y).
  746. """
  747. mlvl_pos_mask_targets, mlvl_labels, \
  748. mlvl_pos_masks = \
  749. super()._get_targets_single(gt_bboxes, gt_labels, gt_masks,
  750. featmap_sizes=featmap_sizes)
  751. mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero()
  752. for item in mlvl_labels]
  753. return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes
  754. def get_results(self,
  755. mlvl_mask_preds_x,
  756. mlvl_mask_preds_y,
  757. mlvl_cls_scores,
  758. img_metas,
  759. rescale=None,
  760. **kwargs):
  761. """Get multi-image mask results.
  762. Args:
  763. mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction
  764. from x branch. Each element in the list has shape
  765. (batch_size, num_grids ,h ,w).
  766. mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction
  767. from y branch. Each element in the list has shape
  768. (batch_size, num_grids ,h ,w).
  769. mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element
  770. in the list has shape
  771. (batch_size, num_classes ,num_grids ,num_grids).
  772. img_metas (list[dict]): Meta information of all images.
  773. Returns:
  774. list[:obj:`InstanceData`]: Processed results of multiple
  775. images.Each :obj:`InstanceData` usually contains
  776. following keys.
  777. - scores (Tensor): Classification scores, has shape
  778. (num_instance,).
  779. - labels (Tensor): Has shape (num_instances,).
  780. - masks (Tensor): Processed mask results, has
  781. shape (num_instances, h, w).
  782. """
  783. mlvl_cls_scores = [
  784. item.permute(0, 2, 3, 1) for item in mlvl_cls_scores
  785. ]
  786. assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores)
  787. num_levels = len(mlvl_cls_scores)
  788. results_list = []
  789. for img_id in range(len(img_metas)):
  790. cls_pred_list = [
  791. mlvl_cls_scores[i][img_id].view(
  792. -1, self.cls_out_channels).detach()
  793. for i in range(num_levels)
  794. ]
  795. mask_pred_list_x = [
  796. mlvl_mask_preds_x[i][img_id] for i in range(num_levels)
  797. ]
  798. mask_pred_list_y = [
  799. mlvl_mask_preds_y[i][img_id] for i in range(num_levels)
  800. ]
  801. cls_pred_list = torch.cat(cls_pred_list, dim=0)
  802. mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0)
  803. mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0)
  804. results = self._get_results_single(
  805. cls_pred_list,
  806. mask_pred_list_x,
  807. mask_pred_list_y,
  808. img_meta=img_metas[img_id],
  809. cfg=self.test_cfg)
  810. results_list.append(results)
  811. return results_list
  812. def _get_results_single(self, cls_scores, mask_preds_x, mask_preds_y,
  813. img_meta, cfg):
  814. """Get processed mask related results of single image.
  815. Args:
  816. cls_scores (Tensor): Classification score of all points
  817. in single image, has shape (num_points, num_classes).
  818. mask_preds_x (Tensor): Mask prediction of x branch of
  819. all points in single image, has shape
  820. (sum_num_grids, feat_h, feat_w).
  821. mask_preds_y (Tensor): Mask prediction of y branch of
  822. all points in single image, has shape
  823. (sum_num_grids, feat_h, feat_w).
  824. img_meta (dict): Meta information of corresponding image.
  825. cfg (dict): Config used in test phase.
  826. Returns:
  827. :obj:`InstanceData`: Processed results of single image.
  828. it usually contains following keys.
  829. - scores (Tensor): Classification scores, has shape
  830. (num_instance,).
  831. - labels (Tensor): Has shape (num_instances,).
  832. - masks (Tensor): Processed mask results, has
  833. shape (num_instances, h, w).
  834. """
  835. def empty_results(results, cls_scores):
  836. """Generate a empty results."""
  837. results.scores = cls_scores.new_ones(0)
  838. results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2])
  839. results.labels = cls_scores.new_ones(0)
  840. return results
  841. cfg = self.test_cfg if cfg is None else cfg
  842. results = InstanceData(img_meta)
  843. img_shape = results.img_shape
  844. ori_shape = results.ori_shape
  845. h, w, _ = img_shape
  846. featmap_size = mask_preds_x.size()[-2:]
  847. upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4)
  848. score_mask = (cls_scores > cfg.score_thr)
  849. cls_scores = cls_scores[score_mask]
  850. inds = score_mask.nonzero()
  851. lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0)
  852. num_all_points = lvl_interval[-1]
  853. lvl_start_index = inds.new_ones(num_all_points)
  854. num_grids = inds.new_ones(num_all_points)
  855. seg_size = inds.new_tensor(self.num_grids).cumsum(0)
  856. mask_lvl_start_index = inds.new_ones(num_all_points)
  857. strides = inds.new_ones(num_all_points)
  858. lvl_start_index[:lvl_interval[0]] *= 0
  859. mask_lvl_start_index[:lvl_interval[0]] *= 0
  860. num_grids[:lvl_interval[0]] *= self.num_grids[0]
  861. strides[:lvl_interval[0]] *= self.strides[0]
  862. for lvl in range(1, self.num_levels):
  863. lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
  864. lvl_interval[lvl - 1]
  865. mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
  866. seg_size[lvl - 1]
  867. num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
  868. self.num_grids[lvl]
  869. strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \
  870. self.strides[lvl]
  871. lvl_start_index = lvl_start_index[inds[:, 0]]
  872. mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]]
  873. num_grids = num_grids[inds[:, 0]]
  874. strides = strides[inds[:, 0]]
  875. y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids
  876. x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids
  877. y_inds = mask_lvl_start_index + y_lvl_offset
  878. x_inds = mask_lvl_start_index + x_lvl_offset
  879. cls_labels = inds[:, 1]
  880. mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...]
  881. masks = mask_preds > cfg.mask_thr
  882. sum_masks = masks.sum((1, 2)).float()
  883. keep = sum_masks > strides
  884. if keep.sum() == 0:
  885. return empty_results(results, cls_scores)
  886. masks = masks[keep]
  887. mask_preds = mask_preds[keep]
  888. sum_masks = sum_masks[keep]
  889. cls_scores = cls_scores[keep]
  890. cls_labels = cls_labels[keep]
  891. # maskness.
  892. mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks
  893. cls_scores *= mask_scores
  894. scores, labels, _, keep_inds = mask_matrix_nms(
  895. masks,
  896. cls_labels,
  897. cls_scores,
  898. mask_area=sum_masks,
  899. nms_pre=cfg.nms_pre,
  900. max_num=cfg.max_per_img,
  901. kernel=cfg.kernel,
  902. sigma=cfg.sigma,
  903. filter_thr=cfg.filter_thr)
  904. mask_preds = mask_preds[keep_inds]
  905. mask_preds = F.interpolate(
  906. mask_preds.unsqueeze(0), size=upsampled_size,
  907. mode='bilinear')[:, :, :h, :w]
  908. mask_preds = F.interpolate(
  909. mask_preds, size=ori_shape[:2], mode='bilinear').squeeze(0)
  910. masks = mask_preds > cfg.mask_thr
  911. results.masks = masks
  912. results.labels = labels
  913. results.scores = scores
  914. return results
  915. @HEADS.register_module()
  916. class DecoupledSOLOLightHead(DecoupledSOLOHead):
  917. """Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by
  918. Locations <https://arxiv.org/abs/1912.04488>`_
  919. Args:
  920. with_dcn (bool): Whether use dcn in mask_convs and cls_convs,
  921. default: False.
  922. init_cfg (dict or list[dict], optional): Initialization config dict.
  923. """
  924. def __init__(self,
  925. *args,
  926. dcn_cfg=None,
  927. init_cfg=[
  928. dict(type='Normal', layer='Conv2d', std=0.01),
  929. dict(
  930. type='Normal',
  931. std=0.01,
  932. bias_prob=0.01,
  933. override=dict(name='conv_mask_list_x')),
  934. dict(
  935. type='Normal',
  936. std=0.01,
  937. bias_prob=0.01,
  938. override=dict(name='conv_mask_list_y')),
  939. dict(
  940. type='Normal',
  941. std=0.01,
  942. bias_prob=0.01,
  943. override=dict(name='conv_cls'))
  944. ],
  945. **kwargs):
  946. assert dcn_cfg is None or isinstance(dcn_cfg, dict)
  947. self.dcn_cfg = dcn_cfg
  948. super(DecoupledSOLOLightHead, self).__init__(
  949. *args, init_cfg=init_cfg, **kwargs)
  950. def _init_layers(self):
  951. self.mask_convs = nn.ModuleList()
  952. self.cls_convs = nn.ModuleList()
  953. for i in range(self.stacked_convs):
  954. if self.dcn_cfg is not None\
  955. and i == self.stacked_convs - 1:
  956. conv_cfg = self.dcn_cfg
  957. else:
  958. conv_cfg = None
  959. chn = self.in_channels + 2 if i == 0 else self.feat_channels
  960. self.mask_convs.append(
  961. ConvModule(
  962. chn,
  963. self.feat_channels,
  964. 3,
  965. stride=1,
  966. padding=1,
  967. conv_cfg=conv_cfg,
  968. norm_cfg=self.norm_cfg))
  969. chn = self.in_channels if i == 0 else self.feat_channels
  970. self.cls_convs.append(
  971. ConvModule(
  972. chn,
  973. self.feat_channels,
  974. 3,
  975. stride=1,
  976. padding=1,
  977. conv_cfg=conv_cfg,
  978. norm_cfg=self.norm_cfg))
  979. self.conv_mask_list_x = nn.ModuleList()
  980. self.conv_mask_list_y = nn.ModuleList()
  981. for num_grid in self.num_grids:
  982. self.conv_mask_list_x.append(
  983. nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
  984. self.conv_mask_list_y.append(
  985. nn.Conv2d(self.feat_channels, num_grid, 3, padding=1))
  986. self.conv_cls = nn.Conv2d(
  987. self.feat_channels, self.cls_out_channels, 3, padding=1)
  988. def forward(self, feats):
  989. assert len(feats) == self.num_levels
  990. feats = self.resize_feats(feats)
  991. mask_preds_x = []
  992. mask_preds_y = []
  993. cls_preds = []
  994. for i in range(self.num_levels):
  995. x = feats[i]
  996. mask_feat = x
  997. cls_feat = x
  998. # generate and concat the coordinate
  999. coord_feat = generate_coordinate(mask_feat.size(),
  1000. mask_feat.device)
  1001. mask_feat = torch.cat([mask_feat, coord_feat], 1)
  1002. for mask_layer in self.mask_convs:
  1003. mask_feat = mask_layer(mask_feat)
  1004. mask_feat = F.interpolate(
  1005. mask_feat, scale_factor=2, mode='bilinear')
  1006. mask_pred_x = self.conv_mask_list_x[i](mask_feat)
  1007. mask_pred_y = self.conv_mask_list_y[i](mask_feat)
  1008. # cls branch
  1009. for j, cls_layer in enumerate(self.cls_convs):
  1010. if j == self.cls_down_index:
  1011. num_grid = self.num_grids[i]
  1012. cls_feat = F.interpolate(
  1013. cls_feat, size=num_grid, mode='bilinear')
  1014. cls_feat = cls_layer(cls_feat)
  1015. cls_pred = self.conv_cls(cls_feat)
  1016. if not self.training:
  1017. feat_wh = feats[0].size()[-2:]
  1018. upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2)
  1019. mask_pred_x = F.interpolate(
  1020. mask_pred_x.sigmoid(),
  1021. size=upsampled_size,
  1022. mode='bilinear')
  1023. mask_pred_y = F.interpolate(
  1024. mask_pred_y.sigmoid(),
  1025. size=upsampled_size,
  1026. mode='bilinear')
  1027. cls_pred = cls_pred.sigmoid()
  1028. # get local maximum
  1029. local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1)
  1030. keep_mask = local_max[:, :, :-1, :-1] == cls_pred
  1031. cls_pred = cls_pred * keep_mask
  1032. mask_preds_x.append(mask_pred_x)
  1033. mask_preds_y.append(mask_pred_y)
  1034. cls_preds.append(cls_pred)
  1035. return mask_preds_x, mask_preds_y, cls_preds

No Description

Contributors (3)