You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

channel_mapper.py 4.0 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import torch.nn as nn
  3. from mmcv.cnn import ConvModule
  4. from mmcv.runner import BaseModule
  5. from ..builder import NECKS
  6. @NECKS.register_module()
  7. class ChannelMapper(BaseModule):
  8. r"""Channel Mapper to reduce/increase channels of backbone features.
  9. This is used to reduce/increase channels of backbone features.
  10. Args:
  11. in_channels (List[int]): Number of input channels per scale.
  12. out_channels (int): Number of output channels (used at each scale).
  13. kernel_size (int, optional): kernel_size for reducing channels (used
  14. at each scale). Default: 3.
  15. conv_cfg (dict, optional): Config dict for convolution layer.
  16. Default: None.
  17. norm_cfg (dict, optional): Config dict for normalization layer.
  18. Default: None.
  19. act_cfg (dict, optional): Config dict for activation layer in
  20. ConvModule. Default: dict(type='ReLU').
  21. num_outs (int, optional): Number of output feature maps. There
  22. would be extra_convs when num_outs larger than the length
  23. of in_channels.
  24. init_cfg (dict or list[dict], optional): Initialization config dict.
  25. Example:
  26. >>> import torch
  27. >>> in_channels = [2, 3, 5, 7]
  28. >>> scales = [340, 170, 84, 43]
  29. >>> inputs = [torch.rand(1, c, s, s)
  30. ... for c, s in zip(in_channels, scales)]
  31. >>> self = ChannelMapper(in_channels, 11, 3).eval()
  32. >>> outputs = self.forward(inputs)
  33. >>> for i in range(len(outputs)):
  34. ... print(f'outputs[{i}].shape = {outputs[i].shape}')
  35. outputs[0].shape = torch.Size([1, 11, 340, 340])
  36. outputs[1].shape = torch.Size([1, 11, 170, 170])
  37. outputs[2].shape = torch.Size([1, 11, 84, 84])
  38. outputs[3].shape = torch.Size([1, 11, 43, 43])
  39. """
  40. def __init__(self,
  41. in_channels,
  42. out_channels,
  43. kernel_size=3,
  44. conv_cfg=None,
  45. norm_cfg=None,
  46. act_cfg=dict(type='ReLU'),
  47. num_outs=None,
  48. init_cfg=dict(
  49. type='Xavier', layer='Conv2d', distribution='uniform')):
  50. super(ChannelMapper, self).__init__(init_cfg)
  51. assert isinstance(in_channels, list)
  52. self.extra_convs = None
  53. if num_outs is None:
  54. num_outs = len(in_channels)
  55. self.convs = nn.ModuleList()
  56. for in_channel in in_channels:
  57. self.convs.append(
  58. ConvModule(
  59. in_channel,
  60. out_channels,
  61. kernel_size,
  62. padding=(kernel_size - 1) // 2,
  63. conv_cfg=conv_cfg,
  64. norm_cfg=norm_cfg,
  65. act_cfg=act_cfg))
  66. if num_outs > len(in_channels):
  67. self.extra_convs = nn.ModuleList()
  68. for i in range(len(in_channels), num_outs):
  69. if i == len(in_channels):
  70. in_channel = in_channels[-1]
  71. else:
  72. in_channel = out_channels
  73. self.extra_convs.append(
  74. ConvModule(
  75. in_channel,
  76. out_channels,
  77. 3,
  78. stride=2,
  79. padding=1,
  80. conv_cfg=conv_cfg,
  81. norm_cfg=norm_cfg,
  82. act_cfg=act_cfg))
  83. def forward(self, inputs):
  84. """Forward function."""
  85. assert len(inputs) == len(self.convs)
  86. outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
  87. if self.extra_convs:
  88. for i in range(len(self.extra_convs)):
  89. if i == 0:
  90. outs.append(self.extra_convs[0](inputs[-1]))
  91. else:
  92. outs.append(self.extra_convs[i](outs[-1]))
  93. return tuple(outs)

No Description

Contributors (1)