You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

retina_sepbn_head.py 4.6 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
  4. from ..builder import HEADS
  5. from .anchor_head import AnchorHead
  6. @HEADS.register_module()
  7. class RetinaSepBNHead(AnchorHead):
  8. """"RetinaHead with separate BN.
  9. In RetinaHead, conv/norm layers are shared across different FPN levels,
  10. while in RetinaSepBNHead, conv layers are shared across different FPN
  11. levels, but BN layers are separated.
  12. """
  13. def __init__(self,
  14. num_classes,
  15. num_ins,
  16. in_channels,
  17. stacked_convs=4,
  18. conv_cfg=None,
  19. norm_cfg=None,
  20. init_cfg=None,
  21. **kwargs):
  22. assert init_cfg is None, 'To prevent abnormal initialization ' \
  23. 'behavior, init_cfg is not allowed to be set'
  24. self.stacked_convs = stacked_convs
  25. self.conv_cfg = conv_cfg
  26. self.norm_cfg = norm_cfg
  27. self.num_ins = num_ins
  28. super(RetinaSepBNHead, self).__init__(
  29. num_classes, in_channels, init_cfg=init_cfg, **kwargs)
  30. def _init_layers(self):
  31. """Initialize layers of the head."""
  32. self.relu = nn.ReLU(inplace=True)
  33. self.cls_convs = nn.ModuleList()
  34. self.reg_convs = nn.ModuleList()
  35. for i in range(self.num_ins):
  36. cls_convs = nn.ModuleList()
  37. reg_convs = nn.ModuleList()
  38. for i in range(self.stacked_convs):
  39. chn = self.in_channels if i == 0 else self.feat_channels
  40. cls_convs.append(
  41. ConvModule(
  42. chn,
  43. self.feat_channels,
  44. 3,
  45. stride=1,
  46. padding=1,
  47. conv_cfg=self.conv_cfg,
  48. norm_cfg=self.norm_cfg))
  49. reg_convs.append(
  50. ConvModule(
  51. chn,
  52. self.feat_channels,
  53. 3,
  54. stride=1,
  55. padding=1,
  56. conv_cfg=self.conv_cfg,
  57. norm_cfg=self.norm_cfg))
  58. self.cls_convs.append(cls_convs)
  59. self.reg_convs.append(reg_convs)
  60. for i in range(self.stacked_convs):
  61. for j in range(1, self.num_ins):
  62. self.cls_convs[j][i].conv = self.cls_convs[0][i].conv
  63. self.reg_convs[j][i].conv = self.reg_convs[0][i].conv
  64. self.retina_cls = nn.Conv2d(
  65. self.feat_channels,
  66. self.num_base_priors * self.cls_out_channels,
  67. 3,
  68. padding=1)
  69. self.retina_reg = nn.Conv2d(
  70. self.feat_channels, self.num_base_priors * 4, 3, padding=1)
  71. def init_weights(self):
  72. """Initialize weights of the head."""
  73. super(RetinaSepBNHead, self).init_weights()
  74. for m in self.cls_convs[0]:
  75. normal_init(m.conv, std=0.01)
  76. for m in self.reg_convs[0]:
  77. normal_init(m.conv, std=0.01)
  78. bias_cls = bias_init_with_prob(0.01)
  79. normal_init(self.retina_cls, std=0.01, bias=bias_cls)
  80. normal_init(self.retina_reg, std=0.01)
  81. def forward(self, feats):
  82. """Forward features from the upstream network.
  83. Args:
  84. feats (tuple[Tensor]): Features from the upstream network, each is
  85. a 4D-tensor.
  86. Returns:
  87. tuple: Usually a tuple of classification scores and bbox prediction
  88. cls_scores (list[Tensor]): Classification scores for all scale
  89. levels, each is a 4D-tensor, the channels number is
  90. num_anchors * num_classes.
  91. bbox_preds (list[Tensor]): Box energies / deltas for all scale
  92. levels, each is a 4D-tensor, the channels number is
  93. num_anchors * 4.
  94. """
  95. cls_scores = []
  96. bbox_preds = []
  97. for i, x in enumerate(feats):
  98. cls_feat = feats[i]
  99. reg_feat = feats[i]
  100. for cls_conv in self.cls_convs[i]:
  101. cls_feat = cls_conv(cls_feat)
  102. for reg_conv in self.reg_convs[i]:
  103. reg_feat = reg_conv(reg_feat)
  104. cls_score = self.retina_cls(cls_feat)
  105. bbox_pred = self.retina_reg(reg_feat)
  106. cls_scores.append(cls_score)
  107. bbox_preds.append(bbox_pred)
  108. return cls_scores, bbox_preds

No Description

Contributors (3)