You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inverted_residual.py 4.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.utils.checkpoint as cp
  3. from mmcv.cnn import ConvModule
  4. from mmcv.runner import BaseModule
  5. from .se_layer import SELayer
  6. class InvertedResidual(BaseModule):
  7. """Inverted Residual Block.
  8. Args:
  9. in_channels (int): The input channels of this Module.
  10. out_channels (int): The output channels of this Module.
  11. mid_channels (int): The input channels of the depthwise convolution.
  12. kernel_size (int): The kernel size of the depthwise convolution.
  13. Default: 3.
  14. stride (int): The stride of the depthwise convolution. Default: 1.
  15. se_cfg (dict): Config dict for se layer. Default: None, which means no
  16. se layer.
  17. with_expand_conv (bool): Use expand conv or not. If set False,
  18. mid_channels must be the same with in_channels.
  19. Default: True.
  20. conv_cfg (dict): Config dict for convolution layer. Default: None,
  21. which means using conv2d.
  22. norm_cfg (dict): Config dict for normalization layer.
  23. Default: dict(type='BN').
  24. act_cfg (dict): Config dict for activation layer.
  25. Default: dict(type='ReLU').
  26. with_cp (bool): Use checkpoint or not. Using checkpoint will save some
  27. memory while slowing down the training speed. Default: False.
  28. init_cfg (dict or list[dict], optional): Initialization config dict.
  29. Default: None
  30. Returns:
  31. Tensor: The output tensor.
  32. """
  33. def __init__(self,
  34. in_channels,
  35. out_channels,
  36. mid_channels,
  37. kernel_size=3,
  38. stride=1,
  39. se_cfg=None,
  40. with_expand_conv=True,
  41. conv_cfg=None,
  42. norm_cfg=dict(type='BN'),
  43. act_cfg=dict(type='ReLU'),
  44. with_cp=False,
  45. init_cfg=None):
  46. super(InvertedResidual, self).__init__(init_cfg)
  47. self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
  48. assert stride in [1, 2], f'stride must in [1, 2]. ' \
  49. f'But received {stride}.'
  50. self.with_cp = with_cp
  51. self.with_se = se_cfg is not None
  52. self.with_expand_conv = with_expand_conv
  53. if self.with_se:
  54. assert isinstance(se_cfg, dict)
  55. if not self.with_expand_conv:
  56. assert mid_channels == in_channels
  57. if self.with_expand_conv:
  58. self.expand_conv = ConvModule(
  59. in_channels=in_channels,
  60. out_channels=mid_channels,
  61. kernel_size=1,
  62. stride=1,
  63. padding=0,
  64. conv_cfg=conv_cfg,
  65. norm_cfg=norm_cfg,
  66. act_cfg=act_cfg)
  67. self.depthwise_conv = ConvModule(
  68. in_channels=mid_channels,
  69. out_channels=mid_channels,
  70. kernel_size=kernel_size,
  71. stride=stride,
  72. padding=kernel_size // 2,
  73. groups=mid_channels,
  74. conv_cfg=conv_cfg,
  75. norm_cfg=norm_cfg,
  76. act_cfg=act_cfg)
  77. if self.with_se:
  78. self.se = SELayer(**se_cfg)
  79. self.linear_conv = ConvModule(
  80. in_channels=mid_channels,
  81. out_channels=out_channels,
  82. kernel_size=1,
  83. stride=1,
  84. padding=0,
  85. conv_cfg=conv_cfg,
  86. norm_cfg=norm_cfg,
  87. act_cfg=None)
  88. def forward(self, x):
  89. def _inner_forward(x):
  90. out = x
  91. if self.with_expand_conv:
  92. out = self.expand_conv(out)
  93. out = self.depthwise_conv(out)
  94. if self.with_se:
  95. out = self.se(out)
  96. out = self.linear_conv(out)
  97. if self.with_res_shortcut:
  98. return x + out
  99. else:
  100. return out
  101. if self.with_cp and x.requires_grad:
  102. out = cp.checkpoint(_inner_forward, x)
  103. else:
  104. out = _inner_forward(x)
  105. return out

No Description

Contributors (1)