You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_conv2d.py 19 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore as ms
  17. from mindspore import context, Tensor, Parameter
  18. from mindspore.common.api import _cell_graph_executor
  19. from mindspore.nn import Cell, TrainOneStepCell, Momentum
  20. from mindspore.ops import operations as P
  21. class Net(Cell):
  22. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1, pad=0,
  23. strategy1=None, strategy2=None):
  24. super().__init__()
  25. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, pad_mode=pad_mode, pad=pad,
  26. stride=stride, dilation=dilation, group=group).shard(strategy1)
  27. self.neg = P.Neg().shard(strategy2)
  28. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  29. def construct(self, x, b):
  30. out = self.conv2d(x, self.conv2d_weight)
  31. out = self.neg(out)
  32. return out
  33. _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  34. _x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
  35. _x3 = Tensor(np.ones([32, 16, 16, 16]), dtype=ms.float32)
  36. _w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
  37. _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
  38. _w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
  39. _w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
  40. _w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
  41. _w5 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
  42. _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  43. def compile_net(net, input_x=_x):
  44. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  45. train_net = TrainOneStepCell(net, optimizer)
  46. train_net.set_auto_parallel()
  47. train_net.set_train()
  48. _cell_graph_executor.compile(train_net, input_x, _b)
  49. context.reset_auto_parallel_context()
  50. def test_conv2d_data_parallel():
  51. """
  52. Feature: test conv2d data parallel
  53. Description: shard n dimension
  54. Expectation: compile success
  55. """
  56. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  57. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  58. strategy2 = ((8, 1, 1, 1),)
  59. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  60. compile_net(net)
  61. def test_conv2d_pad_mode_overlap_is_negative():
  62. """
  63. Feature: test conv2d pad mode and overlap is negative
  64. Description: shard h/w
  65. Expectation: compile failed
  66. """
  67. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  68. strategy1 = ((1, 1, 4, 4), (1, 1, 1, 1))
  69. strategy2 = ((1, 1, 1, 1),)
  70. net = Net(_w5, out_channel=8, kernel_size=4, pad_mode="pad", stride=5, pad=(3, 0, 3, 0),
  71. strategy1=strategy1, strategy2=strategy2)
  72. with pytest.raises(RuntimeError):
  73. compile_net(net, _x3)
  74. def test_conv2d_pad_mode():
  75. """
  76. Feature: test conv2d pad mode and overlap is non-negative
  77. Description: shard h/w
  78. Expectation: compile success
  79. """
  80. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  81. strategy1 = ((1, 1, 2, 4), (1, 1, 1, 1))
  82. strategy2 = ((1, 1, 1, 1),)
  83. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="pad", stride=1, pad=(3, 3, 3, 3),
  84. strategy1=strategy1, strategy2=strategy2)
  85. compile_net(net, _x3)
  86. def test_conv2d_data_parallel_invalid_stride():
  87. """
  88. Feature: test conv2d invalid stride
  89. Description: the first two elements of stride must be 1, but set 2
  90. Expectation: compile failed
  91. """
  92. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  93. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  94. strategy2 = ((8, 1, 1, 1),)
  95. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=(2, 2, 1, 1),
  96. strategy1=strategy1, strategy2=strategy2)
  97. with pytest.raises(RuntimeError):
  98. compile_net(net)
  99. def test_conv2d_data_parallel_dilation():
  100. """
  101. Feature: test conv2d data parallel and dilation is not 1
  102. Description: data parallel and dilation is not 1
  103. Expectation: compile success
  104. """
  105. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  106. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  107. strategy2 = ((8, 1, 1, 1),)
  108. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  109. strategy1=strategy1, strategy2=strategy2)
  110. compile_net(net)
  111. def test_conv2d_data_parallel_group():
  112. """
  113. Feature: test conv2d data parallel and group is not 1
  114. Description: data parallel and group is not 1
  115. Expectation: compile success
  116. """
  117. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  118. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  119. strategy2 = ((8, 1, 1, 1),)
  120. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  121. strategy1=strategy1, strategy2=strategy2)
  122. compile_net(net)
  123. def test_conv2d_model_parallel1():
  124. """
  125. Feature: test conv2d model parallel
  126. Description: split n/c-in/c-out
  127. Expectation: compile success
  128. """
  129. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  130. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  131. strategy2 = ((8, 1, 1, 1),)
  132. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  133. compile_net(net)
  134. def test_conv2d_model_parallel_dilation():
  135. """
  136. Feature: test conv2d model parallel and dilation is not 1
  137. Description: model parallel and dilation is not 1
  138. Expectation: compile success
  139. """
  140. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  141. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  142. strategy2 = ((8, 1, 1, 1),)
  143. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  144. strategy1=strategy1, strategy2=strategy2)
  145. compile_net(net)
  146. def test_conv2d_model_parallel_group():
  147. """
  148. Feature: test conv2d model parallel and group is not 1
  149. Description: split cin and cout, and group is not 1
  150. Expectation: compile failed
  151. """
  152. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  153. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  154. strategy2 = ((8, 1, 1, 1),)
  155. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  156. strategy1=strategy1, strategy2=strategy2)
  157. with pytest.raises(RuntimeError):
  158. compile_net(net)
  159. def test_conv2d_model_parallel_group2():
  160. """
  161. Feature: test conv2d model parallel and group is not 1
  162. Description: has not to split cin and cout, and group is not 1
  163. Expectation: compile success
  164. """
  165. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  166. strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
  167. strategy2 = ((8, 1, 1, 1),)
  168. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  169. strategy1=strategy1, strategy2=strategy2)
  170. compile_net(net)
  171. def test_conv2d_model_parallel2():
  172. """
  173. Feature: same mode, stride = kernel_size, no need exchange
  174. Description: split n/c-in/c-out/h/w
  175. Expectation: compile success
  176. """
  177. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  178. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  179. strategy2 = ((32, 1, 1, 1),)
  180. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  181. compile_net(net)
  182. def test_conv2d_model_parallel3():
  183. """
  184. Feature: same mode, stride < kernel_size, need exchange
  185. Description: split n/w
  186. Expectation: compile success
  187. """
  188. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  189. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  190. strategy2 = ((2, 1, 1, 4),)
  191. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  192. compile_net(net)
  193. def test_conv2d_auto_parallel():
  194. """
  195. Feature: same mode, auto parallel
  196. Description: generate data parallel strategy
  197. Expectation: compile success
  198. """
  199. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  200. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1)
  201. compile_net(net)
  202. def test_conv2d_model_parallel4():
  203. """
  204. Feature: same mode, stride < kernel_size, need exchange
  205. Description: split n/c-in/c-out/w
  206. Expectation: compile success
  207. """
  208. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  209. strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1))
  210. strategy2 = ((2, 2, 1, 4),)
  211. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  212. compile_net(net)
  213. def test_conv2d_left_and_right_no_need_to_send():
  214. """
  215. Feature: same mode, k - s = 1, left pad is 0, single direction exchange
  216. Description: support that the left no need to send
  217. Expectation: compile success
  218. """
  219. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  220. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  221. strategy2 = ((2, 1, 1, 4),)
  222. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  223. compile_net(net)
  224. def test_conv2d_kernel_size_larger_than_stride_and_split_h():
  225. """
  226. Feature: same mode, stride < kernel_size, need exchange
  227. Description: split n/c-in/c-out/h
  228. Expectation: compile success
  229. """
  230. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  231. strategy1 = ((2, 2, 4, 1), (2, 2, 1, 1))
  232. strategy2 = ((2, 2, 4, 1),)
  233. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  234. compile_net(net)
  235. def test_conv2d_valid_mode_kernel_size_larger_than_stride():
  236. """
  237. Feature: valid mode, stride < kernel_size, need exchange
  238. Description: do not support to split w
  239. Expectation: compile failed
  240. """
  241. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  242. strategy1 = ((2, 1, 1, 2), (1, 1, 1, 1))
  243. strategy2 = ((2, 1, 1, 4),)
  244. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="valid", stride=1, strategy1=strategy1, strategy2=strategy2)
  245. with pytest.raises(RuntimeError):
  246. compile_net(net)
  247. def test_conv2d_output_can_not_divisible_by_strategy():
  248. """
  249. Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
  250. Description: split w dimension
  251. Expectation: compile failed
  252. """
  253. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  254. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  255. strategy2 = ((1, 1, 1, 8),)
  256. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  257. with pytest.raises(RuntimeError):
  258. compile_net(net)
  259. def test_conv2d_output_can_not_divisible_by_strategy2():
  260. """
  261. Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
  262. Description: split h dimension
  263. Expectation: compile failed
  264. """
  265. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  266. strategy1 = ((1, 1, 8, 1), (1, 1, 1, 1))
  267. strategy2 = ((1, 1, 1, 8),)
  268. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  269. with pytest.raises(RuntimeError):
  270. compile_net(net)
  271. def test_split_kernel():
  272. """
  273. Feature: split kernel size
  274. Description: do not support to split kernel size
  275. Expectation: compile failed
  276. """
  277. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  278. strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
  279. strategy2 = ((1, 1, 1, 8),)
  280. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  281. with pytest.raises(RuntimeError):
  282. compile_net(net)
  283. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_same_mode():
  284. """
  285. Feature: same mode, slice shape can not be divided by stride
  286. Description: split w
  287. Expectation: compile failed
  288. """
  289. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  290. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  291. strategy2 = ((1, 1, 1, 8),)
  292. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  293. with pytest.raises(RuntimeError):
  294. compile_net(net, _x2)
  295. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  296. """
  297. Feature: valid mode, slice shape can not be divided by stride
  298. Description: split w
  299. Expectation: compile failed
  300. """
  301. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  302. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  303. strategy2 = ((1, 1, 1, 8),)
  304. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  305. with pytest.raises(RuntimeError):
  306. compile_net(net, _x2)
  307. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_by_stride_same_mode():
  308. """
  309. Feature: same mode, slice shape can not be divided by stride
  310. Description: split h
  311. Expectation: compile failed
  312. """
  313. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  314. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  315. strategy2 = ((1, 1, 1, 8),)
  316. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  317. with pytest.raises(RuntimeError):
  318. compile_net(net, _x2)
  319. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  320. """
  321. Feature: valid mode, slice shape can not be divided by stride
  322. Description: split h
  323. Expectation: compile failed
  324. """
  325. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  326. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  327. strategy2 = ((1, 1, 1, 8),)
  328. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  329. with pytest.raises(RuntimeError):
  330. compile_net(net, _x2)
  331. def test_split_h_dimension_and_pad_mode_is_pad():
  332. """
  333. Feature: pad mode
  334. Description: split h
  335. Expectation: compile failed
  336. """
  337. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  338. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  339. strategy2 = ((1, 1, 1, 8),)
  340. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="pad", stride=2, strategy1=strategy1, strategy2=strategy2)
  341. with pytest.raises(RuntimeError):
  342. compile_net(net)
  343. def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride():
  344. """
  345. Feature: same mode, input shape can not be divided by stride
  346. Description: split w
  347. Expectation: compile failed
  348. """
  349. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  350. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  351. strategy2 = ((1, 1, 1, 8),)
  352. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  353. with pytest.raises(RuntimeError):
  354. compile_net(net, _x2)
  355. def test_kernel_size_larger_than_stride_and_slice_too_small():
  356. """
  357. Feature: same mode, slice shape is small than overlap shape
  358. Description: split w
  359. Expectation: compile failed
  360. """
  361. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  362. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  363. strategy2 = ((1, 1, 1, 8),)
  364. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  365. with pytest.raises(RuntimeError):
  366. compile_net(net)
  367. def test_conv2d_dilation():
  368. """
  369. Feature: same mode, dilation is 2
  370. Description: split n/h/w
  371. Expectation: compile success
  372. """
  373. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  374. strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
  375. strategy2 = ((2, 2, 1, 2),)
  376. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, dilation=2, strategy1=strategy1,
  377. strategy2=strategy2)
  378. compile_net(net)
  379. def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():
  380. """
  381. Feature: same mode, slice shape is equal to overlap shape
  382. Description: split w
  383. Expectation: compile success
  384. """
  385. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  386. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  387. strategy2 = ((2, 1, 1, 4),)
  388. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  389. compile_net(net)
  390. def test_kernel_size_larger_than_stride_and_left_pad_is_0():
  391. """
  392. Feature: same mode, kernel_size > stride and left pad is 0, single direction exchange
  393. Description: split w
  394. Expectation: compile success
  395. """
  396. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  397. strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1))
  398. strategy2 = ((1, 1, 1, 8),)
  399. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  400. compile_net(net)
  401. def test_conv2d_kernel_size_larger_than_stride_and_split_nchw():
  402. """
  403. Feature: same mode, stride < kernel_size, need exchange
  404. Description: split n/c-in/c-out/h/w
  405. Expectation: compile success
  406. """
  407. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  408. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  409. strategy2 = ((2, 2, 2, 2),)
  410. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  411. compile_net(net)