You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_conv2d.py 17 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore as ms
  17. from mindspore import context, Tensor, Parameter
  18. from mindspore.common.api import _cell_graph_executor
  19. from mindspore.nn import Cell, TrainOneStepCell, Momentum
  20. from mindspore.ops import operations as P
  21. class Net(Cell):
  22. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1,
  23. strategy1=None, strategy2=None):
  24. super().__init__()
  25. self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
  26. pad_mode=pad_mode, stride=stride, dilation=dilation, group=group).shard(strategy1)
  27. self.neg = P.Neg().shard(strategy2)
  28. self.conv2d_weight = Parameter(conv2d_weight, "w1")
  29. def construct(self, x, b):
  30. out = self.conv2d(x, self.conv2d_weight)
  31. out = self.neg(out)
  32. return out
  33. _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  34. _x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
  35. _w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
  36. _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
  37. _w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
  38. _w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
  39. _w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32)
  40. _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  41. def compile_net(net, input_x=_x):
  42. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  43. train_net = TrainOneStepCell(net, optimizer)
  44. train_net.set_auto_parallel()
  45. train_net.set_train()
  46. _cell_graph_executor.compile(train_net, input_x, _b)
  47. context.reset_auto_parallel_context()
  48. def test_conv2d_data_parallel():
  49. """
  50. Feature: test conv2d data parallel
  51. Description: shard n dimension
  52. Expectation: compile success
  53. """
  54. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  55. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  56. strategy2 = ((8, 1, 1, 1),)
  57. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  58. compile_net(net)
  59. def test_conv2d_data_parallel_invalid_stride():
  60. """
  61. Feature: test conv2d invalid stride
  62. Description: the first two elements of stride must be 1, but set 2
  63. Expectation: compile success
  64. """
  65. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  66. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  67. strategy2 = ((8, 1, 1, 1),)
  68. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=(2, 2, 1, 1),
  69. strategy1=strategy1, strategy2=strategy2)
  70. with pytest.raises(RuntimeError):
  71. compile_net(net)
  72. def test_conv2d_data_parallel_dilation():
  73. """
  74. Feature: test conv2d data parallel and dilation is not 1
  75. Description: data parallel and dilation is not 1
  76. Expectation: compile success
  77. """
  78. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  79. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  80. strategy2 = ((8, 1, 1, 1),)
  81. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  82. strategy1=strategy1, strategy2=strategy2)
  83. compile_net(net)
  84. def test_conv2d_data_parallel_group():
  85. """
  86. Feature: test conv2d data parallel and group is not 1
  87. Description: data parallel and group is not 1
  88. Expectation: compile success
  89. """
  90. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  91. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  92. strategy2 = ((8, 1, 1, 1),)
  93. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  94. strategy1=strategy1, strategy2=strategy2)
  95. compile_net(net)
  96. def test_conv2d_model_parallel1():
  97. """
  98. Feature: test conv2d model parallel
  99. Description: split n/c-in/c-out
  100. Expectation: compile success
  101. """
  102. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  103. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  104. strategy2 = ((8, 1, 1, 1),)
  105. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  106. compile_net(net)
  107. def test_conv2d_model_parallel_dilation():
  108. """
  109. Feature: test conv2d model parallel and dilation is not 1
  110. Description: model parallel and dilation is not 1
  111. Expectation: compile failed
  112. """
  113. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  114. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  115. strategy2 = ((8, 1, 1, 1),)
  116. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2,
  117. strategy1=strategy1, strategy2=strategy2)
  118. with pytest.raises(RuntimeError):
  119. compile_net(net)
  120. def test_conv2d_model_parallel_group():
  121. """
  122. Feature: test conv2d model parallel and group is not 1
  123. Description: model parallel and group is not 1
  124. Expectation: compile failed
  125. """
  126. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  127. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  128. strategy2 = ((8, 1, 1, 1),)
  129. net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2,
  130. strategy1=strategy1, strategy2=strategy2)
  131. with pytest.raises(RuntimeError):
  132. compile_net(net)
  133. def test_conv2d_model_parallel2():
  134. """
  135. Feature: same mode, stride = kernel_size, no need exchange
  136. Description: split n/c-in/c-out/h/w
  137. Expectation: compile success
  138. """
  139. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  140. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  141. strategy2 = ((32, 1, 1, 1),)
  142. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  143. compile_net(net)
  144. def test_conv2d_model_parallel3():
  145. """
  146. Feature: same mode, stride < kernel_size, need exchange
  147. Description: split n/w
  148. Expectation: compile success
  149. """
  150. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  151. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  152. strategy2 = ((2, 1, 1, 4),)
  153. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  154. compile_net(net)
  155. def test_conv2d_auto_parallel():
  156. """
  157. Feature: same mode, auto parallel
  158. Description: generate data parallel strategy
  159. Expectation: compile success
  160. """
  161. context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
  162. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1)
  163. compile_net(net)
  164. def test_conv2d_model_parallel4():
  165. """
  166. Feature: same mode, stride < kernel_size, need exchange
  167. Description: split n/c-in/c-out/w
  168. Expectation: compile success
  169. """
  170. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  171. strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1))
  172. strategy2 = ((2, 2, 1, 4),)
  173. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  174. compile_net(net)
  175. def test_conv2d_left_and_right_no_need_to_send():
  176. """
  177. Feature: same mode, k - s = 1, left pad is 0, single direction exchange
  178. Description: support that the left no need to send
  179. Expectation: compile success
  180. """
  181. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  182. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  183. strategy2 = ((2, 1, 1, 4),)
  184. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  185. compile_net(net)
  186. def test_conv2d_kernel_size_larger_than_stride_and_split_h():
  187. """
  188. Feature: same mode, stride < kernel_size, need exchange
  189. Description: split n/c-in/c-out/h
  190. Expectation: compile success
  191. """
  192. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  193. strategy1 = ((2, 2, 4, 1), (2, 2, 1, 1))
  194. strategy2 = ((2, 2, 4, 1),)
  195. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  196. compile_net(net)
  197. def test_conv2d_valid_mode_kernel_size_larger_than_stride():
  198. """
  199. Feature: valid mode, stride < kernel_size, need exchange
  200. Description: do not support to split w
  201. Expectation: compile failed
  202. """
  203. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  204. strategy1 = ((2, 1, 1, 2), (1, 1, 1, 1))
  205. strategy2 = ((2, 1, 1, 4),)
  206. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="valid", stride=1, strategy1=strategy1, strategy2=strategy2)
  207. with pytest.raises(RuntimeError):
  208. compile_net(net)
  209. def test_conv2d_output_can_not_divisible_by_strategy():
  210. """
  211. Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
  212. Description: split w dimension
  213. Expectation: compile failed
  214. """
  215. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  216. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  217. strategy2 = ((1, 1, 1, 8),)
  218. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  219. with pytest.raises(RuntimeError):
  220. compile_net(net)
  221. def test_conv2d_output_can_not_divisible_by_strategy2():
  222. """
  223. Feature: same mode, stride = kernel_size, but output shape can not be divided by strategy
  224. Description: split h dimension
  225. Expectation: compile failed
  226. """
  227. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  228. strategy1 = ((1, 1, 8, 1), (1, 1, 1, 1))
  229. strategy2 = ((1, 1, 1, 8),)
  230. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  231. with pytest.raises(RuntimeError):
  232. compile_net(net)
  233. def test_split_kernel():
  234. """
  235. Feature: split kernel size
  236. Description: do not support to split kernel size
  237. Expectation: compile failed
  238. """
  239. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  240. strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2))
  241. strategy2 = ((1, 1, 1, 8),)
  242. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
  243. with pytest.raises(RuntimeError):
  244. compile_net(net)
  245. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_same_mode():
  246. """
  247. Feature: same mode, slice shape can not be divided by stride
  248. Description: split w
  249. Expectation: compile failed
  250. """
  251. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  252. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  253. strategy2 = ((1, 1, 1, 8),)
  254. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  255. with pytest.raises(RuntimeError):
  256. compile_net(net, _x2)
  257. def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  258. """
  259. Feature: valid mode, slice shape can not be divided by stride
  260. Description: split w
  261. Expectation: compile failed
  262. """
  263. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  264. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  265. strategy2 = ((1, 1, 1, 8),)
  266. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  267. with pytest.raises(RuntimeError):
  268. compile_net(net, _x2)
  269. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_by_stride_same_mode():
  270. """
  271. Feature: same mode, slice shape can not be divided by stride
  272. Description: split h
  273. Expectation: compile failed
  274. """
  275. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  276. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  277. strategy2 = ((1, 1, 1, 8),)
  278. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  279. with pytest.raises(RuntimeError):
  280. compile_net(net, _x2)
  281. def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode():
  282. """
  283. Feature: valid mode, slice shape can not be divided by stride
  284. Description: split h
  285. Expectation: compile failed
  286. """
  287. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  288. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  289. strategy2 = ((1, 1, 1, 8),)
  290. net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2)
  291. with pytest.raises(RuntimeError):
  292. compile_net(net, _x2)
  293. def test_split_h_dimension_and_pad_mode_is_pad():
  294. """
  295. Feature: pad mode
  296. Description: split h
  297. Expectation: compile failed
  298. """
  299. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  300. strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1))
  301. strategy2 = ((1, 1, 1, 8),)
  302. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="pad", stride=2, strategy1=strategy1, strategy2=strategy2)
  303. with pytest.raises(RuntimeError):
  304. compile_net(net)
  305. def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride():
  306. """
  307. Feature: same mode, input shape can not be divided by stride
  308. Description: split w
  309. Expectation: compile failed
  310. """
  311. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  312. strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1))
  313. strategy2 = ((1, 1, 1, 8),)
  314. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2)
  315. with pytest.raises(RuntimeError):
  316. compile_net(net, _x2)
  317. def test_kernel_size_larger_than_stride_and_slice_too_small():
  318. """
  319. Feature: same mode, slice shape is small than overlap shape
  320. Description: split w
  321. Expectation: compile failed
  322. """
  323. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  324. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  325. strategy2 = ((1, 1, 1, 8),)
  326. net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  327. with pytest.raises(RuntimeError):
  328. compile_net(net)
  329. def test_conv2d_same_mode_overlap_size_equal_to_slice_shape():
  330. """
  331. Feature: same mode, slice shape is equal to overlap shape
  332. Description: split w
  333. Expectation: compile failed
  334. """
  335. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  336. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  337. strategy2 = ((2, 1, 1, 4),)
  338. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  339. with pytest.raises(RuntimeError):
  340. compile_net(net)
  341. def test_kernel_size_larger_than_stride_and_left_pad_is_0():
  342. """
  343. Feature: same mode, kernel_size > stride and left pad is 0, single direction exchange
  344. Description: split w
  345. Expectation: compile success
  346. """
  347. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  348. strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1))
  349. strategy2 = ((1, 1, 1, 8),)
  350. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  351. compile_net(net)
  352. def test_conv2d_kernel_size_larger_than_stride_and_split_nchw():
  353. """
  354. Feature: same mode, stride < kernel_size, need exchange
  355. Description: split n/c-in/c-out/h/w
  356. Expectation: compile success
  357. """
  358. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
  359. strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1))
  360. strategy2 = ((2, 2, 2, 2),)
  361. net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  362. compile_net(net)