You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

res_layer.py 6.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from mmcv.cnn import build_conv_layer, build_norm_layer
  3. from mmcv.runner import BaseModule, Sequential
  4. from torch import nn as nn
  5. class ResLayer(Sequential):
  6. """ResLayer to build ResNet style backbone.
  7. Args:
  8. block (nn.Module): block used to build ResLayer.
  9. inplanes (int): inplanes of block.
  10. planes (int): planes of block.
  11. num_blocks (int): number of blocks.
  12. stride (int): stride of the first block. Default: 1
  13. avg_down (bool): Use AvgPool instead of stride conv when
  14. downsampling in the bottleneck. Default: False
  15. conv_cfg (dict): dictionary to construct and config conv layer.
  16. Default: None
  17. norm_cfg (dict): dictionary to construct and config norm layer.
  18. Default: dict(type='BN')
  19. downsample_first (bool): Downsample at the first block or last block.
  20. False for Hourglass, True for ResNet. Default: True
  21. """
  22. def __init__(self,
  23. block,
  24. inplanes,
  25. planes,
  26. num_blocks,
  27. stride=1,
  28. avg_down=False,
  29. conv_cfg=None,
  30. norm_cfg=dict(type='BN'),
  31. downsample_first=True,
  32. **kwargs):
  33. self.block = block
  34. downsample = None
  35. if stride != 1 or inplanes != planes * block.expansion:
  36. downsample = []
  37. conv_stride = stride
  38. if avg_down:
  39. conv_stride = 1
  40. downsample.append(
  41. nn.AvgPool2d(
  42. kernel_size=stride,
  43. stride=stride,
  44. ceil_mode=True,
  45. count_include_pad=False))
  46. downsample.extend([
  47. build_conv_layer(
  48. conv_cfg,
  49. inplanes,
  50. planes * block.expansion,
  51. kernel_size=1,
  52. stride=conv_stride,
  53. bias=False),
  54. build_norm_layer(norm_cfg, planes * block.expansion)[1]
  55. ])
  56. downsample = nn.Sequential(*downsample)
  57. layers = []
  58. if downsample_first:
  59. layers.append(
  60. block(
  61. inplanes=inplanes,
  62. planes=planes,
  63. stride=stride,
  64. downsample=downsample,
  65. conv_cfg=conv_cfg,
  66. norm_cfg=norm_cfg,
  67. **kwargs))
  68. inplanes = planes * block.expansion
  69. for _ in range(1, num_blocks):
  70. layers.append(
  71. block(
  72. inplanes=inplanes,
  73. planes=planes,
  74. stride=1,
  75. conv_cfg=conv_cfg,
  76. norm_cfg=norm_cfg,
  77. **kwargs))
  78. else: # downsample_first=False is for HourglassModule
  79. for _ in range(num_blocks - 1):
  80. layers.append(
  81. block(
  82. inplanes=inplanes,
  83. planes=inplanes,
  84. stride=1,
  85. conv_cfg=conv_cfg,
  86. norm_cfg=norm_cfg,
  87. **kwargs))
  88. layers.append(
  89. block(
  90. inplanes=inplanes,
  91. planes=planes,
  92. stride=stride,
  93. downsample=downsample,
  94. conv_cfg=conv_cfg,
  95. norm_cfg=norm_cfg,
  96. **kwargs))
  97. super(ResLayer, self).__init__(*layers)
  98. class SimplifiedBasicBlock(BaseModule):
  99. """Simplified version of original basic residual block. This is used in
  100. `SCNet <https://arxiv.org/abs/2012.10150>`_.
  101. - Norm layer is now optional
  102. - Last ReLU in forward function is removed
  103. """
  104. expansion = 1
  105. def __init__(self,
  106. inplanes,
  107. planes,
  108. stride=1,
  109. dilation=1,
  110. downsample=None,
  111. style='pytorch',
  112. with_cp=False,
  113. conv_cfg=None,
  114. norm_cfg=dict(type='BN'),
  115. dcn=None,
  116. plugins=None,
  117. init_fg=None):
  118. super(SimplifiedBasicBlock, self).__init__(init_fg)
  119. assert dcn is None, 'Not implemented yet.'
  120. assert plugins is None, 'Not implemented yet.'
  121. assert not with_cp, 'Not implemented yet.'
  122. self.with_norm = norm_cfg is not None
  123. with_bias = True if norm_cfg is None else False
  124. self.conv1 = build_conv_layer(
  125. conv_cfg,
  126. inplanes,
  127. planes,
  128. 3,
  129. stride=stride,
  130. padding=dilation,
  131. dilation=dilation,
  132. bias=with_bias)
  133. if self.with_norm:
  134. self.norm1_name, norm1 = build_norm_layer(
  135. norm_cfg, planes, postfix=1)
  136. self.add_module(self.norm1_name, norm1)
  137. self.conv2 = build_conv_layer(
  138. conv_cfg, planes, planes, 3, padding=1, bias=with_bias)
  139. if self.with_norm:
  140. self.norm2_name, norm2 = build_norm_layer(
  141. norm_cfg, planes, postfix=2)
  142. self.add_module(self.norm2_name, norm2)
  143. self.relu = nn.ReLU(inplace=True)
  144. self.downsample = downsample
  145. self.stride = stride
  146. self.dilation = dilation
  147. self.with_cp = with_cp
  148. @property
  149. def norm1(self):
  150. """nn.Module: normalization layer after the first convolution layer"""
  151. return getattr(self, self.norm1_name) if self.with_norm else None
  152. @property
  153. def norm2(self):
  154. """nn.Module: normalization layer after the second convolution layer"""
  155. return getattr(self, self.norm2_name) if self.with_norm else None
  156. def forward(self, x):
  157. """Forward function."""
  158. identity = x
  159. out = self.conv1(x)
  160. if self.with_norm:
  161. out = self.norm1(out)
  162. out = self.relu(out)
  163. out = self.conv2(out)
  164. if self.with_norm:
  165. out = self.norm2(out)
  166. if self.downsample is not None:
  167. identity = self.downsample(x)
  168. out += identity
  169. return out

No Description

Contributors (1)