You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hourglass.py 7.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from mmcv.cnn import ConvModule
  5. from mmcv.runner import BaseModule
  6. from ..builder import BACKBONES
  7. from ..utils import ResLayer
  8. from .resnet import BasicBlock
  9. class HourglassModule(BaseModule):
  10. """Hourglass Module for HourglassNet backbone.
  11. Generate module recursively and use BasicBlock as the base unit.
  12. Args:
  13. depth (int): Depth of current HourglassModule.
  14. stage_channels (list[int]): Feature channels of sub-modules in current
  15. and follow-up HourglassModule.
  16. stage_blocks (list[int]): Number of sub-modules stacked in current and
  17. follow-up HourglassModule.
  18. norm_cfg (dict): Dictionary to construct and config norm layer.
  19. init_cfg (dict or list[dict], optional): Initialization config dict.
  20. Default: None
  21. upsample_cfg (dict, optional): Config dict for interpolate layer.
  22. Default: `dict(mode='nearest')`
  23. """
  24. def __init__(self,
  25. depth,
  26. stage_channels,
  27. stage_blocks,
  28. norm_cfg=dict(type='BN', requires_grad=True),
  29. init_cfg=None,
  30. upsample_cfg=dict(mode='nearest')):
  31. super(HourglassModule, self).__init__(init_cfg)
  32. self.depth = depth
  33. cur_block = stage_blocks[0]
  34. next_block = stage_blocks[1]
  35. cur_channel = stage_channels[0]
  36. next_channel = stage_channels[1]
  37. self.up1 = ResLayer(
  38. BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg)
  39. self.low1 = ResLayer(
  40. BasicBlock,
  41. cur_channel,
  42. next_channel,
  43. cur_block,
  44. stride=2,
  45. norm_cfg=norm_cfg)
  46. if self.depth > 1:
  47. self.low2 = HourglassModule(depth - 1, stage_channels[1:],
  48. stage_blocks[1:])
  49. else:
  50. self.low2 = ResLayer(
  51. BasicBlock,
  52. next_channel,
  53. next_channel,
  54. next_block,
  55. norm_cfg=norm_cfg)
  56. self.low3 = ResLayer(
  57. BasicBlock,
  58. next_channel,
  59. cur_channel,
  60. cur_block,
  61. norm_cfg=norm_cfg,
  62. downsample_first=False)
  63. self.up2 = F.interpolate
  64. self.upsample_cfg = upsample_cfg
  65. def forward(self, x):
  66. """Forward function."""
  67. up1 = self.up1(x)
  68. low1 = self.low1(x)
  69. low2 = self.low2(low1)
  70. low3 = self.low3(low2)
  71. # Fixing `scale factor` (e.g. 2) is common for upsampling, but
  72. # in some cases the spatial size is mismatched and error will arise.
  73. if 'scale_factor' in self.upsample_cfg:
  74. up2 = self.up2(low3, **self.upsample_cfg)
  75. else:
  76. shape = up1.shape[2:]
  77. up2 = self.up2(low3, size=shape, **self.upsample_cfg)
  78. return up1 + up2
  79. @BACKBONES.register_module()
  80. class HourglassNet(BaseModule):
  81. """HourglassNet backbone.
  82. Stacked Hourglass Networks for Human Pose Estimation.
  83. More details can be found in the `paper
  84. <https://arxiv.org/abs/1603.06937>`_ .
  85. Args:
  86. downsample_times (int): Downsample times in a HourglassModule.
  87. num_stacks (int): Number of HourglassModule modules stacked,
  88. 1 for Hourglass-52, 2 for Hourglass-104.
  89. stage_channels (list[int]): Feature channel of each sub-module in a
  90. HourglassModule.
  91. stage_blocks (list[int]): Number of sub-modules stacked in a
  92. HourglassModule.
  93. feat_channel (int): Feature channel of conv after a HourglassModule.
  94. norm_cfg (dict): Dictionary to construct and config norm layer.
  95. pretrained (str, optional): model pretrained path. Default: None
  96. init_cfg (dict or list[dict], optional): Initialization config dict.
  97. Default: None
  98. Example:
  99. >>> from mmdet.models import HourglassNet
  100. >>> import torch
  101. >>> self = HourglassNet()
  102. >>> self.eval()
  103. >>> inputs = torch.rand(1, 3, 511, 511)
  104. >>> level_outputs = self.forward(inputs)
  105. >>> for level_output in level_outputs:
  106. ... print(tuple(level_output.shape))
  107. (1, 256, 128, 128)
  108. (1, 256, 128, 128)
  109. """
  110. def __init__(self,
  111. downsample_times=5,
  112. num_stacks=2,
  113. stage_channels=(256, 256, 384, 384, 384, 512),
  114. stage_blocks=(2, 2, 2, 2, 2, 4),
  115. feat_channel=256,
  116. norm_cfg=dict(type='BN', requires_grad=True),
  117. pretrained=None,
  118. init_cfg=None):
  119. assert init_cfg is None, 'To prevent abnormal initialization ' \
  120. 'behavior, init_cfg is not allowed to be set'
  121. super(HourglassNet, self).__init__(init_cfg)
  122. self.num_stacks = num_stacks
  123. assert self.num_stacks >= 1
  124. assert len(stage_channels) == len(stage_blocks)
  125. assert len(stage_channels) > downsample_times
  126. cur_channel = stage_channels[0]
  127. self.stem = nn.Sequential(
  128. ConvModule(
  129. 3, cur_channel // 2, 7, padding=3, stride=2,
  130. norm_cfg=norm_cfg),
  131. ResLayer(
  132. BasicBlock,
  133. cur_channel // 2,
  134. cur_channel,
  135. 1,
  136. stride=2,
  137. norm_cfg=norm_cfg))
  138. self.hourglass_modules = nn.ModuleList([
  139. HourglassModule(downsample_times, stage_channels, stage_blocks)
  140. for _ in range(num_stacks)
  141. ])
  142. self.inters = ResLayer(
  143. BasicBlock,
  144. cur_channel,
  145. cur_channel,
  146. num_stacks - 1,
  147. norm_cfg=norm_cfg)
  148. self.conv1x1s = nn.ModuleList([
  149. ConvModule(
  150. cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
  151. for _ in range(num_stacks - 1)
  152. ])
  153. self.out_convs = nn.ModuleList([
  154. ConvModule(
  155. cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
  156. for _ in range(num_stacks)
  157. ])
  158. self.remap_convs = nn.ModuleList([
  159. ConvModule(
  160. feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
  161. for _ in range(num_stacks - 1)
  162. ])
  163. self.relu = nn.ReLU(inplace=True)
  164. def init_weights(self):
  165. """Init module weights."""
  166. # Training Centripetal Model needs to reset parameters for Conv2d
  167. super(HourglassNet, self).init_weights()
  168. for m in self.modules():
  169. if isinstance(m, nn.Conv2d):
  170. m.reset_parameters()
  171. def forward(self, x):
  172. """Forward function."""
  173. inter_feat = self.stem(x)
  174. out_feats = []
  175. for ind in range(self.num_stacks):
  176. single_hourglass = self.hourglass_modules[ind]
  177. out_conv = self.out_convs[ind]
  178. hourglass_feat = single_hourglass(inter_feat)
  179. out_feat = out_conv(hourglass_feat)
  180. out_feats.append(out_feat)
  181. if ind < self.num_stacks - 1:
  182. inter_feat = self.conv1x1s[ind](
  183. inter_feat) + self.remap_convs[ind](
  184. out_feat)
  185. inter_feat = self.inters[ind](self.relu(inter_feat))
  186. return out_feats

No Description

Contributors (3)