You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reduce_method_info.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from mindspore import context
  16. import mindspore.nn as nn
  17. from mindspore.ops import operations as P
  18. from mindspore import Tensor
  19. from tests.ut.python.ops.test_math_ops import VirtualLoss
  20. import mindspore as ms
  21. from mindspore.common.api import _executor
  22. from mindspore.ops import composite as C
  23. class NetWithLoss(nn.Cell):
  24. def __init__(self, network):
  25. super(NetWithLoss, self).__init__()
  26. self.loss = VirtualLoss()
  27. self.network = network
  28. def construct(self, x, y, b):
  29. predict = self.network(x, y, b)
  30. return self.loss(predict)
  31. class GradWrap(nn.Cell):
  32. def __init__(self, network):
  33. super(GradWrap, self).__init__()
  34. self.network = network
  35. def construct(self, x, y, b):
  36. return C.grad_all(self.network)(x, y, b)
  37. def compile(net, x, y, b):
  38. net.set_auto_parallel()
  39. _executor.compile(net, x, y, b)
  40. # model_parallel test
  41. def test_sum_mul():
  42. class Net(nn.Cell):
  43. def __init__(self, strategy1, strategy2, strategy3):
  44. super().__init__()
  45. self.mul1 = P.Mul().set_strategy(strategy1)
  46. self.reduce_sum = P.ReduceSum(keep_dims=False).set_strategy(strategy2)
  47. self.mul2 = P.Mul().set_strategy(strategy3)
  48. def construct(self, x, y, b):
  49. out = self.mul1(x, y)
  50. out = self.reduce_sum(out, (1,))
  51. out = self.mul2(out, b)
  52. return out
  53. context.set_auto_parallel_context(device_num=8, global_rank=0)
  54. strategy1 = ((1, 1, 8), (1, 1, 8))
  55. strategy2 = ((4, 1, 2), )
  56. strategy3 = ((2, 4), (2, 4))
  57. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  58. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  59. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  60. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  61. b = Tensor(np.ones([128, 64]), dtype=ms.float32)
  62. compile(net, x, y, b)
  63. def test_sum_mul2():
  64. class Net(nn.Cell):
  65. def __init__(self, strategy1, strategy2, strategy3):
  66. super().__init__()
  67. self.mul1 = P.Mul().set_strategy(strategy1)
  68. self.reduce_sum = P.ReduceSum(keep_dims=False).set_strategy(strategy2)
  69. self.mul2 = P.Mul().set_strategy(strategy3)
  70. def construct(self, x, y, b):
  71. out = self.mul1(x, y)
  72. out = self.reduce_sum(out, (0, 1))
  73. out = self.mul2(out, b)
  74. return out
  75. context.set_auto_parallel_context(device_num=8, global_rank=0)
  76. strategy1 = ((1, 1, 4, 2), (1, 1, 4, 2))
  77. strategy2 = ((2, 4, 1, 1), )
  78. strategy3 = ((2, 4), (2, 4))
  79. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  80. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  81. x = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
  82. y = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
  83. b = Tensor(np.ones([64, 64]), dtype=ms.float32)
  84. compile(net, x, y, b)
  85. def test_sum_mul3():
  86. class Net(nn.Cell):
  87. def __init__(self, strategy1, strategy2, strategy3):
  88. super().__init__()
  89. self.mul1 = P.Mul().set_strategy(strategy1)
  90. self.reduce_sum = P.ReduceSum(keep_dims=False).set_strategy(strategy2)
  91. self.mul2 = P.Mul().set_strategy(strategy3)
  92. def construct(self, x, y, b):
  93. out = self.mul1(x, y)
  94. out = self.reduce_sum(out, -1)
  95. out = self.mul2(out, b)
  96. return out
  97. context.set_auto_parallel_context(device_num=8, global_rank=0)
  98. strategy1 = ((1, 4, 2), (1, 4, 2))
  99. strategy2 = ((4, 2, 1), )
  100. strategy3 = ((2, 4), (2, 4))
  101. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  102. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  103. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  104. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  105. b = Tensor(np.ones([128, 32]), dtype=ms.float32)
  106. compile(net, x, y, b)
  107. def test_sum_mul4():
  108. class Net(nn.Cell):
  109. def __init__(self, strategy1, strategy2, strategy3):
  110. super().__init__()
  111. self.mul1 = P.Mul().set_strategy(strategy1)
  112. self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy2)
  113. self.mul2 = P.Mul().set_strategy(strategy3)
  114. def construct(self, x, y, b):
  115. out = self.mul1(x, y)
  116. out = self.reduce_sum(out, -1)
  117. out = self.mul2(out, b)
  118. return out
  119. context.set_auto_parallel_context(device_num=8, global_rank=0)
  120. strategy1 = ((1, 4, 2), (1, 4, 2))
  121. strategy2 = ((2, 2, 2), )
  122. strategy3 = ((4, 2, 1), (4, 2, 1))
  123. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  124. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  125. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  126. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  127. b = Tensor(np.ones([128, 32, 1]), dtype=ms.float32)
  128. compile(net, x, y, b)
  129. def test_sum_mul5():
  130. class Net(nn.Cell):
  131. def __init__(self, strategy1, strategy2):
  132. super().__init__()
  133. self.mul1 = P.Mul().set_strategy(strategy1)
  134. self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy2)
  135. def construct(self, x, y, b):
  136. out = self.mul1(x, y)
  137. out = self.reduce_sum(out, 0)
  138. return out
  139. context.set_auto_parallel_context(device_num=64, global_rank=0)
  140. strategy1 = ((1, 8, 8), (1, 8, 8))
  141. strategy2 = ((2, 4, 1), )
  142. net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
  143. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  144. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  145. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  146. b = Tensor(np.ones([1, 32, 64]), dtype=ms.float32)
  147. compile(net, x, y, b)
  148. def test_sum_mul6():
  149. class Net(nn.Cell):
  150. def __init__(self, strategy1, strategy2):
  151. super().__init__()
  152. self.mul1 = P.Mul().set_strategy(strategy1)
  153. self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy2)
  154. def construct(self, x, y, b):
  155. out = self.mul1(x, y)
  156. out = self.reduce_sum(out, 1)
  157. return out
  158. context.set_auto_parallel_context(device_num=64, global_rank=0)
  159. strategy1 = ((1, 8, 8), (1, 8, 8))
  160. strategy2 = ((2, 1, 4), )
  161. net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
  162. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  163. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  164. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  165. b = Tensor(np.ones([128, 1, 64]), dtype=ms.float32)
  166. compile(net, x, y, b)
  167. def test_sum_mul7():
  168. class Net(nn.Cell):
  169. def __init__(self, strategy1, strategy2):
  170. super().__init__()
  171. self.mul1 = P.Mul().set_strategy(strategy1)
  172. self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy2)
  173. def construct(self, x, y, b):
  174. out = self.mul1(x, y)
  175. out = self.reduce_sum(out, (0, 1))
  176. return out
  177. context.set_auto_parallel_context(device_num=64, global_rank=0)
  178. strategy1 = ((1, 8, 8), (1, 8, 8))
  179. strategy2 = ((2, 4, 1), )
  180. net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
  181. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  182. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  183. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  184. b = Tensor(np.ones([1, 64]), dtype=ms.float32)
  185. compile(net, x, y, b)
  186. def test_max_mul():
  187. class Net(nn.Cell):
  188. def __init__(self, strategy1, strategy2, strategy3):
  189. super().__init__()
  190. self.mul1 = P.Mul().set_strategy(strategy1)
  191. self.reduce_max = P.ReduceMax(keep_dims=False).set_strategy(strategy2)
  192. self.mul2 = P.Mul().set_strategy(strategy3)
  193. def construct(self, x, y, b):
  194. out = self.mul1(x, y)
  195. out = self.reduce_max(out, -1)
  196. out = self.mul2(out, b)
  197. return out
  198. context.set_auto_parallel_context(device_num=8, global_rank=0)
  199. strategy1 = ((1, 4, 2), (1, 4, 2))
  200. strategy2 = ((4, 1, 2), )
  201. strategy3 = ((2, 4), (2, 4))
  202. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  203. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  204. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  205. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  206. b = Tensor(np.ones([128, 32]), dtype=ms.float32)
  207. compile(net, x, y, b)
  208. def test_min_mul():
  209. class Net(nn.Cell):
  210. def __init__(self, strategy1, strategy2, strategy3):
  211. super().__init__()
  212. self.mul1 = P.Mul().set_strategy(strategy1)
  213. self.reduce_min = P.ReduceMin(keep_dims=False).set_strategy(strategy2)
  214. self.mul2 = P.Mul().set_strategy(strategy3)
  215. def construct(self, x, y, b):
  216. out = self.mul1(x, y)
  217. out = self.reduce_min(out, 0)
  218. out = self.mul2(out, b)
  219. return out
  220. context.set_auto_parallel_context(device_num=8, global_rank=0)
  221. strategy1 = ((1, 4, 2), (1, 4, 2))
  222. strategy2 = ((4, 1, 2), )
  223. strategy3 = ((2, 4), (2, 4))
  224. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  225. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  226. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  227. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  228. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  229. compile(net, x, y, b)
  230. def test_reduce_mean_mul_float32():
  231. class Net(nn.Cell):
  232. def __init__(self, strategy1, strategy2, strategy3):
  233. super().__init__()
  234. self.mul1 = P.Mul().set_strategy(strategy1)
  235. self.reduce_mean = P.ReduceMean(keep_dims=False).set_strategy(strategy2)
  236. self.mul2 = P.Mul().set_strategy(strategy3)
  237. def construct(self, x, y, b):
  238. out = self.mul1(x, y)
  239. out = self.reduce_mean(out, 0)
  240. out = self.mul2(out, b)
  241. return out
  242. context.set_auto_parallel_context(device_num=8, global_rank=0)
  243. strategy1 = ((1, 4, 2), (1, 4, 2))
  244. strategy2 = ((4, 1, 2), )
  245. strategy3 = ((2, 4), (2, 4))
  246. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  247. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  248. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  249. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  250. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  251. compile(net, x, y, b)
  252. class ArgMaxWithValueNet(nn.Cell):
  253. def __init__(self, strategy1, strategy2, strategy3):
  254. super().__init__()
  255. self.mul1 = P.Mul().set_strategy(strategy1)
  256. self.arg_max_with_value = P.ArgMaxWithValue(keep_dims=False, axis=-1).set_strategy(strategy2)
  257. self.mul2 = P.Mul().set_strategy(strategy3)
  258. def construct(self, x, y, b):
  259. out = self.mul1(x, y)
  260. index, out = self.arg_max_with_value(out)
  261. out = self.mul2(out, b)
  262. return out
  263. class ArgMinWithValueNet(nn.Cell):
  264. def __init__(self, strategy1, strategy2, strategy3):
  265. super().__init__()
  266. self.mul1 = P.Mul().set_strategy(strategy1)
  267. self.arg_min_with_value = P.ArgMinWithValue(keep_dims=False, axis=-1).set_strategy(strategy2)
  268. self.mul2 = P.Mul().set_strategy(strategy3)
  269. def construct(self, x, y, b):
  270. out = self.mul1(x, y)
  271. index, out = self.arg_min_with_value(out)
  272. out = self.mul2(out, b)
  273. return out
  274. def gen_inputs_and_compile(net):
  275. x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
  276. y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
  277. b = Tensor(np.ones([128, 64]), dtype=ms.float32)
  278. compile(net, x, y, b)
  279. def tobefixed_test_arg_max_with_value_mul_semi_axis_parallel():
  280. context.set_auto_parallel_context(device_num=8, global_rank=0)
  281. strategy1 = ((1, 4, 2), (1, 4, 2))
  282. strategy2 = ((4, 1, 2), )
  283. strategy3 = ((2, 4), (2, 4))
  284. net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
  285. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  286. gen_inputs_and_compile(net)
  287. def test_arg_max_with_value_mul_semi():
  288. context.set_auto_parallel_context(device_num=8, global_rank=0)
  289. strategy1 = ((1, 4, 2), (1, 4, 2))
  290. strategy2 = ((4, 1, 1), )
  291. strategy3 = ((2, 4), (2, 4))
  292. net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
  293. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  294. gen_inputs_and_compile(net)
  295. def test_arg_max_with_value_mul_auto():
  296. context.set_auto_parallel_context(device_num=8, global_rank=0)
  297. strategy1 = None
  298. strategy2 = None
  299. strategy3 = None
  300. net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
  301. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  302. gen_inputs_and_compile(net)
  303. def test_arg_min_with_value_mul_semi_axis_parallel():
  304. context.set_auto_parallel_context(device_num=8, global_rank=0)
  305. strategy1 = ((1, 4, 2), (1, 4, 2))
  306. strategy2 = ((4, 1, 2), )
  307. strategy3 = ((2, 4), (2, 4))
  308. net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
  309. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  310. gen_inputs_and_compile(net)
  311. def test_arg_min_with_value_mul_semi():
  312. context.set_auto_parallel_context(device_num=8, global_rank=0)
  313. strategy1 = ((1, 4, 2), (1, 4, 2))
  314. strategy2 = ((4, 1, 1), )
  315. strategy3 = ((2, 4), (2, 4))
  316. net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
  317. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  318. gen_inputs_and_compile(net)
  319. def test_arg_min_with_value_mul_auto():
  320. context.set_auto_parallel_context(device_num=8, global_rank=0)
  321. strategy1 = None
  322. strategy2 = None
  323. strategy3 = None
  324. net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
  325. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  326. gen_inputs_and_compile(net)
  327. class ArgMinWithValueNet2(nn.Cell):
  328. def __init__(self, strategy1, strategy2, strategy3):
  329. super().__init__()
  330. self.mul1 = P.Mul().set_strategy(strategy1)
  331. self.arg_min_with_value = P.ArgMinWithValue(keep_dims=True, axis=-1).set_strategy(strategy2)
  332. self.relu = P.ReLU().set_strategy(strategy3)
  333. def construct(self, x, y, b):
  334. out = self.mul1(x, y)
  335. index, out = self.arg_min_with_value(out)
  336. out = self.relu(out)
  337. return out
  338. def tobefixed_test_arg_min_with_value_mul_semi_axis_parallel2():
  339. context.set_auto_parallel_context(device_num=8, global_rank=0)
  340. strategy1 = ((1, 4, 2), (1, 4, 2))
  341. strategy2 = ((4, 1, 2), )
  342. strategy3 = ((2, 4, 1), )
  343. net = GradWrap(NetWithLoss(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
  344. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  345. gen_inputs_and_compile(net)
  346. def test_arg_min_with_value_mul_semi2():
  347. context.set_auto_parallel_context(device_num=8, global_rank=0)
  348. strategy1 = ((1, 4, 2), (1, 4, 2))
  349. strategy2 = ((4, 1, 1), )
  350. strategy3 = ((2, 4, 1), )
  351. net = GradWrap(NetWithLoss(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
  352. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  353. gen_inputs_and_compile(net)
  354. def test_arg_min_with_value_mul_auto2():
  355. context.set_auto_parallel_context(device_num=8, global_rank=0)
  356. strategy1 = None
  357. strategy2 = None
  358. strategy3 = None
  359. net = GradWrap(NetWithLoss(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
  360. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  361. gen_inputs_and_compile(net)
  362. def test_cross_batch():
  363. class Net(nn.Cell):
  364. def __init__(self, strategy1, strategy2, strategy3):
  365. super().__init__()
  366. self.mul1 = P.Mul().set_strategy(strategy1)
  367. self.reduce_sum = P.ReduceSum(keep_dims=False).set_strategy(strategy2)
  368. self.reduce_mean = P.ReduceMean(keep_dims=False).set_strategy(strategy3).add_prim_attr("cross_batch", True)
  369. def construct(self, x, y, b):
  370. out = self.mul1(x, y)
  371. out = self.reduce_sum(out, -1)
  372. out = self.reduce_mean(out, 0)
  373. return out
  374. context.set_auto_parallel_context(device_num=8, global_rank=0)
  375. strategy1 = ((4, 2), (4, 2))
  376. strategy2 = ((2, 1), )
  377. strategy3 = ((8, ), )
  378. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  379. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  380. x = Tensor(np.ones([32, 64]), dtype=ms.float32)
  381. y = Tensor(np.ones([32, 64]), dtype=ms.float32)
  382. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  383. compile(net, x, y, b)
  384. def test_cross_batch2():
  385. class Net(nn.Cell):
  386. def __init__(self, strategy1, strategy2, strategy3):
  387. super().__init__()
  388. self.mul1 = P.Mul().set_strategy(strategy1)
  389. self.reduce_mean = P.ReduceMean(keep_dims=False).set_strategy(strategy2)
  390. self.reduce_sum = P.ReduceSum(keep_dims=False).set_strategy(strategy3).add_prim_attr("cross_batch", True)
  391. def construct(self, x, y, b):
  392. out = self.mul1(x, y)
  393. out = self.reduce_mean(out, -1)
  394. out = self.reduce_sum(out, 0)
  395. return out
  396. context.set_auto_parallel_context(device_num=8, global_rank=0)
  397. strategy1 = ((4, 2), (4, 2))
  398. strategy2 = ((2, 1), )
  399. strategy3 = ((8, ), )
  400. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  401. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  402. x = Tensor(np.ones([32, 64]), dtype=ms.float32)
  403. y = Tensor(np.ones([32, 64]), dtype=ms.float32)
  404. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  405. compile(net, x, y, b)
  406. def test_cross_batch_auto():
  407. class Net(nn.Cell):
  408. def __init__(self):
  409. super().__init__()
  410. self.mul1 = P.Mul()
  411. self.reduce_mean = P.ReduceMean(keep_dims=False)
  412. self.reduce_sum = P.ReduceSum(keep_dims=False).add_prim_attr("cross_batch", True)
  413. def construct(self, x, y, b):
  414. out = self.mul1(x, y)
  415. out = self.reduce_mean(out, -1)
  416. out = self.reduce_sum(out, 0)
  417. return out
  418. context.set_auto_parallel_context(device_num=8, global_rank=0)
  419. net = GradWrap(NetWithLoss(Net()))
  420. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  421. x = Tensor(np.ones([32, 64]), dtype=ms.float32)
  422. y = Tensor(np.ones([32, 64]), dtype=ms.float32)
  423. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  424. compile(net, x, y, b)
  425. def test_max_empty_tuple():
  426. class Net(nn.Cell):
  427. def __init__(self, strategy1, strategy2, strategy3):
  428. super().__init__()
  429. self.mul = P.Mul().set_strategy(strategy1)
  430. self.reduce_max = P.ReduceMax(keep_dims=False).set_strategy(strategy2)
  431. self.add = P.TensorAdd().set_strategy(strategy3)
  432. def construct(self, x, y, b):
  433. out = self.mul(x, y)
  434. out = self.reduce_max(out)
  435. out = self.add(out, b)
  436. return out
  437. context.set_auto_parallel_context(device_num=8, global_rank=0)
  438. strategy1 = ((1, 4, 2), (1, 4, 2))
  439. strategy2 = ((4, 1, 2), )
  440. strategy3 = ((), (1, 1))
  441. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  442. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  443. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  444. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  445. b = Tensor(np.ones([128, 32]), dtype=ms.float32)
  446. compile(net, x, y, b)