You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_conv2d.py 18 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore as ms
  17. from mindspore import context, Tensor, Parameter
  18. from mindspore.common.api import _cell_graph_executor
  19. from mindspore.nn import Cell, TrainOneStepCell, Momentum
  20. from mindspore.ops import operations as P
  21. class Net(Cell):
  22. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1,
  23. strategy1=None, strategy2=None):
  24. super().__init__()
  25. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
  26. pad_mode=pad_mode, stride=stride, dilation=dilation, group=group).shard(strategy1)
  27. self.neg = P.Neg().shard(strategy2)
  28. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  29. def construct(self, x, b):
  30. out = self.conv2d(x, self.conv2d_weight)
  31. out = self.neg(out)
  32. return out
  33. _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  34. _x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
  35. _w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
  36. _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
  37. _w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
  38. _w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
  39. _w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
  40. _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  41. def compile_net(net, input_x=_x):
  42. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  43. train_net = TrainOneStepCell(net, optimizer)
  44. train_net.set_auto_parallel()
  45. train_net.set_train()
  46. _cell_graph_executor.compile(train_net, input_x, _b)
  47. context.reset_auto_parallel_context()
  48. def test_conv2d_data_parallel():
  49. """
  50. Feature: test conv2d data parallel
  51. Description: shard n dimension
  52. Expectation: compile success
  53. """
  54. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  55. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  56. strategy2 = ((8, 1, 1, 1),)
  57. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  58. compile_net(net)
  59. def test_conv2d_data_parallel_invalid_stride():
  60. """
  61. Feature: test conv2d invalid stride
  62. Description: the first two elements of stride must be 1, but set 2
  63. Expectation: compile success
  64. """
  65. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  66. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  67. strategy2 = ((8, 1, 1, 1),)
  68. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=(2, 2, 1, 1),
  69. strategy1=strategy1, strategy2=strategy2)
  70. with pytest.raises(RuntimeError):
  71. compile_net(net)
  72. def test_conv2d_data_parallel_dilation():
  73. """
  74. Feature: test conv2d data parallel and dilation is not 1
  75. Description: data parallel and dilation is not 1
  76. Expectation: compile success
  77. """
  78. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  79. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  80. strategy2 = ((8, 1, 1, 1),)
  81. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  82. strategy1=strategy1, strategy2=strategy2)
  83. compile_net(net)
  84. def test_conv2d_data_parallel_group():
  85. """
  86. Feature: test conv2d data parallel and group is not 1
  87. Description: data parallel and group is not 1
  88. Expectation: compile success
  89. """
  90. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  91. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  92. strategy2 = ((8, 1, 1, 1),)
  93. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  94. strategy1=strategy1, strategy2=strategy2)
  95. compile_net(net)
  96. def test_conv2d_model_parallel1():
  97. """
  98. Feature: test conv2d model parallel
  99. Description: split n/c-in/c-out
  100. Expectation: compile success
  101. """
  102. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  103. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  104. strategy2 = ((8, 1, 1, 1),)
  105. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  106. compile_net(net)
  107. def test_conv2d_model_parallel_dilation():
  108. """
  109. Feature: test conv2d model parallel and dilation is not 1
  110. Description: model parallel and dilation is not 1
  111. Expectation: compile success
  112. """
  113. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  114. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  115. strategy2 = ((8, 1, 1, 1),)
  116. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  117. strategy1=strategy1, strategy2=strategy2)
  118. compile_net(net)
  119. def test_conv2d_model_parallel_group():
  120. """
  121. Feature: test conv2d model parallel and group is not 1
  122. Description: split cin and cout, and group is not 1
  123. Expectation: compile failed
  124. """
  125. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  126. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  127. strategy2 = ((8, 1, 1, 1),)
  128. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  129. strategy1=strategy1, strategy2=strategy2)
  130. with pytest.raises(RuntimeError):
  131. compile_net(net)
  132. def test_conv2d_model_parallel_group2():
  133. """
  134. Feature: test conv2d model parallel and group is not 1
  135. Description: has not to split cin and cout, and group is not 1
  136. Expectation: compile success
  137. """
  138. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  139. strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
  140. strategy2 = ((8, 1, 1, 1),)
  141. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  142. strategy1=strategy1, strategy2=strategy2)
  143. compile_net(net)
  144. def test_conv2d_model_parallel2():
  145. """
  146. Feature: same mode, stride = kernel_size, no need exchange
  147. Description: split n/c-in/c-out/h/w
  148. Expectation: compile success
  149. """
  150. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  151. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  152. strategy2 = ((32, 1, 1, 1),)
  153. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  154. compile_net(net)
  155. def test_conv2d_model_parallel3():
  156. """
  157. Feature: same mode, stride < kernel_size, need exchange
  158. Description: split n/w
  159. Expectation: compile success
  160. """
  161. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  162. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  163. strategy2 = ((2, 1, 1, 4),)
  164. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  165. compile_net(net)
  166. def test_conv2d_auto_parallel():
  167. """
  168. Feature: same mode, auto parallel
  169. Description: generate data parallel strategy
  170. Expectation: compile success
  171. """
  172. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  173. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1)
  174. compile_net(net)
  175. def test_conv2d_model_parallel4():
  176. """
  177. Feature: same mode, stride < kernel_size, need exchange
  178. Description: split n/c-in/c-out/w
  179. Expectation: compile success
  180. """
  181. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  182. strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1))
  183. strategy2 = ((2, 2, 1, 4),)
  184. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  185. compile_net(net)
  186. def test_conv2d_left_and_right_no_need_to_send():
  187. """
  188. Feature: same mode, k - s = 1, left pad is 0, single direction exchange
  189. Description: support that the left no need to send
  190. Expectation: compile success
  191. """
  192. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  193. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  194. strategy2 = ((2, 1, 1, 4),)
  195. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  196. compile_net(net)
  197. def test_conv2d_kernel_size_larger_than_stride_and_split_h():
  198. """
  199. Feature: same mode, stride < kernel_size, need exchange
  200. Description: split n/c-in/c-out/h
  201. Expectation: compile success
  202. """
  203. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  204. strategy1 = ((2, 2, 4, 1), (2, 2, 1, 1))
  205. strategy2 = ((2, 2, 4, 1),)
  206. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  207. compile_net(net)
  208. def test_conv2d_valid_mode_kernel_size_larger_than_stride():
  209. """
  210. Feature: valid mode, stride < kernel_size, need exchange
  211. Description: do not support to split w
  212. Expectation: compile failed
  213. """
  214. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  215. strategy1 = ((2, 1, 1, 2), (1, 1, 1, 1))
  216. strategy2 = ((2, 1, 1, 4),)
  217. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="valid", stride=1, strategy1=strategy1, strategy2=strategy2)
  218. with pytest.raises(RuntimeError):
  219. compile_net(net)
  220. def test_conv2d_output_can_not_divisible_by_strategy():
  221. """
  222. Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
  223. Description: split w dimension
  224. Expectation: compile failed
  225. """
  226. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  227. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  228. strategy2 = ((1, 1, 1, 8),)
  229. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  230. with pytest.raises(RuntimeError):
  231. compile_net(net)
  232. def test_conv2d_output_can_not_divisible_by_strategy2():
  233. """
  234. Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
  235. Description: split h dimension
  236. Expectation: compile failed
  237. """
  238. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  239. strategy1 = ((1, 1, 8, 1), (1, 1, 1, 1))
  240. strategy2 = ((1, 1, 1, 8),)
  241. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  242. with pytest.raises(RuntimeError):
  243. compile_net(net)
  244. def test_split_kernel():
  245. """
  246. Feature: split kernel size
  247. Description: do not support to split kernel size
  248. Expectation: compile failed
  249. """
  250. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  251. strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
  252. strategy2 = ((1, 1, 1, 8),)
  253. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  254. with pytest.raises(RuntimeError):
  255. compile_net(net)
  256. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_same_mode():
  257. """
  258. Feature: same mode, slice shape can not be divided by stride
  259. Description: split w
  260. Expectation: compile failed
  261. """
  262. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  263. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  264. strategy2 = ((1, 1, 1, 8),)
  265. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  266. with pytest.raises(RuntimeError):
  267. compile_net(net, _x2)
  268. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  269. """
  270. Feature: valid mode, slice shape can not be divided by stride
  271. Description: split w
  272. Expectation: compile failed
  273. """
  274. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  275. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  276. strategy2 = ((1, 1, 1, 8),)
  277. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  278. with pytest.raises(RuntimeError):
  279. compile_net(net, _x2)
  280. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_by_stride_same_mode():
  281. """
  282. Feature: same mode, slice shape can not be divided by stride
  283. Description: split h
  284. Expectation: compile failed
  285. """
  286. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  287. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  288. strategy2 = ((1, 1, 1, 8),)
  289. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  290. with pytest.raises(RuntimeError):
  291. compile_net(net, _x2)
  292. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  293. """
  294. Feature: valid mode, slice shape can not be divided by stride
  295. Description: split h
  296. Expectation: compile failed
  297. """
  298. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  299. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  300. strategy2 = ((1, 1, 1, 8),)
  301. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  302. with pytest.raises(RuntimeError):
  303. compile_net(net, _x2)
  304. def test_split_h_dimension_and_pad_mode_is_pad():
  305. """
  306. Feature: pad mode
  307. Description: split h
  308. Expectation: compile failed
  309. """
  310. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  311. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  312. strategy2 = ((1, 1, 1, 8),)
  313. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="pad", stride=2, strategy1=strategy1, strategy2=strategy2)
  314. with pytest.raises(RuntimeError):
  315. compile_net(net)
  316. def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride():
  317. """
  318. Feature: same mode, input shape can not be divided by stride
  319. Description: split w
  320. Expectation: compile failed
  321. """
  322. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  323. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  324. strategy2 = ((1, 1, 1, 8),)
  325. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  326. with pytest.raises(RuntimeError):
  327. compile_net(net, _x2)
  328. def test_kernel_size_larger_than_stride_and_slice_too_small():
  329. """
  330. Feature: same mode, slice shape is small than overlap shape
  331. Description: split w
  332. Expectation: compile failed
  333. """
  334. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  335. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  336. strategy2 = ((1, 1, 1, 8),)
  337. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  338. with pytest.raises(RuntimeError):
  339. compile_net(net)
  340. def test_conv2d_dilation():
  341. """
  342. Feature: same mode, dilation is 2
  343. Description: split n/h/w
  344. Expectation: compile success
  345. """
  346. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  347. strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
  348. strategy2 = ((2, 2, 1, 2),)
  349. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, dilation=2, strategy1=strategy1,
  350. strategy2=strategy2)
  351. compile_net(net)
  352. def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():
  353. """
  354. Feature: same mode, slice shape is equal to overlap shape
  355. Description: split w
  356. Expectation: compile failed
  357. """
  358. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  359. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  360. strategy2 = ((2, 1, 1, 4),)
  361. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  362. with pytest.raises(RuntimeError):
  363. compile_net(net)
  364. def test_kernel_size_larger_than_stride_and_left_pad_is_0():
  365. """
  366. Feature: same mode, kernel_size > stride and left pad is 0, single direction exchange
  367. Description: split w
  368. Expectation: compile success
  369. """
  370. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  371. strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1))
  372. strategy2 = ((1, 1, 1, 8),)
  373. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  374. compile_net(net)
  375. def test_conv2d_kernel_size_larger_than_stride_and_split_nchw():
  376. """
  377. Feature: same mode, stride < kernel_size, need exchange
  378. Description: split n/c-in/c-out/h/w
  379. Expectation: compile success
  380. """
  381. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  382. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  383. strategy2 = ((2, 2, 2, 2),)
  384. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  385. compile_net(net)