You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_conv2d.py 20 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore as ms
  17. from mindspore import context, Tensor, Parameter
  18. from mindspore.common.api import _cell_graph_executor
  19. from mindspore.nn import Cell, TrainOneStepCell, Momentum
  20. from mindspore.ops import operations as P
  21. class Net(Cell):
  22. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1, pad=0,
  23. strategy1=None, strategy2=None):
  24. super().__init__()
  25. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, pad_mode=pad_mode, pad=pad,
  26. stride=stride, dilation=dilation, group=group).shard(strategy1)
  27. self.neg = P.Neg().shard(strategy2)
  28. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  29. def construct(self, x, b):
  30. out = self.conv2d(x, self.conv2d_weight)
  31. out = self.neg(out)
  32. return out
  33. _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  34. _x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
  35. _x3 = Tensor(np.ones([32, 16, 16, 16]), dtype=ms.float32)
  36. _w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
  37. _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
  38. _w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
  39. _w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
  40. _w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
  41. _w5 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
  42. _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  43. def compile_net(net, input_x=_x):
  44. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  45. train_net = TrainOneStepCell(net, optimizer)
  46. train_net.set_auto_parallel()
  47. train_net.set_train()
  48. _cell_graph_executor.compile(train_net, input_x, _b)
  49. context.reset_auto_parallel_context()
  50. def test_conv2d_data_parallel():
  51. """
  52. Feature: test conv2d data parallel
  53. Description: shard n dimension
  54. Expectation: compile success
  55. """
  56. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  57. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  58. strategy2 = ((8, 1, 1, 1),)
  59. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  60. compile_net(net)
  61. def test_conv2d_pad_mode_overlap_is_negative():
  62. """
  63. Feature: test conv2d pad mode and overlap is negative
  64. Description: shard h/w
  65. Expectation: compile failed
  66. """
  67. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  68. strategy1 = ((1, 1, 4, 4), (1, 1, 1, 1))
  69. strategy2 = ((1, 1, 1, 1),)
  70. net = Net(_w5, out_channel=8, kernel_size=4, pad_mode="pad", stride=5, pad=(3, 0, 3, 0),
  71. strategy1=strategy1, strategy2=strategy2)
  72. with pytest.raises(RuntimeError):
  73. compile_net(net, _x3)
  74. def test_conv2d_pad_mode():
  75. """
  76. Feature: test conv2d pad mode and overlap is non-negative
  77. Description: shard h/w
  78. Expectation: compile success
  79. """
  80. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  81. strategy1 = ((1, 1, 2, 4), (1, 1, 1, 1))
  82. strategy2 = ((1, 1, 1, 1),)
  83. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="pad", stride=1, pad=(3, 3, 3, 3),
  84. strategy1=strategy1, strategy2=strategy2)
  85. compile_net(net, _x3)
  86. def test_conv2d_valid_mode_output_shape_cannot_div_by_strategy():
  87. """
  88. Feature: test conv2d valid mode, and output shape can not div by strategy
  89. Description: shard w
  90. Expectation: compile failed
  91. """
  92. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  93. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  94. strategy2 = ((1, 1, 1, 1),)
  95. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="valid", stride=4,
  96. strategy1=strategy1, strategy2=strategy2)
  97. with pytest.raises(RuntimeError):
  98. compile_net(net, _x3)
  99. def test_conv2d_data_parallel_invalid_stride():
  100. """
  101. Feature: test conv2d invalid stride
  102. Description: the first two elements of stride must be 1, but set 2
  103. Expectation: compile failed
  104. """
  105. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  106. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  107. strategy2 = ((8, 1, 1, 1),)
  108. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=(2, 2, 1, 1),
  109. strategy1=strategy1, strategy2=strategy2)
  110. with pytest.raises(RuntimeError):
  111. compile_net(net)
  112. def test_conv2d_data_parallel_dilation():
  113. """
  114. Feature: test conv2d data parallel and dilation is not 1
  115. Description: data parallel and dilation is not 1
  116. Expectation: compile success
  117. """
  118. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  119. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  120. strategy2 = ((8, 1, 1, 1),)
  121. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  122. strategy1=strategy1, strategy2=strategy2)
  123. compile_net(net)
  124. def test_conv2d_data_parallel_group():
  125. """
  126. Feature: test conv2d data parallel and group is not 1
  127. Description: data parallel and group is not 1
  128. Expectation: compile success
  129. """
  130. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  131. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  132. strategy2 = ((8, 1, 1, 1),)
  133. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  134. strategy1=strategy1, strategy2=strategy2)
  135. compile_net(net)
  136. def test_conv2d_model_parallel1():
  137. """
  138. Feature: test conv2d model parallel
  139. Description: split n/c-in/c-out
  140. Expectation: compile success
  141. """
  142. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  143. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  144. strategy2 = ((8, 1, 1, 1),)
  145. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  146. compile_net(net)
  147. def test_conv2d_model_parallel_dilation():
  148. """
  149. Feature: test conv2d model parallel and dilation is not 1
  150. Description: model parallel and dilation is not 1
  151. Expectation: compile success
  152. """
  153. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  154. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  155. strategy2 = ((8, 1, 1, 1),)
  156. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  157. strategy1=strategy1, strategy2=strategy2)
  158. compile_net(net)
  159. def test_conv2d_model_parallel_group():
  160. """
  161. Feature: test conv2d model parallel and group is not 1
  162. Description: split cin and cout, and group is not 1
  163. Expectation: compile failed
  164. """
  165. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  166. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  167. strategy2 = ((8, 1, 1, 1),)
  168. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  169. strategy1=strategy1, strategy2=strategy2)
  170. with pytest.raises(RuntimeError):
  171. compile_net(net)
  172. def test_conv2d_model_parallel_group2():
  173. """
  174. Feature: test conv2d model parallel and group is not 1
  175. Description: has not to split cin and cout, and group is not 1
  176. Expectation: compile success
  177. """
  178. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  179. strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
  180. strategy2 = ((8, 1, 1, 1),)
  181. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  182. strategy1=strategy1, strategy2=strategy2)
  183. compile_net(net)
  184. def test_conv2d_model_parallel2():
  185. """
  186. Feature: same mode, stride = kernel_size, no need exchange
  187. Description: split n/c-in/c-out/h/w
  188. Expectation: compile success
  189. """
  190. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  191. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  192. strategy2 = ((32, 1, 1, 1),)
  193. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  194. compile_net(net)
  195. def test_conv2d_model_parallel3():
  196. """
  197. Feature: same mode, stride < kernel_size, need exchange
  198. Description: split n/w
  199. Expectation: compile success
  200. """
  201. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  202. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  203. strategy2 = ((2, 1, 1, 4),)
  204. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  205. compile_net(net)
  206. def test_conv2d_auto_parallel():
  207. """
  208. Feature: same mode, auto parallel
  209. Description: generate data parallel strategy
  210. Expectation: compile success
  211. """
  212. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  213. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1)
  214. compile_net(net)
  215. def test_conv2d_model_parallel4():
  216. """
  217. Feature: same mode, stride < kernel_size, need exchange
  218. Description: split n/c-in/c-out/w
  219. Expectation: compile success
  220. """
  221. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  222. strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1))
  223. strategy2 = ((2, 2, 1, 4),)
  224. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  225. compile_net(net)
  226. def test_conv2d_left_and_right_no_need_to_send():
  227. """
  228. Feature: same mode, k - s = 1, left pad is 0, single direction exchange
  229. Description: support that the left no need to send
  230. Expectation: compile success
  231. """
  232. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  233. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  234. strategy2 = ((2, 1, 1, 4),)
  235. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  236. compile_net(net)
  237. def test_conv2d_kernel_size_larger_than_stride_and_split_h():
  238. """
  239. Feature: same mode, stride < kernel_size, need exchange
  240. Description: split n/c-in/c-out/h
  241. Expectation: compile success
  242. """
  243. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  244. strategy1 = ((2, 2, 4, 1), (2, 2, 1, 1))
  245. strategy2 = ((2, 2, 4, 1),)
  246. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  247. compile_net(net)
  248. def test_conv2d_valid_mode_kernel_size_larger_than_stride():
  249. """
  250. Feature: valid mode, stride < kernel_size, need exchange
  251. Description: do not support to split w
  252. Expectation: compile failed
  253. """
  254. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  255. strategy1 = ((2, 1, 1, 2), (1, 1, 1, 1))
  256. strategy2 = ((2, 1, 1, 4),)
  257. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="valid", stride=1, strategy1=strategy1, strategy2=strategy2)
  258. with pytest.raises(RuntimeError):
  259. compile_net(net)
  260. def test_conv2d_output_can_not_divisible_by_strategy():
  261. """
  262. Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
  263. Description: split w dimension
  264. Expectation: compile failed
  265. """
  266. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  267. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  268. strategy2 = ((1, 1, 1, 8),)
  269. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  270. with pytest.raises(RuntimeError):
  271. compile_net(net)
  272. def test_conv2d_output_can_not_divisible_by_strategy2():
  273. """
  274. Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
  275. Description: split h dimension
  276. Expectation: compile failed
  277. """
  278. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  279. strategy1 = ((1, 1, 8, 1), (1, 1, 1, 1))
  280. strategy2 = ((1, 1, 1, 8),)
  281. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  282. with pytest.raises(RuntimeError):
  283. compile_net(net)
  284. def test_split_kernel():
  285. """
  286. Feature: split kernel size
  287. Description: do not support to split kernel size
  288. Expectation: compile failed
  289. """
  290. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  291. strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
  292. strategy2 = ((1, 1, 1, 8),)
  293. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  294. with pytest.raises(RuntimeError):
  295. compile_net(net)
  296. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_same_mode():
  297. """
  298. Feature: same mode, slice shape can not be divided by stride
  299. Description: split w
  300. Expectation: compile failed
  301. """
  302. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  303. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  304. strategy2 = ((1, 1, 1, 8),)
  305. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  306. with pytest.raises(RuntimeError):
  307. compile_net(net, _x2)
  308. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  309. """
  310. Feature: valid mode, slice shape can not be divided by stride
  311. Description: split w
  312. Expectation: compile failed
  313. """
  314. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  315. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  316. strategy2 = ((1, 1, 1, 8),)
  317. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  318. with pytest.raises(RuntimeError):
  319. compile_net(net, _x2)
  320. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_by_stride_same_mode():
  321. """
  322. Feature: same mode, slice shape can not be divided by stride
  323. Description: split h
  324. Expectation: compile failed
  325. """
  326. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  327. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  328. strategy2 = ((1, 1, 1, 8),)
  329. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  330. with pytest.raises(RuntimeError):
  331. compile_net(net, _x2)
  332. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  333. """
  334. Feature: valid mode, slice shape can not be divided by stride
  335. Description: split h
  336. Expectation: compile failed
  337. """
  338. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  339. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  340. strategy2 = ((1, 1, 1, 8),)
  341. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  342. with pytest.raises(RuntimeError):
  343. compile_net(net, _x2)
  344. def test_split_h_dimension_and_pad_mode_is_pad():
  345. """
  346. Feature: pad mode
  347. Description: split h
  348. Expectation: compile failed
  349. """
  350. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  351. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  352. strategy2 = ((1, 1, 1, 8),)
  353. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="pad", stride=2, strategy1=strategy1, strategy2=strategy2)
  354. with pytest.raises(RuntimeError):
  355. compile_net(net)
  356. def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride():
  357. """
  358. Feature: same mode, input shape can not be divided by stride
  359. Description: split w
  360. Expectation: compile failed
  361. """
  362. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  363. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  364. strategy2 = ((1, 1, 1, 8),)
  365. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  366. with pytest.raises(RuntimeError):
  367. compile_net(net, _x2)
  368. def test_kernel_size_larger_than_stride_and_slice_too_small():
  369. """
  370. Feature: same mode, slice shape is small than overlap shape
  371. Description: split w
  372. Expectation: compile failed
  373. """
  374. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  375. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  376. strategy2 = ((1, 1, 1, 8),)
  377. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  378. with pytest.raises(RuntimeError):
  379. compile_net(net)
  380. def test_conv2d_dilation():
  381. """
  382. Feature: same mode, dilation is 2
  383. Description: split n/h/w
  384. Expectation: compile success
  385. """
  386. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  387. strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
  388. strategy2 = ((2, 2, 1, 2),)
  389. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, dilation=2, strategy1=strategy1,
  390. strategy2=strategy2)
  391. compile_net(net)
  392. def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():
  393. """
  394. Feature: same mode, slice shape is equal to overlap shape
  395. Description: split w
  396. Expectation: compile success
  397. """
  398. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  399. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  400. strategy2 = ((2, 1, 1, 4),)
  401. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  402. compile_net(net)
  403. def test_kernel_size_larger_than_stride_and_left_pad_is_0():
  404. """
  405. Feature: same mode, kernel_size > stride and left pad is 0, single direction exchange
  406. Description: split w
  407. Expectation: compile success
  408. """
  409. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  410. strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1))
  411. strategy2 = ((1, 1, 1, 8),)
  412. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  413. compile_net(net)
  414. def test_conv2d_kernel_size_larger_than_stride_and_split_nchw():
  415. """
  416. Feature: same mode, stride < kernel_size, need exchange
  417. Description: split n/c-in/c-out/h/w
  418. Expectation: compile success
  419. """
  420. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  421. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  422. strategy2 = ((2, 2, 2, 2),)
  423. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  424. compile_net(net)