You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_maxpool_avgpool.py 8.6 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore as ms
  17. from mindspore import context, Tensor, Parameter
  18. from mindspore.common.api import _cell_graph_executor
  19. from mindspore.nn import Cell, TrainOneStepCell, Momentum
  20. from mindspore.ops import operations as P
  21. class Net(Cell):
  22. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pool_kernel_size, pool_strides,
  23. strategy1=None, strategy2=None):
  24. super().__init__()
  25. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
  26. pad_mode=pad_mode, stride=stride).shard(strategy1)
  27. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  28. self.max_pool = P.MaxPool(kernel_size=pool_kernel_size, strides=pool_strides).shard(strategy2)
  29. def construct(self, x, b):
  30. out = self.conv2d(x, self.conv2d_weight)
  31. out = self.max_pool(out)
  32. return out
  33. class Net2(Cell):
  34. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pool_kernel_size, pool_strides,
  35. strategy1=None, strategy2=None):
  36. super().__init__()
  37. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
  38. pad_mode=pad_mode, stride=stride).shard(strategy1)
  39. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  40. self.avg_pool = P.AvgPool(kernel_size=pool_kernel_size, strides=pool_strides).shard(strategy2)
  41. def construct(self, x, b):
  42. out = self.conv2d(x, self.conv2d_weight)
  43. out = self.avg_pool(out)
  44. return out
  45. _x0 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
  46. _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  47. _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
  48. _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  49. def compile_net(net, inputs=_x):
  50. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  51. train_net = TrainOneStepCell(net, optimizer)
  52. train_net.set_auto_parallel()
  53. train_net.set_train()
  54. _cell_graph_executor.compile(train_net, inputs, _b)
  55. context.reset_auto_parallel_context()
  56. def test_maxpool_data_parallel():
  57. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  58. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  59. strategy2 = ((8, 1, 1, 1),)
  60. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
  61. strategy1=strategy1, strategy2=strategy2)
  62. compile_net(net)
  63. def test_maxpool_model_parallel1():
  64. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  65. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  66. strategy2 = ((2, 1, 2, 2),)
  67. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
  68. strategy1=strategy1, strategy2=strategy2)
  69. compile_net(net)
  70. def test_maxpool_model_parallel2():
  71. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  72. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  73. strategy2 = ((2, 1, 2, 2),)
  74. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4,
  75. strategy1=strategy1, strategy2=strategy2)
  76. compile_net(net)
  77. def test_maxpool_auto_parallel():
  78. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  79. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4)
  80. compile_net(net)
  81. def test_maxpool_output_is_not_divisible_by_strategy_w_dimension():
  82. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  83. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  84. strategy2 = ((1, 1, 1, 8),)
  85. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
  86. strategy1=strategy1, strategy2=strategy2)
  87. with pytest.raises(RuntimeError):
  88. compile_net(net)
  89. def test_maxpool_output_is_not_divisible_by_strategy_h_dimension():
  90. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  91. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  92. strategy2 = ((1, 1, 8, 1),)
  93. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
  94. strategy1=strategy1, strategy2=strategy2)
  95. with pytest.raises(RuntimeError):
  96. compile_net(net)
  97. def test_maxpool_shard_h_and_kernel_size_larger_than_stride():
  98. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  99. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  100. strategy2 = ((1, 1, 2, 1),)
  101. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=3, pool_strides=2,
  102. strategy1=strategy1, strategy2=strategy2)
  103. with pytest.raises(RuntimeError):
  104. compile_net(net)
  105. def test_maxpool_shard_w_and_kernel_size_larger_than_stride():
  106. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  107. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  108. strategy2 = ((1, 1, 1, 2),)
  109. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=3, pool_strides=2,
  110. strategy1=strategy1, strategy2=strategy2)
  111. with pytest.raises(RuntimeError):
  112. compile_net(net)
  113. def test_maxpool_shard_h_and_input_slice_is_not_divisible_by_stride():
  114. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  115. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  116. strategy2 = ((1, 1, 2, 1),)
  117. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=1, pool_strides=3,
  118. strategy1=strategy1, strategy2=strategy2)
  119. with pytest.raises(RuntimeError):
  120. compile_net(net, inputs=_x0)
  121. def test_maxpool_shard_w_and_input_slice_is_not_divisible_by_stride():
  122. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  123. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  124. strategy2 = ((1, 1, 2, 1),)
  125. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=1, pool_strides=3,
  126. strategy1=strategy1, strategy2=strategy2)
  127. with pytest.raises(RuntimeError):
  128. compile_net(net, inputs=_x0)
  129. def test_avgpool_data_parallel():
  130. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  131. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  132. strategy2 = ((8, 1, 1, 1),)
  133. net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
  134. strategy1=strategy1, strategy2=strategy2)
  135. compile_net(net)
  136. def test_avgpool_model_parallel1():
  137. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  138. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  139. strategy2 = ((2, 1, 2, 2),)
  140. net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
  141. strategy1=strategy1, strategy2=strategy2)
  142. compile_net(net)
  143. def test_avgpool_model_parallel2():
  144. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  145. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  146. strategy2 = ((2, 1, 2, 2),)
  147. net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4,
  148. strategy1=strategy1, strategy2=strategy2)
  149. compile_net(net)
  150. def test_avgpool_auto_parallel():
  151. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  152. net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4)
  153. compile_net(net)