You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_conv2d.py 13 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore as ms
  17. from mindspore import context, Tensor, Parameter
  18. from mindspore.common.api import _cell_graph_executor
  19. from mindspore.nn import Cell, TrainOneStepCell, Momentum
  20. from mindspore.ops import operations as P
  21. class Net(Cell):
  22. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1,
  23. strategy1=None, strategy2=None):
  24. super().__init__()
  25. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
  26. pad_mode=pad_mode, stride=stride, dilation=dilation, group=group).shard(strategy1)
  27. self.neg = P.Neg().shard(strategy2)
  28. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  29. def construct(self, x, b):
  30. out = self.conv2d(x, self.conv2d_weight)
  31. out = self.neg(out)
  32. return out
  33. _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  34. _x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
  35. _w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
  36. _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
  37. _w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
  38. _w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
  39. _w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
  40. _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  41. def compile_net(net, input_x=_x):
  42. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  43. train_net = TrainOneStepCell(net, optimizer)
  44. train_net.set_auto_parallel()
  45. train_net.set_train()
  46. _cell_graph_executor.compile(train_net, input_x, _b)
  47. context.reset_auto_parallel_context()
  48. def test_conv2d_data_parallel():
  49. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  50. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  51. strategy2 = ((8, 1, 1, 1),)
  52. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  53. compile_net(net)
  54. def test_conv2d_data_parallel_invalid_stride():
  55. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  56. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  57. strategy2 = ((8, 1, 1, 1),)
  58. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=(2, 2, 1, 1),
  59. strategy1=strategy1, strategy2=strategy2)
  60. with pytest.raises(RuntimeError):
  61. compile_net(net)
  62. def test_conv2d_data_parallel_dilation():
  63. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  64. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  65. strategy2 = ((8, 1, 1, 1),)
  66. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  67. strategy1=strategy1, strategy2=strategy2)
  68. compile_net(net)
  69. def test_conv2d_data_parallel_group():
  70. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  71. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  72. strategy2 = ((8, 1, 1, 1),)
  73. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  74. strategy1=strategy1, strategy2=strategy2)
  75. compile_net(net)
  76. def test_conv2d_model_parallel1():
  77. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  78. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  79. strategy2 = ((8, 1, 1, 1),)
  80. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  81. compile_net(net)
  82. def test_conv2d_model_parallel_dilation():
  83. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  84. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  85. strategy2 = ((8, 1, 1, 1),)
  86. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  87. strategy1=strategy1, strategy2=strategy2)
  88. with pytest.raises(RuntimeError):
  89. compile_net(net)
  90. def test_conv2d_model_parallel_group():
  91. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  92. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  93. strategy2 = ((8, 1, 1, 1),)
  94. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  95. strategy1=strategy1, strategy2=strategy2)
  96. with pytest.raises(RuntimeError):
  97. compile_net(net)
  98. def test_conv2d_model_parallel2():
  99. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  100. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  101. strategy2 = ((32, 1, 1, 1),)
  102. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  103. compile_net(net)
  104. def test_conv2d_model_parallel3():
  105. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  106. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  107. strategy2 = ((2, 1, 1, 4),)
  108. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  109. compile_net(net)
  110. def test_conv2d_auto_parallel():
  111. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  112. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1)
  113. compile_net(net)
  114. def test_conv2d_model_parallel4():
  115. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  116. strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1))
  117. strategy2 = ((2, 2, 1, 4),)
  118. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  119. compile_net(net)
  120. def test_conv2d_left_and_right_no_need_to_send():
  121. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  122. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  123. strategy2 = ((2, 1, 1, 4),)
  124. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  125. with pytest.raises(RuntimeError):
  126. compile_net(net)
  127. def test_conv2d_kernel_size_larger_than_stride_and_split_h():
  128. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  129. strategy1 = ((2, 2, 4, 1), (2, 2, 1, 1))
  130. strategy2 = ((2, 2, 4, 1),)
  131. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  132. with pytest.raises(RuntimeError):
  133. compile_net(net)
  134. def test_conv2d_valid_mode_kernel_size_larger_than_stride():
  135. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  136. strategy1 = ((2, 1, 1, 2), (1, 1, 1, 1))
  137. strategy2 = ((2, 1, 1, 4),)
  138. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="valid", stride=1, strategy1=strategy1, strategy2=strategy2)
  139. with pytest.raises(RuntimeError):
  140. compile_net(net)
  141. def test_conv2d_output_can_not_divisible_by_strategy():
  142. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  143. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  144. strategy2 = ((1, 1, 1, 8),)
  145. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  146. with pytest.raises(RuntimeError):
  147. compile_net(net)
  148. def test_conv2d_output_can_not_divisible_by_strategy2():
  149. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  150. strategy1 = ((1, 1, 8, 1), (1, 1, 1, 1))
  151. strategy2 = ((1, 1, 1, 8),)
  152. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  153. with pytest.raises(RuntimeError):
  154. compile_net(net)
  155. def test_split_kernel():
  156. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  157. strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
  158. strategy2 = ((1, 1, 1, 8),)
  159. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  160. with pytest.raises(RuntimeError):
  161. compile_net(net)
  162. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_same_mode():
  163. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  164. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  165. strategy2 = ((1, 1, 1, 8),)
  166. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  167. with pytest.raises(RuntimeError):
  168. compile_net(net, _x2)
  169. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  170. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  171. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  172. strategy2 = ((1, 1, 1, 8),)
  173. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  174. with pytest.raises(RuntimeError):
  175. compile_net(net, _x2)
  176. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_by_stride_same_mode():
  177. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  178. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  179. strategy2 = ((1, 1, 1, 8),)
  180. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  181. with pytest.raises(RuntimeError):
  182. compile_net(net, _x2)
  183. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  184. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  185. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  186. strategy2 = ((1, 1, 1, 8),)
  187. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  188. with pytest.raises(RuntimeError):
  189. compile_net(net, _x2)
  190. def test_split_h_dimension_and_pad_mode_is_pad():
  191. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  192. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  193. strategy2 = ((1, 1, 1, 8),)
  194. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="pad", stride=2, strategy1=strategy1, strategy2=strategy2)
  195. with pytest.raises(RuntimeError):
  196. compile_net(net)
  197. def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride():
  198. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  199. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  200. strategy2 = ((1, 1, 1, 8),)
  201. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  202. with pytest.raises(RuntimeError):
  203. compile_net(net, _x2)
  204. def test_kernel_size_larger_than_stride_and_slice_too_small():
  205. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  206. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  207. strategy2 = ((1, 1, 1, 8),)
  208. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  209. with pytest.raises(RuntimeError):
  210. compile_net(net)
  211. def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():
  212. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  213. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  214. strategy2 = ((2, 1, 1, 4),)
  215. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  216. with pytest.raises(RuntimeError):
  217. compile_net(net)
  218. def test_kernel_size_larger_than_stride_and_left_pad_is_0():
  219. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  220. strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1))
  221. strategy2 = ((1, 1, 1, 8),)
  222. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  223. with pytest.raises(RuntimeError):
  224. compile_net(net)