You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_neck.py 4.8 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import mmcv
  4. import pytest
  5. import torch
  6. from mmdet import digit_version
  7. from mmdet.models.necks import FPN, YOLOV3Neck
  8. from .utils import ort_validate
  9. if digit_version(torch.__version__) <= digit_version('1.5.0'):
  10. pytest.skip(
  11. 'ort backend does not support version below 1.5.0',
  12. allow_module_level=True)
  13. # Control the returned model of fpn_neck_config()
  14. fpn_test_step_names = {
  15. 'fpn_normal': 0,
  16. 'fpn_wo_extra_convs': 1,
  17. 'fpn_lateral_bns': 2,
  18. 'fpn_bilinear_upsample': 3,
  19. 'fpn_scale_factor': 4,
  20. 'fpn_extra_convs_inputs': 5,
  21. 'fpn_extra_convs_laterals': 6,
  22. 'fpn_extra_convs_outputs': 7,
  23. }
  24. # Control the returned model of yolo_neck_config()
  25. yolo_test_step_names = {'yolo_normal': 0}
  26. data_path = osp.join(osp.dirname(__file__), 'data')
  27. def fpn_neck_config(test_step_name):
  28. """Return the class containing the corresponding attributes according to
  29. the fpn_test_step_names."""
  30. s = 64
  31. in_channels = [8, 16, 32, 64]
  32. feat_sizes = [s // 2**i for i in range(4)] # [64, 32, 16, 8]
  33. out_channels = 8
  34. feats = [
  35. torch.rand(1, in_channels[i], feat_sizes[i], feat_sizes[i])
  36. for i in range(len(in_channels))
  37. ]
  38. if (fpn_test_step_names[test_step_name] == 0):
  39. fpn_model = FPN(
  40. in_channels=in_channels,
  41. out_channels=out_channels,
  42. add_extra_convs=True,
  43. num_outs=5)
  44. elif (fpn_test_step_names[test_step_name] == 1):
  45. fpn_model = FPN(
  46. in_channels=in_channels,
  47. out_channels=out_channels,
  48. add_extra_convs=False,
  49. num_outs=5)
  50. elif (fpn_test_step_names[test_step_name] == 2):
  51. fpn_model = FPN(
  52. in_channels=in_channels,
  53. out_channels=out_channels,
  54. add_extra_convs=True,
  55. no_norm_on_lateral=False,
  56. norm_cfg=dict(type='BN', requires_grad=True),
  57. num_outs=5)
  58. elif (fpn_test_step_names[test_step_name] == 3):
  59. fpn_model = FPN(
  60. in_channels=in_channels,
  61. out_channels=out_channels,
  62. add_extra_convs=True,
  63. upsample_cfg=dict(mode='bilinear', align_corners=True),
  64. num_outs=5)
  65. elif (fpn_test_step_names[test_step_name] == 4):
  66. fpn_model = FPN(
  67. in_channels=in_channels,
  68. out_channels=out_channels,
  69. add_extra_convs=True,
  70. upsample_cfg=dict(scale_factor=2),
  71. num_outs=5)
  72. elif (fpn_test_step_names[test_step_name] == 5):
  73. fpn_model = FPN(
  74. in_channels=in_channels,
  75. out_channels=out_channels,
  76. add_extra_convs='on_input',
  77. num_outs=5)
  78. elif (fpn_test_step_names[test_step_name] == 6):
  79. fpn_model = FPN(
  80. in_channels=in_channels,
  81. out_channels=out_channels,
  82. add_extra_convs='on_lateral',
  83. num_outs=5)
  84. elif (fpn_test_step_names[test_step_name] == 7):
  85. fpn_model = FPN(
  86. in_channels=in_channels,
  87. out_channels=out_channels,
  88. add_extra_convs='on_output',
  89. num_outs=5)
  90. return fpn_model, feats
  91. def yolo_neck_config(test_step_name):
  92. """Config yolov3 Neck."""
  93. in_channels = [16, 8, 4]
  94. out_channels = [8, 4, 2]
  95. # The data of yolov3_neck.pkl contains a list of
  96. # torch.Tensor, where each torch.Tensor is generated by
  97. # torch.rand and each tensor size is:
  98. # (1, 4, 64, 64), (1, 8, 32, 32), (1, 16, 16, 16).
  99. yolov3_neck_data = 'yolov3_neck.pkl'
  100. feats = mmcv.load(osp.join(data_path, yolov3_neck_data))
  101. if (yolo_test_step_names[test_step_name] == 0):
  102. yolo_model = YOLOV3Neck(
  103. in_channels=in_channels, out_channels=out_channels, num_scales=3)
  104. return yolo_model, feats
  105. def test_fpn_normal():
  106. outs = fpn_neck_config('fpn_normal')
  107. ort_validate(*outs)
  108. def test_fpn_wo_extra_convs():
  109. outs = fpn_neck_config('fpn_wo_extra_convs')
  110. ort_validate(*outs)
  111. def test_fpn_lateral_bns():
  112. outs = fpn_neck_config('fpn_lateral_bns')
  113. ort_validate(*outs)
  114. def test_fpn_bilinear_upsample():
  115. outs = fpn_neck_config('fpn_bilinear_upsample')
  116. ort_validate(*outs)
  117. def test_fpn_scale_factor():
  118. outs = fpn_neck_config('fpn_scale_factor')
  119. ort_validate(*outs)
  120. def test_fpn_extra_convs_inputs():
  121. outs = fpn_neck_config('fpn_extra_convs_inputs')
  122. ort_validate(*outs)
  123. def test_fpn_extra_convs_laterals():
  124. outs = fpn_neck_config('fpn_extra_convs_laterals')
  125. ort_validate(*outs)
  126. def test_fpn_extra_convs_outputs():
  127. outs = fpn_neck_config('fpn_extra_convs_outputs')
  128. ort_validate(*outs)
  129. def test_yolo_normal():
  130. outs = yolo_neck_config('yolo_normal')
  131. ort_validate(*outs)

No Description

Contributors (2)