You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_plugins.py 840 B

2 years ago
1234567891011121314151617181920212223242526272829
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import pytest
  3. import torch
  4. from mmdet.models.plugins import DropBlock
  5. def test_dropblock():
  6. feat = torch.rand(1, 1, 11, 11)
  7. drop_prob = 1.0
  8. dropblock = DropBlock(drop_prob, block_size=11, warmup_iters=0)
  9. out_feat = dropblock(feat)
  10. assert (out_feat == 0).all() and out_feat.shape == feat.shape
  11. drop_prob = 0.5
  12. dropblock = DropBlock(drop_prob, block_size=5, warmup_iters=0)
  13. out_feat = dropblock(feat)
  14. assert out_feat.shape == feat.shape
  15. # drop_prob must be (0,1]
  16. with pytest.raises(AssertionError):
  17. DropBlock(1.5, 3)
  18. # block_size cannot be an even number
  19. with pytest.raises(AssertionError):
  20. DropBlock(0.5, 2)
  21. # warmup_iters cannot be less than 0
  22. with pytest.raises(AssertionError):
  23. DropBlock(0.5, 3, -1)

No Description

Contributors (2)