You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_conv2d_transpose.py 12 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore as ms
  17. from mindspore import context, Tensor, Parameter
  18. from mindspore.common.api import _cell_graph_executor
  19. from mindspore.nn import Cell, TrainOneStepCell, Momentum
  20. from mindspore.ops import operations as P
  21. class Net(Cell):
  22. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
  23. strategy1=None, strategy2=None):
  24. super().__init__()
  25. self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size,
  26. pad_mode=pad_mode, stride=stride).shard(strategy1)
  27. self.neg = P.Neg().shard(strategy2)
  28. self.weight = Parameter(conv2d_weight, "w1")
  29. self.add = P.Add()
  30. self.add_w = Parameter(Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32), "add_w")
  31. def construct(self, x, b):
  32. out = self.add(x, self.add_w)
  33. out = self.conv2d_transpose(out, self.weight, (32, 16, 8, 8))
  34. out = self.neg(out)
  35. return out
  36. class Net2(Cell):
  37. def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pad=0, group=1, dilation=1,
  38. strategy1=None, strategy2=None):
  39. super().__init__()
  40. self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size, pad_mode=pad_mode,
  41. stride=stride, pad=pad, group=group,
  42. dilation=dilation).shard(strategy1)
  43. self.neg = P.Neg().shard(strategy2)
  44. self.weight = Parameter(conv2d_weight, "w1")
  45. def construct(self, x, b):
  46. out = self.conv2d_transpose(x, self.weight, (32, 16, 16, 16))
  47. out = self.neg(out)
  48. return out
  49. _x = Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32)
  50. _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
  51. _w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
  52. _w3 = Tensor(np.ones([8, 16, 10, 10]), dtype=ms.float32)
  53. _w4 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
  54. _w5 = Tensor(np.ones([8, 8, 4, 4]), dtype=ms.float32)
  55. _w6 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
  56. _w7 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32)
  57. _w8 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
  58. _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
  59. def compile_net(net):
  60. optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
  61. train_net = TrainOneStepCell(net, optimizer)
  62. train_net.set_auto_parallel()
  63. train_net.set_train()
  64. _cell_graph_executor.compile(train_net, _x, _b)
  65. context.reset_auto_parallel_context()
  66. def test_conv2d_transpose_data_parallel():
  67. """
  68. Feature: test data parallel strategy
  69. Description: only shard batch dimension
  70. Expectation: compile success
  71. """
  72. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  73. strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
  74. strategy2 = ((8, 1, 1, 1),)
  75. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  76. compile_net(net)
  77. def test_conv2d_transpose_group():
  78. """
  79. Feature: test group is not 1
  80. Description: shard n/h/w, and group is 2
  81. Expectation: compile success
  82. """
  83. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  84. strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
  85. strategy2 = ((8, 1, 1, 1),)
  86. net = Net2(_w5, out_channel=8, kernel_size=4, pad_mode="same", stride=2, group=2, strategy1=strategy1,
  87. strategy2=strategy2)
  88. compile_net(net)
  89. def test_conv2d_transpose_model_parallel1():
  90. """
  91. Feature: test model parallel strategy
  92. Description: only shard batch dimension and channel dimension
  93. Expectation: compile success
  94. """
  95. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  96. strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
  97. strategy2 = ((8, 1, 1, 1),)
  98. net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
  99. compile_net(net)
  100. def test_conv2d_transpose_model_parallel2():
  101. """
  102. Feature: test model parallel strategy
  103. Description: shard batch dimension and w dimension
  104. Expectation: compile success
  105. """
  106. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  107. strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
  108. strategy2 = ((2, 1, 1, 4),)
  109. net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
  110. strategy1=strategy1, strategy2=strategy2)
  111. compile_net(net)
  112. def test_conv2d_transpose_model_parallel_dilation():
  113. """
  114. Feature: test model parallel strategy and dilation is 2
  115. Description: shard n/h/w
  116. Expectation: compile success
  117. """
  118. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  119. strategy1 = ((2, 1, 2, 2), (1, 1, 1, 1))
  120. strategy2 = ((2, 1, 2, 2),)
  121. net = Net2(_w4, out_channel=8, kernel_size=(3, 3), pad_mode="same", stride=2, dilation=2,
  122. strategy1=strategy1, strategy2=strategy2)
  123. compile_net(net)
  124. def test_conv2d_transpose_model_parallel3():
  125. """
  126. Feature: test model parallel strategy
  127. Description: shard batch dimension, channel dimension and w dimension
  128. Expectation: compile success
  129. """
  130. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  131. strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
  132. strategy2 = ((2, 2, 1, 4),)
  133. net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
  134. strategy1=strategy1, strategy2=strategy2)
  135. compile_net(net)
  136. def test_conv2d_transpose_model_parallel4():
  137. """
  138. Feature: test model parallel strategy
  139. Description: shard h dimension and w dimension
  140. Expectation: compile success
  141. """
  142. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  143. strategy1 = ((1, 1, 2, 4), (1, 1, 1, 1))
  144. strategy2 = ((2, 2, 1, 4),)
  145. net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
  146. strategy1=strategy1, strategy2=strategy2)
  147. compile_net(net)
  148. def test_conv2d_transpose_all_rank_no_need_overlap():
  149. """
  150. Feature: test model parallel strategy
  151. Description: shard batch dimension, channel dimension and w dimension
  152. Expectation: compile success
  153. """
  154. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  155. strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
  156. strategy2 = ((2, 2, 1, 4),)
  157. net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2,
  158. strategy1=strategy1, strategy2=strategy2)
  159. compile_net(net)
  160. def test_conv2d_transpose_split_h_or_w_in_pad_mode():
  161. """
  162. Feature: test pad mode
  163. Description: shard batch dimension, channel dimension and w dimension in pad mode
  164. Expectation: compile success
  165. """
  166. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  167. strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
  168. strategy2 = ((2, 2, 1, 4),)
  169. net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="pad", stride=2,
  170. strategy1=strategy1, strategy2=strategy2)
  171. compile_net(net)
  172. def test_conv2d_transpose_split_h_in_same_mode():
  173. """
  174. Feature: test split h dimension
  175. Description: shard h dimension in same mode
  176. Expectation: compile success
  177. """
  178. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
  179. strategy1 = ((2, 2, 4, 1), (2, 1, 1, 1))
  180. strategy2 = ((2, 2, 4, 1),)
  181. net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2,
  182. strategy1=strategy1, strategy2=strategy2)
  183. compile_net(net)
  184. def test_conv2d_transpose_overlap_size_too_large():
  185. """
  186. Feature: test overlap size is too large
  187. Description: shard w dimension and overlap size larger than slice shape
  188. Expectation: compile failed
  189. """
  190. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
  191. strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
  192. strategy2 = ((1, 1, 1, 8),)
  193. net = Net2(_w3, out_channel=8, kernel_size=(10, 10), pad_mode="same", stride=2,
  194. strategy1=strategy1, strategy2=strategy2)
  195. with pytest.raises(RuntimeError):
  196. compile_net(net)
  197. def test_conv2d_transpose_pad_mode_no_need_exchange():
  198. """
  199. Feature: pad mode, and two direction send, w = 8, o = 16, s = 2, k = 1, n = 8, pad = (0, 0, 0, 0)
  200. Description: shard h and w dimension
  201. Expectation: compile success
  202. """
  203. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=64, global_rank=13)
  204. strategy1 = ((1, 1, 8, 8), (1, 1, 1, 1))
  205. strategy2 = ((8, 1, 1, 1),)
  206. net = Net2(_w7, out_channel=8, kernel_size=1, pad_mode="pad", pad=(0, 0, 0, 0), stride=2, strategy1=strategy1,
  207. strategy2=strategy2)
  208. compile_net(net)
  209. def test_conv2d_transpose_pad_mode_two_direction_send_all_slice_pad_different():
  210. """
  211. Feature: pad mode, and two direction send, w = 8, o = 16, s = 2, k = 5, n = 8, pad = (1, 2, 1, 2)
  212. Description: shard h and w dimension
  213. Expectation: compile success
  214. """
  215. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=64, global_rank=13)
  216. strategy1 = ((1, 1, 8, 8), (1, 1, 1, 1))
  217. strategy2 = ((8, 1, 1, 1),)
  218. net = Net2(_w6, out_channel=8, kernel_size=5, pad_mode="pad", pad=(1, 2, 1, 2), stride=2, strategy1=strategy1,
  219. strategy2=strategy2)
  220. compile_net(net)
  221. def test_conv2d_transpose_pad_mode_two_direction_send_all_slice():
  222. """
  223. Feature: pad mode, and two direction send, w = 8, o = 16, s = 2, k = 4, n = 8, pad = (1, 1, 1, 1)
  224. Description: shard h and w dimension
  225. Expectation: compile success
  226. """
  227. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=64, global_rank=13)
  228. strategy1 = ((1, 1, 8, 8), (1, 1, 1, 1))
  229. strategy2 = ((8, 1, 1, 1),)
  230. net = Net2(_w8, out_channel=8, kernel_size=4, pad_mode="pad", pad=(1, 1, 1, 1), stride=2, strategy1=strategy1,
  231. strategy2=strategy2)
  232. compile_net(net)
  233. def test_conv2d_transpose_pad_mode_single_direction_send():
  234. """
  235. Feature: pad mode, and single direction send, w = 8, o = 16, s = 2, k = 3, n = 8, pad = (0, 1, 0, 1)
  236. Description: shard h and w dimension
  237. Expectation: compile success
  238. """
  239. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=64, global_rank=13)
  240. strategy1 = ((1, 1, 8, 8), (1, 1, 1, 1))
  241. strategy2 = ((8, 1, 1, 1),)
  242. net = Net2(_w4, out_channel=8, kernel_size=3, pad_mode="pad", pad=(0, 1, 0, 1), stride=2, strategy1=strategy1,
  243. strategy2=strategy2)
  244. compile_net(net)