You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

se_layer.py 2.2 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import mmcv
  3. import torch.nn as nn
  4. from mmcv.cnn import ConvModule
  5. from mmcv.runner import BaseModule
  6. class SELayer(BaseModule):
  7. """Squeeze-and-Excitation Module.
  8. Args:
  9. channels (int): The input (and output) channels of the SE layer.
  10. ratio (int): Squeeze ratio in SELayer, the intermediate channel will be
  11. ``int(channels/ratio)``. Default: 16.
  12. conv_cfg (None or dict): Config dict for convolution layer.
  13. Default: None, which means using conv2d.
  14. act_cfg (dict or Sequence[dict]): Config dict for activation layer.
  15. If act_cfg is a dict, two activation layers will be configurated
  16. by this dict. If act_cfg is a sequence of dicts, the first
  17. activation layer will be configurated by the first dict and the
  18. second activation layer will be configurated by the second dict.
  19. Default: (dict(type='ReLU'), dict(type='Sigmoid'))
  20. init_cfg (dict or list[dict], optional): Initialization config dict.
  21. Default: None
  22. """
  23. def __init__(self,
  24. channels,
  25. ratio=16,
  26. conv_cfg=None,
  27. act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
  28. init_cfg=None):
  29. super(SELayer, self).__init__(init_cfg)
  30. if isinstance(act_cfg, dict):
  31. act_cfg = (act_cfg, act_cfg)
  32. assert len(act_cfg) == 2
  33. assert mmcv.is_tuple_of(act_cfg, dict)
  34. self.global_avgpool = nn.AdaptiveAvgPool2d(1)
  35. self.conv1 = ConvModule(
  36. in_channels=channels,
  37. out_channels=int(channels / ratio),
  38. kernel_size=1,
  39. stride=1,
  40. conv_cfg=conv_cfg,
  41. act_cfg=act_cfg[0])
  42. self.conv2 = ConvModule(
  43. in_channels=int(channels / ratio),
  44. out_channels=channels,
  45. kernel_size=1,
  46. stride=1,
  47. conv_cfg=conv_cfg,
  48. act_cfg=act_cfg[1])
  49. def forward(self, x):
  50. out = self.global_avgpool(x)
  51. out = self.conv1(out)
  52. out = self.conv2(out)
  53. return x * out

No Description

Contributors (2)