You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_roi_align.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. from mindspore import Tensor, context
  17. from mindspore.common.api import _cell_graph_executor
  18. from mindspore.nn import Cell
  19. from mindspore.ops import operations as P
  20. from parallel.utils.utils import ParallelValidator
  21. POOLED_HEIGHT = 2
  22. POOLED_WIDTH = 2
  23. SPATIAL_SCALE = 0.5
  24. BATCH_SIZE = 32
  25. FEATURES_HEIGHT = 256
  26. FEATURES_WIDTH = 256
  27. CHANNELS = 3
  28. NUM_ROIS = 16
  29. _features = Tensor(np.random.normal(size=[BATCH_SIZE, CHANNELS, FEATURES_HEIGHT, FEATURES_WIDTH]).astype(np.float32))
  30. _rois = Tensor(
  31. np.hstack((np.random.randint(0, BATCH_SIZE, [NUM_ROIS, 1]).astype(np.float32),
  32. np.random.uniform(low=0, high=FEATURES_HEIGHT / SPATIAL_SCALE, size=[NUM_ROIS, 4]).astype(np.float32))))
  33. class Net(Cell):
  34. def __init__(self, pooled_h, pooled_w, spatial_scale, strategy=None):
  35. super(Net, self).__init__()
  36. self.roi_align = P.ROIAlign(pooled_h, pooled_w, spatial_scale).shard(strategy)
  37. def construct(self, features, rois):
  38. output = self.roi_align(features, rois)
  39. return output
  40. def compile_net(net: Cell, *inputs):
  41. net.set_auto_parallel()
  42. net.set_train()
  43. phase, _ = _cell_graph_executor.compile(net, *inputs, auto_parallel_mode=True)
  44. context.reset_auto_parallel_context()
  45. return phase
  46. def test_roi_align_auto_parallel():
  47. """
  48. Feature: test ROIAlign auto parallel
  49. Description: auto parallel
  50. Expectation: compile success
  51. """
  52. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  53. net = Net(POOLED_HEIGHT, POOLED_WIDTH, SPATIAL_SCALE)
  54. compile_net(net, _features, _rois)
  55. def test_roi_align_data_parallel():
  56. """
  57. Feature: test ROIAlign data parallel
  58. Description: data parallel
  59. Expectation: compile success
  60. """
  61. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  62. strategy = ((4, 1, 1, 1), (2, 1))
  63. net = Net(POOLED_HEIGHT, POOLED_WIDTH, SPATIAL_SCALE, strategy)
  64. compile_net(net, _features, _rois)
  65. def test_roi_align_strategy_error():
  66. """
  67. Feature: test invalid strategy for ROIAlign
  68. Description: illegal strategy
  69. Expectation: raise RuntimeError
  70. """
  71. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  72. strategy = ((2, 1, 2, 2), (1, 1))
  73. net = Net(POOLED_HEIGHT, POOLED_WIDTH, SPATIAL_SCALE, strategy)
  74. with pytest.raises(RuntimeError):
  75. compile_net(net, _features, _rois)
  76. context.reset_auto_parallel_context()
  77. def test_roi_align_layout():
  78. """
  79. Features: ROIAlignInfo
  80. Description: validate layout and structure
  81. Expectation: No raise RuntimeError
  82. """
  83. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  84. strategy = ((4, 1, 1, 1), (2, 1))
  85. net = Net(POOLED_HEIGHT, POOLED_WIDTH, SPATIAL_SCALE, strategy)
  86. phase = compile_net(net, _features, _rois)
  87. validator = ParallelValidator(net, phase)
  88. # check layout
  89. features_expect_layout = ([8], [0, -1, -1, -1], [4, 3, 256, 256], 0, True, '')
  90. assert validator.check_parameter_layout('features', features_expect_layout)
  91. # check attrs
  92. roi_expect_attrs = {'pooled_height': POOLED_HEIGHT, 'pooled_width': POOLED_WIDTH, 'spatial_scale': SPATIAL_SCALE}
  93. assert validator.check_node_attrs('ROIAlign-0', roi_expect_attrs)
  94. # check inputs
  95. roi_expect_inputs = ['Reshape-1', 'TensorScatterUpdate-0']
  96. assert validator.check_node_inputs('ROIAlign-0', roi_expect_inputs)
  97. # check sub_graph
  98. sub_graph = {
  99. 'TensorScatterUpdate-0': ['Reshape-3', 'Stack-0', 'Minimum-0'],
  100. 'Equal-0': ['Sub-0', 'Minimum-0'],
  101. 'ROIAlign-0': ['Reshape-1', 'TensorScatterUpdate-0'],
  102. 'Mul-0': ['ROIAlign-0', 'ExpandDims-2'],
  103. 'AllReduce-0': ['Mul-0']
  104. }
  105. assert validator.check_graph_structure(sub_graph)