You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ct_resnet_neck.py 3.6 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import math
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule
  5. from mmcv.runner import BaseModule, auto_fp16
  6. from mmdet.models.builder import NECKS
  7. @NECKS.register_module()
  8. class CTResNetNeck(BaseModule):
  9. """The neck used in `CenterNet <https://arxiv.org/abs/1904.07850>`_ for
  10. object classification and box regression.
  11. Args:
  12. in_channel (int): Number of input channels.
  13. num_deconv_filters (tuple[int]): Number of filters per stage.
  14. num_deconv_kernels (tuple[int]): Number of kernels per stage.
  15. use_dcn (bool): If True, use DCNv2. Default: True.
  16. init_cfg (dict or list[dict], optional): Initialization config dict.
  17. """
  18. def __init__(self,
  19. in_channel,
  20. num_deconv_filters,
  21. num_deconv_kernels,
  22. use_dcn=True,
  23. init_cfg=None):
  24. super(CTResNetNeck, self).__init__(init_cfg)
  25. assert len(num_deconv_filters) == len(num_deconv_kernels)
  26. self.fp16_enabled = False
  27. self.use_dcn = use_dcn
  28. self.in_channel = in_channel
  29. self.deconv_layers = self._make_deconv_layer(num_deconv_filters,
  30. num_deconv_kernels)
  31. def _make_deconv_layer(self, num_deconv_filters, num_deconv_kernels):
  32. """use deconv layers to upsample backbone's output."""
  33. layers = []
  34. for i in range(len(num_deconv_filters)):
  35. feat_channel = num_deconv_filters[i]
  36. conv_module = ConvModule(
  37. self.in_channel,
  38. feat_channel,
  39. 3,
  40. padding=1,
  41. conv_cfg=dict(type='DCNv2') if self.use_dcn else None,
  42. norm_cfg=dict(type='BN'))
  43. layers.append(conv_module)
  44. upsample_module = ConvModule(
  45. feat_channel,
  46. feat_channel,
  47. num_deconv_kernels[i],
  48. stride=2,
  49. padding=1,
  50. conv_cfg=dict(type='deconv'),
  51. norm_cfg=dict(type='BN'))
  52. layers.append(upsample_module)
  53. self.in_channel = feat_channel
  54. return nn.Sequential(*layers)
  55. def init_weights(self):
  56. for m in self.modules():
  57. if isinstance(m, nn.ConvTranspose2d):
  58. # In order to be consistent with the source code,
  59. # reset the ConvTranspose2d initialization parameters
  60. m.reset_parameters()
  61. # Simulated bilinear upsampling kernel
  62. w = m.weight.data
  63. f = math.ceil(w.size(2) / 2)
  64. c = (2 * f - 1 - f % 2) / (2. * f)
  65. for i in range(w.size(2)):
  66. for j in range(w.size(3)):
  67. w[0, 0, i, j] = \
  68. (1 - math.fabs(i / f - c)) * (
  69. 1 - math.fabs(j / f - c))
  70. for c in range(1, w.size(0)):
  71. w[c, 0, :, :] = w[0, 0, :, :]
  72. elif isinstance(m, nn.BatchNorm2d):
  73. nn.init.constant_(m.weight, 1)
  74. nn.init.constant_(m.bias, 0)
  75. # self.use_dcn is False
  76. elif not self.use_dcn and isinstance(m, nn.Conv2d):
  77. # In order to be consistent with the source code,
  78. # reset the Conv2d initialization parameters
  79. m.reset_parameters()
  80. @auto_fp16()
  81. def forward(self, inputs):
  82. assert isinstance(inputs, (list, tuple))
  83. outs = self.deconv_layers(inputs[-1])
  84. return outs,

No Description

Contributors (2)