You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_roi_align.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. from mindspore import Tensor, context
  17. from mindspore.common.api import _cell_graph_executor
  18. from mindspore.nn import Cell
  19. from mindspore.ops import operations as P
  20. POOLED_HEIGHT = 2
  21. POOLED_WIDTH = 2
  22. SPATIAL_SCALE = 0.5
  23. BATCH_SIZE = 32
  24. FEATURES_HEIGHT = 256
  25. FEATURES_WIDTH = 256
  26. CHANNELS = 3
  27. NUM_ROIS = 16
  28. _features = Tensor(np.random.normal(size=[BATCH_SIZE, CHANNELS, FEATURES_HEIGHT, FEATURES_WIDTH]).astype(np.float32))
  29. _rois = Tensor(
  30. np.hstack((np.random.randint(0, BATCH_SIZE, [NUM_ROIS, 1]).astype(np.float32),
  31. np.random.uniform(low=0, high=FEATURES_HEIGHT / SPATIAL_SCALE, size=[NUM_ROIS, 4]).astype(np.float32))))
  32. class Net(Cell):
  33. def __init__(self, pooled_h, pooled_w, spatial_scale, strategy=None):
  34. super(Net, self).__init__()
  35. self.roi_align = P.ROIAlign(pooled_h, pooled_w, spatial_scale).shard(strategy)
  36. def construct(self, features, rois):
  37. output = self.roi_align(features, rois)
  38. return output
  39. def compile_net(net: Cell, *inputs):
  40. net.set_auto_parallel()
  41. net.set_train()
  42. _cell_graph_executor.compile(net, *inputs)
  43. context.reset_auto_parallel_context()
  44. def test_roi_align_auto_parallel():
  45. """
  46. Feature: test ROIAlign auto parallel
  47. Description: auto parallel
  48. Expectation: compile success
  49. """
  50. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  51. net = Net(POOLED_HEIGHT, POOLED_WIDTH, SPATIAL_SCALE)
  52. compile_net(net, _features, _rois)
  53. def test_roi_align_data_parallel():
  54. """
  55. Feature: test ROIAlign data parallel
  56. Description: data parallel
  57. Expectation: compile success
  58. """
  59. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  60. strategy = ((4, 1, 1, 1), (2, 1))
  61. net = Net(POOLED_HEIGHT, POOLED_WIDTH, SPATIAL_SCALE, strategy)
  62. compile_net(net, _features, _rois)
  63. def test_roi_align_strategy_error():
  64. """
  65. Feature: test invalid strategy for ROIAlign
  66. Description: illegal strategy
  67. Expectation: raise RuntimeError
  68. """
  69. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  70. strategy = ((2, 1, 2, 2), (1, 1))
  71. net = Net(POOLED_HEIGHT, POOLED_WIDTH, SPATIAL_SCALE, strategy)
  72. with pytest.raises(RuntimeError):
  73. compile_net(net, _features, _rois)
  74. context.reset_auto_parallel_context()