You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_conv2d.py 17 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore as ms
  17. from mindspore import context, Tensor, Parameter
  18. from mindspore.common.api import _cell_graph_executor
  19. from mindspore.nn import Cell, TrainOneStepCell, Momentum
  20. from mindspore.ops import operations as P
  21. class Net(Cell):
  22. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1,
  23. strategy1=None, strategy2=None):
  24. super().__init__()
  25. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
  26. pad_mode=pad_mode, stride=stride, dilation=dilation, group=group).shard(strategy1)
  27. self.neg = P.Neg().shard(strategy2)
  28. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  29. def construct(self, x, b):
  30. out = self.conv2d(x, self.conv2d_weight)
  31. out = self.neg(out)
  32. return out
  33. _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  34. _x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
  35. _w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
  36. _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
  37. _w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
  38. _w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
  39. _w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
  40. _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  41. def compile_net(net, input_x=_x):
  42. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  43. train_net = TrainOneStepCell(net, optimizer)
  44. train_net.set_auto_parallel()
  45. train_net.set_train()
  46. _cell_graph_executor.compile(train_net, input_x, _b)
  47. context.reset_auto_parallel_context()
  48. def test_conv2d_data_parallel():
  49. """
  50. Feature: test conv2d data parallel
  51. Description: shard n dimension
  52. Expectation: compile success
  53. """
  54. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  55. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  56. strategy2 = ((8, 1, 1, 1),)
  57. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  58. compile_net(net)
  59. def test_conv2d_data_parallel_invalid_stride():
  60. """
  61. Feature: test conv2d invalid stride
  62. Description: the first two elements of stride must be 1, but set 2
  63. Expectation: compile success
  64. """
  65. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  66. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  67. strategy2 = ((8, 1, 1, 1),)
  68. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=(2, 2, 1, 1),
  69. strategy1=strategy1, strategy2=strategy2)
  70. with pytest.raises(RuntimeError):
  71. compile_net(net)
  72. def test_conv2d_data_parallel_dilation():
  73. """
  74. Feature: test conv2d data parallel and dilation is not 1
  75. Description: data parallel and dilation is not 1
  76. Expectation: compile success
  77. """
  78. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  79. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  80. strategy2 = ((8, 1, 1, 1),)
  81. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  82. strategy1=strategy1, strategy2=strategy2)
  83. compile_net(net)
  84. def test_conv2d_data_parallel_group():
  85. """
  86. Feature: test conv2d data parallel and group is not 1
  87. Description: data parallel and group is not 1
  88. Expectation: compile success
  89. """
  90. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  91. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  92. strategy2 = ((8, 1, 1, 1),)
  93. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  94. strategy1=strategy1, strategy2=strategy2)
  95. compile_net(net)
  96. def test_conv2d_model_parallel1():
  97. """
  98. Feature: test conv2d model parallel
  99. Description: split n/c-in/c-out
  100. Expectation: compile success
  101. """
  102. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  103. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  104. strategy2 = ((8, 1, 1, 1),)
  105. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  106. compile_net(net)
  107. def test_conv2d_model_parallel_dilation():
  108. """
  109. Feature: test conv2d model parallel and dilation is not 1
  110. Description: model parallel and dilation is not 1
  111. Expectation: compile failed
  112. """
  113. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  114. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  115. strategy2 = ((8, 1, 1, 1),)
  116. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  117. strategy1=strategy1, strategy2=strategy2)
  118. with pytest.raises(RuntimeError):
  119. compile_net(net)
  120. def test_conv2d_model_parallel_group():
  121. """
  122. Feature: test conv2d model parallel and group is not 1
  123. Description: model parallel and group is not 1
  124. Expectation: compile failed
  125. """
  126. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  127. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  128. strategy2 = ((8, 1, 1, 1),)
  129. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  130. strategy1=strategy1, strategy2=strategy2)
  131. with pytest.raises(RuntimeError):
  132. compile_net(net)
  133. def test_conv2d_model_parallel2():
  134. """
  135. Feature: same mode, stride = kernel_size, no need exchange
  136. Description: split n/c-in/c-out/h/w
  137. Expectation: compile success
  138. """
  139. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  140. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  141. strategy2 = ((32, 1, 1, 1),)
  142. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  143. compile_net(net)
  144. def test_conv2d_model_parallel3():
  145. """
  146. Feature: same mode, stride < kernel_size, need exchange
  147. Description: split n/w
  148. Expectation: compile success
  149. """
  150. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  151. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  152. strategy2 = ((2, 1, 1, 4),)
  153. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  154. compile_net(net)
  155. def test_conv2d_auto_parallel():
  156. """
  157. Feature: same mode, auto parallel
  158. Description: generate data parallel strategy
  159. Expectation: compile success
  160. """
  161. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  162. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1)
  163. compile_net(net)
  164. def test_conv2d_model_parallel4():
  165. """
  166. Feature: same mode, stride < kernel_size, need exchange
  167. Description: split n/c-in/c-out/w
  168. Expectation: compile success
  169. """
  170. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  171. strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1))
  172. strategy2 = ((2, 2, 1, 4),)
  173. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  174. compile_net(net)
  175. def test_conv2d_left_and_right_no_need_to_send():
  176. """
  177. Feature: same mode, k - s = 1, left pad is 0
  178. Description: do not support that the left no need to send
  179. Expectation: compile failed
  180. """
  181. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  182. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  183. strategy2 = ((2, 1, 1, 4),)
  184. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  185. with pytest.raises(RuntimeError):
  186. compile_net(net)
  187. def test_conv2d_kernel_size_larger_than_stride_and_split_h():
  188. """
  189. Feature: same mode, stride < kernel_size, need exchange
  190. Description: split n/c-in/c-out/h
  191. Expectation: compile success
  192. """
  193. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  194. strategy1 = ((2, 2, 4, 1), (2, 2, 1, 1))
  195. strategy2 = ((2, 2, 4, 1),)
  196. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  197. compile_net(net)
  198. def test_conv2d_valid_mode_kernel_size_larger_than_stride():
  199. """
  200. Feature: valid mode, stride < kernel_size, need exchange
  201. Description: do not support to split w
  202. Expectation: compile failed
  203. """
  204. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  205. strategy1 = ((2, 1, 1, 2), (1, 1, 1, 1))
  206. strategy2 = ((2, 1, 1, 4),)
  207. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="valid", stride=1, strategy1=strategy1, strategy2=strategy2)
  208. with pytest.raises(RuntimeError):
  209. compile_net(net)
  210. def test_conv2d_output_can_not_divisible_by_strategy():
  211. """
  212. Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
  213. Description: split w dimension
  214. Expectation: compile failed
  215. """
  216. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  217. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  218. strategy2 = ((1, 1, 1, 8),)
  219. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  220. with pytest.raises(RuntimeError):
  221. compile_net(net)
  222. def test_conv2d_output_can_not_divisible_by_strategy2():
  223. """
  224. Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
  225. Description: split h dimension
  226. Expectation: compile failed
  227. """
  228. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  229. strategy1 = ((1, 1, 8, 1), (1, 1, 1, 1))
  230. strategy2 = ((1, 1, 1, 8),)
  231. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  232. with pytest.raises(RuntimeError):
  233. compile_net(net)
  234. def test_split_kernel():
  235. """
  236. Feature: split kernel size
  237. Description: do not support to split kernel size
  238. Expectation: compile failed
  239. """
  240. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  241. strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
  242. strategy2 = ((1, 1, 1, 8),)
  243. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  244. with pytest.raises(RuntimeError):
  245. compile_net(net)
  246. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_same_mode():
  247. """
  248. Feature: same mode, slice shape can not be divided by stride
  249. Description: split w
  250. Expectation: compile failed
  251. """
  252. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  253. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  254. strategy2 = ((1, 1, 1, 8),)
  255. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  256. with pytest.raises(RuntimeError):
  257. compile_net(net, _x2)
  258. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  259. """
  260. Feature: valid mode, slice shape can not be divided by stride
  261. Description: split w
  262. Expectation: compile failed
  263. """
  264. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  265. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  266. strategy2 = ((1, 1, 1, 8),)
  267. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  268. with pytest.raises(RuntimeError):
  269. compile_net(net, _x2)
  270. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_by_stride_same_mode():
  271. """
  272. Feature: same mode, slice shape can not be divided by stride
  273. Description: split h
  274. Expectation: compile failed
  275. """
  276. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  277. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  278. strategy2 = ((1, 1, 1, 8),)
  279. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  280. with pytest.raises(RuntimeError):
  281. compile_net(net, _x2)
  282. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  283. """
  284. Feature: valid mode, slice shape can not be divided by stride
  285. Description: split h
  286. Expectation: compile failed
  287. """
  288. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  289. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  290. strategy2 = ((1, 1, 1, 8),)
  291. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  292. with pytest.raises(RuntimeError):
  293. compile_net(net, _x2)
  294. def test_split_h_dimension_and_pad_mode_is_pad():
  295. """
  296. Feature: pad mode
  297. Description: split h
  298. Expectation: compile failed
  299. """
  300. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  301. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  302. strategy2 = ((1, 1, 1, 8),)
  303. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="pad", stride=2, strategy1=strategy1, strategy2=strategy2)
  304. with pytest.raises(RuntimeError):
  305. compile_net(net)
  306. def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride():
  307. """
  308. Feature: same mode, input shape can not be divided by stride
  309. Description: split w
  310. Expectation: compile failed
  311. """
  312. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  313. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  314. strategy2 = ((1, 1, 1, 8),)
  315. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  316. with pytest.raises(RuntimeError):
  317. compile_net(net, _x2)
  318. def test_kernel_size_larger_than_stride_and_slice_too_small():
  319. """
  320. Feature: same mode, slice shape is small than overlap shape
  321. Description: split w
  322. Expectation: compile failed
  323. """
  324. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  325. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  326. strategy2 = ((1, 1, 1, 8),)
  327. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  328. with pytest.raises(RuntimeError):
  329. compile_net(net)
  330. def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():
  331. """
  332. Feature: same mode, slice shape is equal to overlap shape
  333. Description: split w
  334. Expectation: compile failed
  335. """
  336. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  337. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  338. strategy2 = ((2, 1, 1, 4),)
  339. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  340. with pytest.raises(RuntimeError):
  341. compile_net(net)
  342. def test_kernel_size_larger_than_stride_and_left_pad_is_0():
  343. """
  344. Feature: same mode, kernel_size > stride and left pad is 0
  345. Description: split w
  346. Expectation: compile failed
  347. """
  348. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  349. strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1))
  350. strategy2 = ((1, 1, 1, 8),)
  351. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  352. with pytest.raises(RuntimeError):
  353. compile_net(net)
  354. def test_conv2d_kernel_size_larger_than_stride_and_split_nchw():
  355. """
  356. Feature: same mode, stride < kernel_size, need exchange
  357. Description: split n/c-in/c-out/h/w
  358. Expectation: compile success
  359. """
  360. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  361. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  362. strategy2 = ((2, 2, 2, 2),)
  363. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  364. compile_net(net)