You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grid_head.py 16 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import numpy as np
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from mmcv.cnn import ConvModule
  7. from mmcv.runner import BaseModule
  8. from mmdet.models.builder import HEADS, build_loss
  9. @HEADS.register_module()
  10. class GridHead(BaseModule):
  11. def __init__(self,
  12. grid_points=9,
  13. num_convs=8,
  14. roi_feat_size=14,
  15. in_channels=256,
  16. conv_kernel_size=3,
  17. point_feat_channels=64,
  18. deconv_kernel_size=4,
  19. class_agnostic=False,
  20. loss_grid=dict(
  21. type='CrossEntropyLoss', use_sigmoid=True,
  22. loss_weight=15),
  23. conv_cfg=None,
  24. norm_cfg=dict(type='GN', num_groups=36),
  25. init_cfg=[
  26. dict(type='Kaiming', layer=['Conv2d', 'Linear']),
  27. dict(
  28. type='Normal',
  29. layer='ConvTranspose2d',
  30. std=0.001,
  31. override=dict(
  32. type='Normal',
  33. name='deconv2',
  34. std=0.001,
  35. bias=-np.log(0.99 / 0.01)))
  36. ]):
  37. super(GridHead, self).__init__(init_cfg)
  38. self.grid_points = grid_points
  39. self.num_convs = num_convs
  40. self.roi_feat_size = roi_feat_size
  41. self.in_channels = in_channels
  42. self.conv_kernel_size = conv_kernel_size
  43. self.point_feat_channels = point_feat_channels
  44. self.conv_out_channels = self.point_feat_channels * self.grid_points
  45. self.class_agnostic = class_agnostic
  46. self.conv_cfg = conv_cfg
  47. self.norm_cfg = norm_cfg
  48. if isinstance(norm_cfg, dict) and norm_cfg['type'] == 'GN':
  49. assert self.conv_out_channels % norm_cfg['num_groups'] == 0
  50. assert self.grid_points >= 4
  51. self.grid_size = int(np.sqrt(self.grid_points))
  52. if self.grid_size * self.grid_size != self.grid_points:
  53. raise ValueError('grid_points must be a square number')
  54. # the predicted heatmap is half of whole_map_size
  55. if not isinstance(self.roi_feat_size, int):
  56. raise ValueError('Only square RoIs are supporeted in Grid R-CNN')
  57. self.whole_map_size = self.roi_feat_size * 4
  58. # compute point-wise sub-regions
  59. self.sub_regions = self.calc_sub_regions()
  60. self.convs = []
  61. for i in range(self.num_convs):
  62. in_channels = (
  63. self.in_channels if i == 0 else self.conv_out_channels)
  64. stride = 2 if i == 0 else 1
  65. padding = (self.conv_kernel_size - 1) // 2
  66. self.convs.append(
  67. ConvModule(
  68. in_channels,
  69. self.conv_out_channels,
  70. self.conv_kernel_size,
  71. stride=stride,
  72. padding=padding,
  73. conv_cfg=self.conv_cfg,
  74. norm_cfg=self.norm_cfg,
  75. bias=True))
  76. self.convs = nn.Sequential(*self.convs)
  77. self.deconv1 = nn.ConvTranspose2d(
  78. self.conv_out_channels,
  79. self.conv_out_channels,
  80. kernel_size=deconv_kernel_size,
  81. stride=2,
  82. padding=(deconv_kernel_size - 2) // 2,
  83. groups=grid_points)
  84. self.norm1 = nn.GroupNorm(grid_points, self.conv_out_channels)
  85. self.deconv2 = nn.ConvTranspose2d(
  86. self.conv_out_channels,
  87. grid_points,
  88. kernel_size=deconv_kernel_size,
  89. stride=2,
  90. padding=(deconv_kernel_size - 2) // 2,
  91. groups=grid_points)
  92. # find the 4-neighbor of each grid point
  93. self.neighbor_points = []
  94. grid_size = self.grid_size
  95. for i in range(grid_size): # i-th column
  96. for j in range(grid_size): # j-th row
  97. neighbors = []
  98. if i > 0: # left: (i - 1, j)
  99. neighbors.append((i - 1) * grid_size + j)
  100. if j > 0: # up: (i, j - 1)
  101. neighbors.append(i * grid_size + j - 1)
  102. if j < grid_size - 1: # down: (i, j + 1)
  103. neighbors.append(i * grid_size + j + 1)
  104. if i < grid_size - 1: # right: (i + 1, j)
  105. neighbors.append((i + 1) * grid_size + j)
  106. self.neighbor_points.append(tuple(neighbors))
  107. # total edges in the grid
  108. self.num_edges = sum([len(p) for p in self.neighbor_points])
  109. self.forder_trans = nn.ModuleList() # first-order feature transition
  110. self.sorder_trans = nn.ModuleList() # second-order feature transition
  111. for neighbors in self.neighbor_points:
  112. fo_trans = nn.ModuleList()
  113. so_trans = nn.ModuleList()
  114. for _ in range(len(neighbors)):
  115. # each transition module consists of a 5x5 depth-wise conv and
  116. # 1x1 conv.
  117. fo_trans.append(
  118. nn.Sequential(
  119. nn.Conv2d(
  120. self.point_feat_channels,
  121. self.point_feat_channels,
  122. 5,
  123. stride=1,
  124. padding=2,
  125. groups=self.point_feat_channels),
  126. nn.Conv2d(self.point_feat_channels,
  127. self.point_feat_channels, 1)))
  128. so_trans.append(
  129. nn.Sequential(
  130. nn.Conv2d(
  131. self.point_feat_channels,
  132. self.point_feat_channels,
  133. 5,
  134. 1,
  135. 2,
  136. groups=self.point_feat_channels),
  137. nn.Conv2d(self.point_feat_channels,
  138. self.point_feat_channels, 1)))
  139. self.forder_trans.append(fo_trans)
  140. self.sorder_trans.append(so_trans)
  141. self.loss_grid = build_loss(loss_grid)
  142. def forward(self, x):
  143. assert x.shape[-1] == x.shape[-2] == self.roi_feat_size
  144. # RoI feature transformation, downsample 2x
  145. x = self.convs(x)
  146. c = self.point_feat_channels
  147. # first-order fusion
  148. x_fo = [None for _ in range(self.grid_points)]
  149. for i, points in enumerate(self.neighbor_points):
  150. x_fo[i] = x[:, i * c:(i + 1) * c]
  151. for j, point_idx in enumerate(points):
  152. x_fo[i] = x_fo[i] + self.forder_trans[i][j](
  153. x[:, point_idx * c:(point_idx + 1) * c])
  154. # second-order fusion
  155. x_so = [None for _ in range(self.grid_points)]
  156. for i, points in enumerate(self.neighbor_points):
  157. x_so[i] = x[:, i * c:(i + 1) * c]
  158. for j, point_idx in enumerate(points):
  159. x_so[i] = x_so[i] + self.sorder_trans[i][j](x_fo[point_idx])
  160. # predicted heatmap with fused features
  161. x2 = torch.cat(x_so, dim=1)
  162. x2 = self.deconv1(x2)
  163. x2 = F.relu(self.norm1(x2), inplace=True)
  164. heatmap = self.deconv2(x2)
  165. # predicted heatmap with original features (applicable during training)
  166. if self.training:
  167. x1 = x
  168. x1 = self.deconv1(x1)
  169. x1 = F.relu(self.norm1(x1), inplace=True)
  170. heatmap_unfused = self.deconv2(x1)
  171. else:
  172. heatmap_unfused = heatmap
  173. return dict(fused=heatmap, unfused=heatmap_unfused)
  174. def calc_sub_regions(self):
  175. """Compute point specific representation regions.
  176. See Grid R-CNN Plus (https://arxiv.org/abs/1906.05688) for details.
  177. """
  178. # to make it consistent with the original implementation, half_size
  179. # is computed as 2 * quarter_size, which is smaller
  180. half_size = self.whole_map_size // 4 * 2
  181. sub_regions = []
  182. for i in range(self.grid_points):
  183. x_idx = i // self.grid_size
  184. y_idx = i % self.grid_size
  185. if x_idx == 0:
  186. sub_x1 = 0
  187. elif x_idx == self.grid_size - 1:
  188. sub_x1 = half_size
  189. else:
  190. ratio = x_idx / (self.grid_size - 1) - 0.25
  191. sub_x1 = max(int(ratio * self.whole_map_size), 0)
  192. if y_idx == 0:
  193. sub_y1 = 0
  194. elif y_idx == self.grid_size - 1:
  195. sub_y1 = half_size
  196. else:
  197. ratio = y_idx / (self.grid_size - 1) - 0.25
  198. sub_y1 = max(int(ratio * self.whole_map_size), 0)
  199. sub_regions.append(
  200. (sub_x1, sub_y1, sub_x1 + half_size, sub_y1 + half_size))
  201. return sub_regions
  202. def get_targets(self, sampling_results, rcnn_train_cfg):
  203. # mix all samples (across images) together.
  204. pos_bboxes = torch.cat([res.pos_bboxes for res in sampling_results],
  205. dim=0).cpu()
  206. pos_gt_bboxes = torch.cat(
  207. [res.pos_gt_bboxes for res in sampling_results], dim=0).cpu()
  208. assert pos_bboxes.shape == pos_gt_bboxes.shape
  209. # expand pos_bboxes to 2x of original size
  210. x1 = pos_bboxes[:, 0] - (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
  211. y1 = pos_bboxes[:, 1] - (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
  212. x2 = pos_bboxes[:, 2] + (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
  213. y2 = pos_bboxes[:, 3] + (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
  214. pos_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
  215. pos_bbox_ws = (pos_bboxes[:, 2] - pos_bboxes[:, 0]).unsqueeze(-1)
  216. pos_bbox_hs = (pos_bboxes[:, 3] - pos_bboxes[:, 1]).unsqueeze(-1)
  217. num_rois = pos_bboxes.shape[0]
  218. map_size = self.whole_map_size
  219. # this is not the final target shape
  220. targets = torch.zeros((num_rois, self.grid_points, map_size, map_size),
  221. dtype=torch.float)
  222. # pre-compute interpolation factors for all grid points.
  223. # the first item is the factor of x-dim, and the second is y-dim.
  224. # for a 9-point grid, factors are like (1, 0), (0.5, 0.5), (0, 1)
  225. factors = []
  226. for j in range(self.grid_points):
  227. x_idx = j // self.grid_size
  228. y_idx = j % self.grid_size
  229. factors.append((1 - x_idx / (self.grid_size - 1),
  230. 1 - y_idx / (self.grid_size - 1)))
  231. radius = rcnn_train_cfg.pos_radius
  232. radius2 = radius**2
  233. for i in range(num_rois):
  234. # ignore small bboxes
  235. if (pos_bbox_ws[i] <= self.grid_size
  236. or pos_bbox_hs[i] <= self.grid_size):
  237. continue
  238. # for each grid point, mark a small circle as positive
  239. for j in range(self.grid_points):
  240. factor_x, factor_y = factors[j]
  241. gridpoint_x = factor_x * pos_gt_bboxes[i, 0] + (
  242. 1 - factor_x) * pos_gt_bboxes[i, 2]
  243. gridpoint_y = factor_y * pos_gt_bboxes[i, 1] + (
  244. 1 - factor_y) * pos_gt_bboxes[i, 3]
  245. cx = int((gridpoint_x - pos_bboxes[i, 0]) / pos_bbox_ws[i] *
  246. map_size)
  247. cy = int((gridpoint_y - pos_bboxes[i, 1]) / pos_bbox_hs[i] *
  248. map_size)
  249. for x in range(cx - radius, cx + radius + 1):
  250. for y in range(cy - radius, cy + radius + 1):
  251. if x >= 0 and x < map_size and y >= 0 and y < map_size:
  252. if (x - cx)**2 + (y - cy)**2 <= radius2:
  253. targets[i, j, y, x] = 1
  254. # reduce the target heatmap size by a half
  255. # proposed in Grid R-CNN Plus (https://arxiv.org/abs/1906.05688).
  256. sub_targets = []
  257. for i in range(self.grid_points):
  258. sub_x1, sub_y1, sub_x2, sub_y2 = self.sub_regions[i]
  259. sub_targets.append(targets[:, [i], sub_y1:sub_y2, sub_x1:sub_x2])
  260. sub_targets = torch.cat(sub_targets, dim=1)
  261. sub_targets = sub_targets.to(sampling_results[0].pos_bboxes.device)
  262. return sub_targets
  263. def loss(self, grid_pred, grid_targets):
  264. loss_fused = self.loss_grid(grid_pred['fused'], grid_targets)
  265. loss_unfused = self.loss_grid(grid_pred['unfused'], grid_targets)
  266. loss_grid = loss_fused + loss_unfused
  267. return dict(loss_grid=loss_grid)
  268. def get_bboxes(self, det_bboxes, grid_pred, img_metas):
  269. # TODO: refactoring
  270. assert det_bboxes.shape[0] == grid_pred.shape[0]
  271. det_bboxes = det_bboxes.cpu()
  272. cls_scores = det_bboxes[:, [4]]
  273. det_bboxes = det_bboxes[:, :4]
  274. grid_pred = grid_pred.sigmoid().cpu()
  275. R, c, h, w = grid_pred.shape
  276. half_size = self.whole_map_size // 4 * 2
  277. assert h == w == half_size
  278. assert c == self.grid_points
  279. # find the point with max scores in the half-sized heatmap
  280. grid_pred = grid_pred.view(R * c, h * w)
  281. pred_scores, pred_position = grid_pred.max(dim=1)
  282. xs = pred_position % w
  283. ys = pred_position // w
  284. # get the position in the whole heatmap instead of half-sized heatmap
  285. for i in range(self.grid_points):
  286. xs[i::self.grid_points] += self.sub_regions[i][0]
  287. ys[i::self.grid_points] += self.sub_regions[i][1]
  288. # reshape to (num_rois, grid_points)
  289. pred_scores, xs, ys = tuple(
  290. map(lambda x: x.view(R, c), [pred_scores, xs, ys]))
  291. # get expanded pos_bboxes
  292. widths = (det_bboxes[:, 2] - det_bboxes[:, 0]).unsqueeze(-1)
  293. heights = (det_bboxes[:, 3] - det_bboxes[:, 1]).unsqueeze(-1)
  294. x1 = (det_bboxes[:, 0, None] - widths / 2)
  295. y1 = (det_bboxes[:, 1, None] - heights / 2)
  296. # map the grid point to the absolute coordinates
  297. abs_xs = (xs.float() + 0.5) / w * widths + x1
  298. abs_ys = (ys.float() + 0.5) / h * heights + y1
  299. # get the grid points indices that fall on the bbox boundaries
  300. x1_inds = [i for i in range(self.grid_size)]
  301. y1_inds = [i * self.grid_size for i in range(self.grid_size)]
  302. x2_inds = [
  303. self.grid_points - self.grid_size + i
  304. for i in range(self.grid_size)
  305. ]
  306. y2_inds = [(i + 1) * self.grid_size - 1 for i in range(self.grid_size)]
  307. # voting of all grid points on some boundary
  308. bboxes_x1 = (abs_xs[:, x1_inds] * pred_scores[:, x1_inds]).sum(
  309. dim=1, keepdim=True) / (
  310. pred_scores[:, x1_inds].sum(dim=1, keepdim=True))
  311. bboxes_y1 = (abs_ys[:, y1_inds] * pred_scores[:, y1_inds]).sum(
  312. dim=1, keepdim=True) / (
  313. pred_scores[:, y1_inds].sum(dim=1, keepdim=True))
  314. bboxes_x2 = (abs_xs[:, x2_inds] * pred_scores[:, x2_inds]).sum(
  315. dim=1, keepdim=True) / (
  316. pred_scores[:, x2_inds].sum(dim=1, keepdim=True))
  317. bboxes_y2 = (abs_ys[:, y2_inds] * pred_scores[:, y2_inds]).sum(
  318. dim=1, keepdim=True) / (
  319. pred_scores[:, y2_inds].sum(dim=1, keepdim=True))
  320. bbox_res = torch.cat(
  321. [bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2, cls_scores], dim=1)
  322. bbox_res[:, [0, 2]].clamp_(min=0, max=img_metas[0]['img_shape'][1])
  323. bbox_res[:, [1, 3]].clamp_(min=0, max=img_metas[0]['img_shape'][0])
  324. return bbox_res

No Description

Contributors (2)