You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_se_layer.py 674 B

2 years ago
123456789101112131415161718192021222324
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models.utils import SELayer
  5. def test_se_layer():
  6. with pytest.raises(AssertionError):
  7. # act_cfg sequence length must equal to 2
  8. SELayer(channels=32, act_cfg=(dict(type='ReLU'), ))
  9. with pytest.raises(AssertionError):
  10. # act_cfg sequence must be a tuple of dict
  11. SELayer(channels=32, act_cfg=[dict(type='ReLU'), dict(type='ReLU')])
  12. # Test SELayer forward
  13. layer = SELayer(channels=32)
  14. layer.init_weights()
  15. layer.train()
  16. x = torch.randn((1, 32, 10, 10))
  17. x_out = layer(x)
  18. assert x_out.shape == torch.Size((1, 32, 10, 10))

No Description

Contributors (2)