You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reduce_method_info.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from mindspore import context
  16. import mindspore.nn as nn
  17. from mindspore.ops import operations as P
  18. from mindspore import Tensor
  19. from tests.ut.python.ops.test_math_ops import VirtualLoss
  20. import mindspore as ms
  21. from mindspore.common.api import _executor
  22. from mindspore.ops import composite as C
  23. class NetWithLoss(nn.Cell):
  24. def __init__(self, network):
  25. super(NetWithLoss, self).__init__()
  26. self.loss = VirtualLoss()
  27. self.network = network
  28. def construct(self, x, y, b):
  29. predict = self.network(x, y, b)
  30. return self.loss(predict)
  31. class GradWrap(nn.Cell):
  32. def __init__(self, network):
  33. super(GradWrap, self).__init__()
  34. self.network = network
  35. def construct(self, x, y, b):
  36. return C.grad_all(self.network)(x, y, b)
  37. # model_parallel test
  38. def test_sum_mul():
  39. class Net(nn.Cell):
  40. def __init__(self, strategy1, strategy2, strategy3):
  41. super().__init__()
  42. self.mul1 = P.Mul().set_strategy(strategy1)
  43. self.reduce_sum = P.ReduceSum(keep_dims=False).set_strategy(strategy2)
  44. self.mul2 = P.Mul().set_strategy(strategy3)
  45. def construct(self, x, y, b):
  46. out = self.mul1(x, y)
  47. out = self.reduce_sum(out, (1,))
  48. out = self.mul2(out, b)
  49. return out
  50. context.set_auto_parallel_context(device_num=8, global_rank=0)
  51. strategy1 = ((1, 1, 8), (1, 1, 8))
  52. strategy2 = ((4, 1, 2), )
  53. strategy3 = ((2, 4), (2, 4))
  54. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  55. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  56. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  57. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  58. b = Tensor(np.ones([128, 64]), dtype=ms.float32)
  59. _executor.compile(net, x, y, b)
  60. def test_sum_mul2():
  61. class Net(nn.Cell):
  62. def __init__(self, strategy1, strategy2, strategy3):
  63. super().__init__()
  64. self.mul1 = P.Mul().set_strategy(strategy1)
  65. self.reduce_sum = P.ReduceSum(keep_dims=False).set_strategy(strategy2)
  66. self.mul2 = P.Mul().set_strategy(strategy3)
  67. def construct(self, x, y, b):
  68. out = self.mul1(x, y)
  69. out = self.reduce_sum(out, (0, 1))
  70. out = self.mul2(out, b)
  71. return out
  72. context.set_auto_parallel_context(device_num=8, global_rank=0)
  73. strategy1 = ((1, 1, 4, 2), (1, 1, 4, 2))
  74. strategy2 = ((2, 4, 1, 1), )
  75. strategy3 = ((2, 4), (2, 4))
  76. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  77. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  78. x = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
  79. y = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
  80. b = Tensor(np.ones([64, 64]), dtype=ms.float32)
  81. _executor.compile(net, x, y, b)
  82. def test_sum_mul3():
  83. class Net(nn.Cell):
  84. def __init__(self, strategy1, strategy2, strategy3):
  85. super().__init__()
  86. self.mul1 = P.Mul().set_strategy(strategy1)
  87. self.reduce_sum = P.ReduceSum(keep_dims=False).set_strategy(strategy2)
  88. self.mul2 = P.Mul().set_strategy(strategy3)
  89. def construct(self, x, y, b):
  90. out = self.mul1(x, y)
  91. out = self.reduce_sum(out, -1)
  92. out = self.mul2(out, b)
  93. return out
  94. context.set_auto_parallel_context(device_num=8, global_rank=0)
  95. strategy1 = ((1, 4, 2), (1, 4, 2))
  96. strategy2 = ((4, 2, 1), )
  97. strategy3 = ((2, 4), (2, 4))
  98. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  99. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  100. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  101. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  102. b = Tensor(np.ones([128, 32]), dtype=ms.float32)
  103. _executor.compile(net, x, y, b)
  104. def test_sum_mul4():
  105. class Net(nn.Cell):
  106. def __init__(self, strategy1, strategy2, strategy3):
  107. super().__init__()
  108. self.mul1 = P.Mul().set_strategy(strategy1)
  109. self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy2)
  110. self.mul2 = P.Mul().set_strategy(strategy3)
  111. def construct(self, x, y, b):
  112. out = self.mul1(x, y)
  113. out = self.reduce_sum(out, -1)
  114. out = self.mul2(out, b)
  115. return out
  116. context.set_auto_parallel_context(device_num=8, global_rank=0)
  117. strategy1 = ((1, 4, 2), (1, 4, 2))
  118. strategy2 = ((2, 2, 2), )
  119. strategy3 = ((4, 2, 1), (4, 2, 1))
  120. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  121. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  122. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  123. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  124. b = Tensor(np.ones([128, 32, 1]), dtype=ms.float32)
  125. _executor.compile(net, x, y, b)
  126. def test_sum_mul5():
  127. class Net(nn.Cell):
  128. def __init__(self, strategy1, strategy2):
  129. super().__init__()
  130. self.mul1 = P.Mul().set_strategy(strategy1)
  131. self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy2)
  132. def construct(self, x, y, b):
  133. out = self.mul1(x, y)
  134. out = self.reduce_sum(out, 0)
  135. return out
  136. context.set_auto_parallel_context(device_num=64, global_rank=0)
  137. strategy1 = ((1, 8, 8), (1, 8, 8))
  138. strategy2 = ((2, 4, 1), )
  139. net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
  140. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  141. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  142. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  143. b = Tensor(np.ones([1, 32, 64]), dtype=ms.float32)
  144. _executor.compile(net, x, y, b)
  145. def test_sum_mul6():
  146. class Net(nn.Cell):
  147. def __init__(self, strategy1, strategy2):
  148. super().__init__()
  149. self.mul1 = P.Mul().set_strategy(strategy1)
  150. self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy2)
  151. def construct(self, x, y, b):
  152. out = self.mul1(x, y)
  153. out = self.reduce_sum(out, 1)
  154. return out
  155. context.set_auto_parallel_context(device_num=64, global_rank=0)
  156. strategy1 = ((1, 8, 8), (1, 8, 8))
  157. strategy2 = ((2, 1, 4), )
  158. net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
  159. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  160. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  161. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  162. b = Tensor(np.ones([128, 1, 64]), dtype=ms.float32)
  163. _executor.compile(net, x, y, b)
  164. def test_sum_mul7():
  165. class Net(nn.Cell):
  166. def __init__(self, strategy1, strategy2):
  167. super().__init__()
  168. self.mul1 = P.Mul().set_strategy(strategy1)
  169. self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy2)
  170. def construct(self, x, y, b):
  171. out = self.mul1(x, y)
  172. out = self.reduce_sum(out, (0, 1))
  173. return out
  174. context.set_auto_parallel_context(device_num=64, global_rank=0)
  175. strategy1 = ((1, 8, 8), (1, 8, 8))
  176. strategy2 = ((2, 4, 1), )
  177. net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
  178. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  179. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  180. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  181. b = Tensor(np.ones([1, 64]), dtype=ms.float32)
  182. _executor.compile(net, x, y, b)
  183. def test_max_mul():
  184. class Net(nn.Cell):
  185. def __init__(self, strategy1, strategy2, strategy3):
  186. super().__init__()
  187. self.mul1 = P.Mul().set_strategy(strategy1)
  188. self.reduce_max = P.ReduceMax(keep_dims=False).set_strategy(strategy2)
  189. self.mul2 = P.Mul().set_strategy(strategy3)
  190. def construct(self, x, y, b):
  191. out = self.mul1(x, y)
  192. out = self.reduce_max(out, -1)
  193. out = self.mul2(out, b)
  194. return out
  195. context.set_auto_parallel_context(device_num=8, global_rank=0)
  196. strategy1 = ((1, 4, 2), (1, 4, 2))
  197. strategy2 = ((4, 1, 2), )
  198. strategy3 = ((2, 4), (2, 4))
  199. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  200. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  201. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  202. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  203. b = Tensor(np.ones([128, 32]), dtype=ms.float32)
  204. _executor.compile(net, x, y, b)
  205. def test_min_mul():
  206. class Net(nn.Cell):
  207. def __init__(self, strategy1, strategy2, strategy3):
  208. super().__init__()
  209. self.mul1 = P.Mul().set_strategy(strategy1)
  210. self.reduce_min = P.ReduceMin(keep_dims=False).set_strategy(strategy2)
  211. self.mul2 = P.Mul().set_strategy(strategy3)
  212. def construct(self, x, y, b):
  213. out = self.mul1(x, y)
  214. out = self.reduce_min(out, 0)
  215. out = self.mul2(out, b)
  216. return out
  217. context.set_auto_parallel_context(device_num=8, global_rank=0)
  218. strategy1 = ((1, 4, 2), (1, 4, 2))
  219. strategy2 = ((4, 1, 2), )
  220. strategy3 = ((2, 4), (2, 4))
  221. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  222. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  223. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  224. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  225. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  226. _executor.compile(net, x, y, b)
  227. def test_reduce_mean_mul_float32():
  228. class Net(nn.Cell):
  229. def __init__(self, strategy1, strategy2, strategy3):
  230. super().__init__()
  231. self.mul1 = P.Mul().set_strategy(strategy1)
  232. self.reduce_mean = P.ReduceMean(keep_dims=False).set_strategy(strategy2)
  233. self.mul2 = P.Mul().set_strategy(strategy3)
  234. def construct(self, x, y, b):
  235. out = self.mul1(x, y)
  236. out = self.reduce_mean(out, 0)
  237. out = self.mul2(out, b)
  238. return out
  239. context.set_auto_parallel_context(device_num=8, global_rank=0)
  240. strategy1 = ((1, 4, 2), (1, 4, 2))
  241. strategy2 = ((4, 1, 2), )
  242. strategy3 = ((2, 4), (2, 4))
  243. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  244. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  245. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  246. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  247. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  248. _executor.compile(net, x, y, b)
  249. class ArgMaxWithValueNet(nn.Cell):
  250. def __init__(self, strategy1, strategy2, strategy3):
  251. super().__init__()
  252. self.mul1 = P.Mul().set_strategy(strategy1)
  253. self.arg_max_with_value = P.ArgMaxWithValue(keep_dims=False, axis=-1).set_strategy(strategy2)
  254. self.mul2 = P.Mul().set_strategy(strategy3)
  255. def construct(self, x, y, b):
  256. out = self.mul1(x, y)
  257. index, out = self.arg_max_with_value(out)
  258. out = self.mul2(out, b)
  259. return out
  260. class ArgMinWithValueNet(nn.Cell):
  261. def __init__(self, strategy1, strategy2, strategy3):
  262. super().__init__()
  263. self.mul1 = P.Mul().set_strategy(strategy1)
  264. self.arg_min_with_value = P.ArgMinWithValue(keep_dims=False, axis=-1).set_strategy(strategy2)
  265. self.mul2 = P.Mul().set_strategy(strategy3)
  266. def construct(self, x, y, b):
  267. out = self.mul1(x, y)
  268. index, out = self.arg_min_with_value(out)
  269. out = self.mul2(out, b)
  270. return out
  271. def gen_inputs_and_compile(net):
  272. x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
  273. y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
  274. b = Tensor(np.ones([128, 64]), dtype=ms.float32)
  275. _executor.compile(net, x, y, b)
  276. def tobefixed_test_arg_max_with_value_mul_semi_axis_parallel():
  277. context.set_auto_parallel_context(device_num=8, global_rank=0)
  278. strategy1 = ((1, 4, 2), (1, 4, 2))
  279. strategy2 = ((4, 1, 2), )
  280. strategy3 = ((2, 4), (2, 4))
  281. net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
  282. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  283. gen_inputs_and_compile(net)
  284. def test_arg_max_with_value_mul_semi():
  285. context.set_auto_parallel_context(device_num=8, global_rank=0)
  286. strategy1 = ((1, 4, 2), (1, 4, 2))
  287. strategy2 = ((4, 1, 1), )
  288. strategy3 = ((2, 4), (2, 4))
  289. net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
  290. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  291. gen_inputs_and_compile(net)
  292. def test_arg_max_with_value_mul_auto():
  293. context.set_auto_parallel_context(device_num=8, global_rank=0)
  294. strategy1 = None
  295. strategy2 = None
  296. strategy3 = None
  297. net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
  298. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  299. gen_inputs_and_compile(net)
  300. def test_arg_min_with_value_mul_semi_axis_parallel():
  301. context.set_auto_parallel_context(device_num=8, global_rank=0)
  302. strategy1 = ((1, 4, 2), (1, 4, 2))
  303. strategy2 = ((4, 1, 2), )
  304. strategy3 = ((2, 4), (2, 4))
  305. net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
  306. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  307. gen_inputs_and_compile(net)
  308. def test_arg_min_with_value_mul_semi():
  309. context.set_auto_parallel_context(device_num=8, global_rank=0)
  310. strategy1 = ((1, 4, 2), (1, 4, 2))
  311. strategy2 = ((4, 1, 1), )
  312. strategy3 = ((2, 4), (2, 4))
  313. net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
  314. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  315. gen_inputs_and_compile(net)
  316. def test_arg_min_with_value_mul_auto():
  317. context.set_auto_parallel_context(device_num=8, global_rank=0)
  318. strategy1 = None
  319. strategy2 = None
  320. strategy3 = None
  321. net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
  322. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  323. gen_inputs_and_compile(net)
  324. class ArgMinWithValueNet2(nn.Cell):
  325. def __init__(self, strategy1, strategy2, strategy3):
  326. super().__init__()
  327. self.mul1 = P.Mul().set_strategy(strategy1)
  328. self.arg_min_with_value = P.ArgMinWithValue(keep_dims=True, axis=-1).set_strategy(strategy2)
  329. self.relu = P.ReLU().set_strategy(strategy3)
  330. def construct(self, x, y, b):
  331. out = self.mul1(x, y)
  332. index, out = self.arg_min_with_value(out)
  333. out = self.relu(out)
  334. return out
  335. def tobefixed_test_arg_min_with_value_mul_semi_axis_parallel2():
  336. context.set_auto_parallel_context(device_num=8, global_rank=0)
  337. strategy1 = ((1, 4, 2), (1, 4, 2))
  338. strategy2 = ((4, 1, 2), )
  339. strategy3 = ((2, 4, 1), )
  340. net = GradWrap(NetWithLoss(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
  341. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  342. gen_inputs_and_compile(net)
  343. def test_arg_min_with_value_mul_semi2():
  344. context.set_auto_parallel_context(device_num=8, global_rank=0)
  345. strategy1 = ((1, 4, 2), (1, 4, 2))
  346. strategy2 = ((4, 1, 1), )
  347. strategy3 = ((2, 4, 1), )
  348. net = GradWrap(NetWithLoss(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
  349. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  350. gen_inputs_and_compile(net)
  351. def test_arg_min_with_value_mul_auto2():
  352. context.set_auto_parallel_context(device_num=8, global_rank=0)
  353. strategy1 = None
  354. strategy2 = None
  355. strategy3 = None
  356. net = GradWrap(NetWithLoss(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
  357. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  358. gen_inputs_and_compile(net)
  359. def test_cross_batch():
  360. class Net(nn.Cell):
  361. def __init__(self, strategy1, strategy2, strategy3):
  362. super().__init__()
  363. self.mul1 = P.Mul().set_strategy(strategy1)
  364. self.reduce_sum = P.ReduceSum(keep_dims=False).set_strategy(strategy2)
  365. self.reduce_mean = P.ReduceMean(keep_dims=False).set_strategy(strategy3).add_prim_attr("cross_batch", True)
  366. def construct(self, x, y, b):
  367. out = self.mul1(x, y)
  368. out = self.reduce_sum(out, -1)
  369. out = self.reduce_mean(out, 0)
  370. return out
  371. context.set_auto_parallel_context(device_num=8, global_rank=0)
  372. strategy1 = ((4, 2), (4, 2))
  373. strategy2 = ((2, 1), )
  374. strategy3 = ((8, ), )
  375. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  376. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  377. x = Tensor(np.ones([32, 64]), dtype=ms.float32)
  378. y = Tensor(np.ones([32, 64]), dtype=ms.float32)
  379. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  380. _executor.compile(net, x, y, b)
  381. def test_cross_batch2():
  382. class Net(nn.Cell):
  383. def __init__(self, strategy1, strategy2, strategy3):
  384. super().__init__()
  385. self.mul1 = P.Mul().set_strategy(strategy1)
  386. self.reduce_mean = P.ReduceMean(keep_dims=False).set_strategy(strategy2)
  387. self.reduce_sum = P.ReduceSum(keep_dims=False).set_strategy(strategy3).add_prim_attr("cross_batch", True)
  388. def construct(self, x, y, b):
  389. out = self.mul1(x, y)
  390. out = self.reduce_mean(out, -1)
  391. out = self.reduce_sum(out, 0)
  392. return out
  393. context.set_auto_parallel_context(device_num=8, global_rank=0)
  394. strategy1 = ((4, 2), (4, 2))
  395. strategy2 = ((2, 1), )
  396. strategy3 = ((8, ), )
  397. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  398. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  399. x = Tensor(np.ones([32, 64]), dtype=ms.float32)
  400. y = Tensor(np.ones([32, 64]), dtype=ms.float32)
  401. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  402. _executor.compile(net, x, y, b)
  403. def test_cross_batch_auto():
  404. class Net(nn.Cell):
  405. def __init__(self):
  406. super().__init__()
  407. self.mul1 = P.Mul()
  408. self.reduce_mean = P.ReduceMean(keep_dims=False)
  409. self.reduce_sum = P.ReduceSum(keep_dims=False).add_prim_attr("cross_batch", True)
  410. def construct(self, x, y, b):
  411. out = self.mul1(x, y)
  412. out = self.reduce_mean(out, -1)
  413. out = self.reduce_sum(out, 0)
  414. return out
  415. context.set_auto_parallel_context(device_num=8, global_rank=0)
  416. net = GradWrap(NetWithLoss(Net()))
  417. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  418. x = Tensor(np.ones([32, 64]), dtype=ms.float32)
  419. y = Tensor(np.ones([32, 64]), dtype=ms.float32)
  420. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  421. _executor.compile(net, x, y, b)
  422. def test_max_empty_tuple():
  423. class Net(nn.Cell):
  424. def __init__(self, strategy1, strategy2, strategy3):
  425. super().__init__()
  426. self.mul = P.Mul().set_strategy(strategy1)
  427. self.reduce_max = P.ReduceMax(keep_dims=False).set_strategy(strategy2)
  428. self.add = P.TensorAdd().set_strategy(strategy3)
  429. def construct(self, x, y, b):
  430. out = self.mul(x, y)
  431. out = self.reduce_max(out)
  432. out = self.add(out, b)
  433. return out
  434. context.set_auto_parallel_context(device_num=8, global_rank=0)
  435. strategy1 = ((1, 4, 2), (1, 4, 2))
  436. strategy2 = ((4, 1, 2), )
  437. strategy3 = ((), (1, 1))
  438. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  439. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  440. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  441. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  442. b = Tensor(np.ones([128, 32]), dtype=ms.float32)
  443. _executor.compile(net, x, y, b)