You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reduce_method_info.py 38 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. '''Reduce method ut'''
  15. import numpy as np
  16. import mindspore as ms
  17. import mindspore.nn as nn
  18. from mindspore import Tensor
  19. from mindspore import context
  20. from mindspore.common.api import _cell_graph_executor
  21. from mindspore.ops import composite as C
  22. from mindspore.ops import operations as P
  23. from tests.ut.python.ops.test_math_ops import VirtualLoss
  24. grad_all = C.GradOperation(get_all=True)
  25. class NetWithLossNoBias(nn.Cell):
  26. def __init__(self, network):
  27. super(NetWithLossNoBias, self).__init__()
  28. self.loss = VirtualLoss()
  29. self.network = network
  30. def construct(self, x, y):
  31. predict = self.network(x, y)
  32. return self.loss(predict)
  33. class NetWithLoss(nn.Cell):
  34. def __init__(self, network):
  35. super(NetWithLoss, self).__init__()
  36. self.loss = VirtualLoss()
  37. self.network = network
  38. def construct(self, x, y, b):
  39. predict = self.network(x, y, b)
  40. return self.loss(predict)
  41. class GradWrapNoBias(nn.Cell):
  42. def __init__(self, network):
  43. super(GradWrapNoBias, self).__init__()
  44. self.network = network
  45. def construct(self, x, y):
  46. return grad_all(self.network)(x, y)
  47. class GradWrap(nn.Cell):
  48. def __init__(self, network):
  49. super(GradWrap, self).__init__()
  50. self.network = network
  51. def construct(self, x, y, b):
  52. return grad_all(self.network)(x, y, b)
  53. def compile_net_no_bias(net, x, y):
  54. net.set_auto_parallel()
  55. net.set_train()
  56. _cell_graph_executor.compile(net, x, y)
  57. def compile_net(net, x, y, b):
  58. net.set_auto_parallel()
  59. net.set_train()
  60. _cell_graph_executor.compile(net, x, y, b)
  61. # model_parallel test
  62. def test_sum_mul():
  63. """
  64. Feature: test ReduceSum model parallel strategy
  65. Description: partition the non-reduced axes, keep_dims is False
  66. Expectation: compile success
  67. """
  68. class Net(nn.Cell):
  69. def __init__(self, strategy1, strategy2, strategy3):
  70. super(Net, self).__init__()
  71. self.mul1 = P.Mul().shard(strategy1)
  72. self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2)
  73. self.mul2 = P.Mul().shard(strategy3)
  74. def construct(self, x, y, b):
  75. out = self.mul1(x, y)
  76. out = self.reduce_sum(out, (1,))
  77. out = self.mul2(out, b)
  78. return out
  79. context.set_auto_parallel_context(device_num=8, global_rank=0)
  80. strategy1 = ((1, 1, 8), (1, 1, 8))
  81. strategy2 = ((4, 1, 2),)
  82. strategy3 = ((2, 4), (2, 4))
  83. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  84. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  85. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  86. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  87. b = Tensor(np.ones([128, 64]), dtype=ms.float32)
  88. compile_net(net, x, y, b)
  89. def test_sum_mul2():
  90. """
  91. Feature: test ReduceSum model parallel strategy
  92. Description: partition the reduced axes, keep_dims is False
  93. Expectation: compile success
  94. """
  95. class Net(nn.Cell):
  96. def __init__(self, strategy1, strategy2, strategy3):
  97. super(Net, self).__init__()
  98. self.mul1 = P.Mul().shard(strategy1)
  99. self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2)
  100. self.mul2 = P.Mul().shard(strategy3)
  101. def construct(self, x, y, b):
  102. out = self.mul1(x, y)
  103. out = self.reduce_sum(out, (0, 1))
  104. out = self.mul2(out, b)
  105. return out
  106. context.set_auto_parallel_context(device_num=8, global_rank=0)
  107. strategy1 = ((1, 1, 4, 2), (1, 1, 4, 2))
  108. strategy2 = ((2, 4, 1, 1),)
  109. strategy3 = ((2, 4), (2, 4))
  110. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  111. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  112. x = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
  113. y = Tensor(np.ones([128, 128, 64, 64]), dtype=ms.float32)
  114. b = Tensor(np.ones([64, 64]), dtype=ms.float32)
  115. compile_net(net, x, y, b)
  116. def test_sum_mul3():
  117. """
  118. Feature: test ReduceSum model parallel strategy
  119. Description: partition the non-reduced axes, keep_dims is False
  120. Expectation: compile success
  121. """
  122. class Net(nn.Cell):
  123. def __init__(self, strategy1, strategy2, strategy3):
  124. super(Net, self).__init__()
  125. self.mul1 = P.Mul().shard(strategy1)
  126. self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2)
  127. self.mul2 = P.Mul().shard(strategy3)
  128. def construct(self, x, y, b):
  129. out = self.mul1(x, y)
  130. out = self.reduce_sum(out, -1)
  131. out = self.mul2(out, b)
  132. return out
  133. context.set_auto_parallel_context(device_num=8, global_rank=0)
  134. strategy1 = ((1, 4, 2), (1, 4, 2))
  135. strategy2 = ((4, 2, 1),)
  136. strategy3 = ((2, 4), (2, 4))
  137. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  138. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  139. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  140. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  141. b = Tensor(np.ones([128, 32]), dtype=ms.float32)
  142. compile_net(net, x, y, b)
  143. def test_sum_mul4():
  144. """
  145. Feature: test ReduceSum model parallel strategy
  146. Description: partition the reduced axes, keep_dims is True
  147. Expectation: compile success
  148. """
  149. class Net(nn.Cell):
  150. def __init__(self, strategy1, strategy2, strategy3):
  151. super(Net, self).__init__()
  152. self.mul1 = P.Mul().shard(strategy1)
  153. self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2)
  154. self.mul2 = P.Mul().shard(strategy3)
  155. def construct(self, x, y, b):
  156. out = self.mul1(x, y)
  157. out = self.reduce_sum(out, -1)
  158. out = self.mul2(out, b)
  159. return out
  160. context.set_auto_parallel_context(device_num=8, global_rank=0)
  161. strategy1 = ((1, 4, 2), (1, 4, 2))
  162. strategy2 = ((2, 2, 2),)
  163. strategy3 = ((4, 2, 1), (4, 2, 1))
  164. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  165. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  166. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  167. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  168. b = Tensor(np.ones([128, 32, 1]), dtype=ms.float32)
  169. compile_net(net, x, y, b)
  170. def test_sum_mul5():
  171. """
  172. Feature: test ReduceSum model parallel strategy
  173. Description: partition the reduced axes, keep_dims is True
  174. Expectation: compile success
  175. """
  176. class Net(nn.Cell):
  177. def __init__(self, strategy1, strategy2):
  178. super(Net, self).__init__()
  179. self.mul1 = P.Mul().shard(strategy1)
  180. self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2)
  181. def construct(self, x, y):
  182. out = self.mul1(x, y)
  183. out = self.reduce_sum(out, 0)
  184. return out
  185. context.set_auto_parallel_context(device_num=64, global_rank=0)
  186. strategy1 = ((1, 8, 8), (1, 8, 8))
  187. strategy2 = ((2, 4, 1),)
  188. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  189. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  190. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  191. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  192. compile_net_no_bias(net, x, y)
  193. def test_sum_mul6():
  194. """
  195. Feature: test ReduceSum model parallel strategy
  196. Description: partition the non-reduced axes, keep_dims is True
  197. Expectation: compile success
  198. """
  199. class Net(nn.Cell):
  200. def __init__(self, strategy1, strategy2):
  201. super(Net, self).__init__()
  202. self.mul1 = P.Mul().shard(strategy1)
  203. self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2)
  204. def construct(self, x, y):
  205. out = self.mul1(x, y)
  206. out = self.reduce_sum(out, 1)
  207. return out
  208. context.set_auto_parallel_context(device_num=64, global_rank=0)
  209. strategy1 = ((1, 8, 8), (1, 8, 8))
  210. strategy2 = ((2, 1, 4),)
  211. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  212. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  213. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  214. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  215. compile_net_no_bias(net, x, y)
  216. def test_sum_mul7():
  217. """
  218. Feature: test ReduceSum model parallel strategy
  219. Description: partition the reduced axes, keep_dims is True
  220. Expectation: compile success
  221. """
  222. class Net(nn.Cell):
  223. def __init__(self, strategy1, strategy2):
  224. super(Net, self).__init__()
  225. self.mul1 = P.Mul().shard(strategy1)
  226. self.reduce_sum = P.ReduceSum(keep_dims=True).shard(strategy2)
  227. def construct(self, x, y):
  228. out = self.mul1(x, y)
  229. out = self.reduce_sum(out, (0, 1))
  230. return out
  231. context.set_auto_parallel_context(device_num=64, global_rank=0)
  232. strategy1 = ((1, 8, 8), (1, 8, 8))
  233. strategy2 = ((2, 4, 1),)
  234. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  235. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  236. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  237. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  238. compile_net_no_bias(net, x, y)
  239. def test_max_mul():
  240. """
  241. Feature: test ReduceMax model parallel strategy
  242. Description: partition the reduced axes, keep_dims is False
  243. Expectation: compile success
  244. """
  245. class Net(nn.Cell):
  246. def __init__(self, strategy1, strategy2, strategy3):
  247. super(Net, self).__init__()
  248. self.mul1 = P.Mul().shard(strategy1)
  249. self.reduce_max = P.ReduceMax(keep_dims=False).shard(strategy2)
  250. self.mul2 = P.Mul().shard(strategy3)
  251. def construct(self, x, y, b):
  252. out = self.mul1(x, y)
  253. out = self.reduce_max(out, -1)
  254. out = self.mul2(out, b)
  255. return out
  256. context.set_auto_parallel_context(device_num=8, global_rank=0)
  257. strategy1 = ((1, 4, 2), (1, 4, 2))
  258. strategy2 = ((4, 1, 2),)
  259. strategy3 = ((2, 4), (2, 4))
  260. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  261. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  262. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  263. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  264. b = Tensor(np.ones([128, 32]), dtype=ms.float32)
  265. compile_net(net, x, y, b)
  266. def test_min_mul():
  267. """
  268. Feature: test ReduceMin model parallel strategy
  269. Description: partition the reduced axes, keep_dims is False
  270. Expectation: compile success
  271. """
  272. class Net(nn.Cell):
  273. def __init__(self, strategy1, strategy2, strategy3):
  274. super(Net, self).__init__()
  275. self.mul1 = P.Mul().shard(strategy1)
  276. self.reduce_min = P.ReduceMin(keep_dims=False).shard(strategy2)
  277. self.mul2 = P.Mul().shard(strategy3)
  278. def construct(self, x, y, b):
  279. out = self.mul1(x, y)
  280. out = self.reduce_min(out, 0)
  281. out = self.mul2(out, b)
  282. return out
  283. context.set_auto_parallel_context(device_num=8, global_rank=0)
  284. strategy1 = ((1, 4, 2), (1, 4, 2))
  285. strategy2 = ((4, 1, 2),)
  286. strategy3 = ((2, 4), (2, 4))
  287. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  288. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  289. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  290. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  291. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  292. compile_net(net, x, y, b)
  293. def test_reduce_mean_mul_float32():
  294. """
  295. Feature: test ReduceMean model parallel strategy
  296. Description: partition the reduced axes, keep_dims is False
  297. Expectation: compile success
  298. """
  299. class Net(nn.Cell):
  300. def __init__(self, strategy1, strategy2, strategy3):
  301. super(Net, self).__init__()
  302. self.mul1 = P.Mul().shard(strategy1)
  303. self.reduce_mean = P.ReduceMean(keep_dims=False).shard(strategy2)
  304. self.mul2 = P.Mul().shard(strategy3)
  305. def construct(self, x, y, b):
  306. out = self.mul1(x, y)
  307. out = self.reduce_mean(out, 0)
  308. out = self.mul2(out, b)
  309. return out
  310. context.set_auto_parallel_context(device_num=8, global_rank=0)
  311. strategy1 = ((1, 4, 2), (1, 4, 2))
  312. strategy2 = ((4, 1, 2),)
  313. strategy3 = ((2, 4), (2, 4))
  314. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  315. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  316. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  317. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  318. b = Tensor(np.ones([32, 64]), dtype=ms.float32)
  319. compile_net(net, x, y, b)
  320. class ArgMaxWithValueNet(nn.Cell):
  321. def __init__(self, strategy1, strategy2, strategy3):
  322. super(ArgMaxWithValueNet, self).__init__()
  323. self.mul1 = P.Mul().shard(strategy1)
  324. self.arg_max_with_value = P.ArgMaxWithValue(keep_dims=False, axis=-1).shard(strategy2)
  325. self.mul2 = P.Mul().shard(strategy3)
  326. def construct(self, x, y, b):
  327. out = self.mul1(x, y)
  328. _, out = self.arg_max_with_value(out)
  329. out = self.mul2(out, b)
  330. return out
  331. class ArgMinWithValueNet(nn.Cell):
  332. def __init__(self, strategy1, strategy2, strategy3):
  333. super(ArgMinWithValueNet, self).__init__()
  334. self.mul1 = P.Mul().shard(strategy1)
  335. self.arg_min_with_value = P.ArgMinWithValue(keep_dims=False, axis=-1).shard(strategy2)
  336. self.mul2 = P.Mul().shard(strategy3)
  337. def construct(self, x, y, b):
  338. out = self.mul1(x, y)
  339. _, out = self.arg_min_with_value(out)
  340. out = self.mul2(out, b)
  341. return out
  342. class ArgMaxNet(nn.Cell):
  343. def __init__(self, strategy1, strategy2):
  344. super(ArgMaxNet, self).__init__()
  345. self.mul1 = P.Mul().shard(strategy1)
  346. self.arg_max = P.Argmax(axis=-1).shard(strategy2)
  347. def construct(self, x, y):
  348. out = self.mul1(x, y)
  349. out = self.arg_max(out)
  350. return out
  351. class ArgMinNet(nn.Cell):
  352. def __init__(self, strategy1, strategy2):
  353. super(ArgMinNet, self).__init__()
  354. self.mul1 = P.Mul().shard(strategy1)
  355. self.arg_min = P.Argmin(axis=-1).shard(strategy2)
  356. def construct(self, x, y):
  357. out = self.mul1(x, y)
  358. out = self.arg_min(out)
  359. return out
  360. def gen_inputs_and_compile_net(net):
  361. x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
  362. y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
  363. b = Tensor(np.ones([128, 64]), dtype=ms.float32)
  364. compile_net(net, x, y, b)
  365. def gen_inputs_and_compile_net_no_bias(net):
  366. x = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
  367. y = Tensor(np.ones([128, 64, 64]), dtype=ms.float32)
  368. compile_net_no_bias(net, x, y)
  369. def tobefixed_test_arg_max_with_value_mul_semi_axis_parallel():
  370. context.set_auto_parallel_context(device_num=8, global_rank=0)
  371. strategy1 = ((1, 4, 2), (1, 4, 2))
  372. strategy2 = ((4, 1, 2),)
  373. strategy3 = ((2, 4), (2, 4))
  374. net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
  375. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  376. gen_inputs_and_compile_net(net)
  377. def test_arg_max_with_value_mul_semi():
  378. """
  379. Feature: test ArgMaxWithValue semi parallel strategy
  380. Description: partition the reduced axes, keep_dims is False
  381. Expectation: compile success
  382. """
  383. context.set_auto_parallel_context(device_num=8, global_rank=0)
  384. strategy1 = ((1, 4, 2), (1, 4, 2))
  385. strategy2 = ((4, 1, 1),)
  386. strategy3 = ((2, 4), (2, 4))
  387. net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
  388. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  389. gen_inputs_and_compile_net(net)
  390. def test_arg_max_with_value_mul_auto():
  391. """
  392. Feature: test ArgMaxWithValue auto parallel strategy
  393. Description: don't set the strategy, keep_dims is False
  394. Expectation: compile success
  395. """
  396. context.set_auto_parallel_context(device_num=8, global_rank=0)
  397. strategy1 = None
  398. strategy2 = None
  399. strategy3 = None
  400. net = GradWrap(NetWithLoss(ArgMaxWithValueNet(strategy1, strategy2, strategy3)))
  401. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  402. gen_inputs_and_compile_net(net)
  403. def test_arg_min_with_value_mul_semi_axis_parallel():
  404. """
  405. Feature: test ArgMinWithValue semi parallel strategy
  406. Description: partition the reduced axes, keep_dims is False
  407. Expectation: compile success
  408. """
  409. context.set_auto_parallel_context(device_num=8, global_rank=0)
  410. strategy1 = ((1, 4, 2), (1, 4, 2))
  411. strategy2 = ((4, 1, 2),)
  412. strategy3 = ((2, 4), (2, 4))
  413. net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
  414. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  415. gen_inputs_and_compile_net(net)
  416. def test_arg_min_with_value_mul_semi():
  417. """
  418. Feature: test ArgMinWithValue model parallel strategy
  419. Description: partition the non-reduced axes, keep_dims is False
  420. Expectation: compile success
  421. """
  422. context.set_auto_parallel_context(device_num=8, global_rank=0)
  423. strategy1 = ((1, 4, 2), (1, 4, 2))
  424. strategy2 = ((4, 1, 1),)
  425. strategy3 = ((2, 4), (2, 4))
  426. net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
  427. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  428. gen_inputs_and_compile_net(net)
  429. def test_arg_min_with_value_mul_auto():
  430. """
  431. Feature: test ArgMinWithValue auto parallel strategy
  432. Description: don't set the strategy, keep_dims is False
  433. Expectation: compile success
  434. """
  435. context.set_auto_parallel_context(device_num=8, global_rank=0)
  436. strategy1 = None
  437. strategy2 = None
  438. strategy3 = None
  439. net = GradWrap(NetWithLoss(ArgMinWithValueNet(strategy1, strategy2, strategy3)))
  440. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  441. gen_inputs_and_compile_net(net)
  442. def test_arg_max_semi_axis_parallel():
  443. """
  444. Feature: test Argmax semi parallel strategy
  445. Description: partition the reduced axes
  446. Expectation: compile success
  447. """
  448. context.set_auto_parallel_context(device_num=8, global_rank=0)
  449. strategy1 = ((1, 4, 2), (1, 4, 2))
  450. strategy2 = ((4, 1, 2),)
  451. net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2)))
  452. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  453. gen_inputs_and_compile_net_no_bias(net)
  454. def test_arg_max_mul_semi():
  455. """
  456. Feature: test Argmax model parallel strategy
  457. Description: partition the non-reduced axes
  458. Expectation: compile success
  459. """
  460. context.set_auto_parallel_context(device_num=8, global_rank=0)
  461. strategy1 = ((1, 4, 2), (1, 4, 2))
  462. strategy2 = ((4, 2, 1),)
  463. net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2)))
  464. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  465. gen_inputs_and_compile_net_no_bias(net)
  466. def test_arg_max_mul_auto():
  467. """
  468. Feature: test Argmax auto parallel strategy
  469. Description: don't set the strategy
  470. Expectation: compile success
  471. """
  472. context.set_auto_parallel_context(device_num=8, global_rank=0)
  473. strategy1 = None
  474. strategy2 = None
  475. net = GradWrapNoBias(NetWithLossNoBias(ArgMaxNet(strategy1, strategy2)))
  476. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  477. gen_inputs_and_compile_net_no_bias(net)
  478. def test_arg_min_semi_axis_parallel():
  479. """
  480. Feature: test Argmin semi parallel strategy
  481. Description: partition the reduced axes
  482. Expectation: compile success
  483. """
  484. context.set_auto_parallel_context(device_num=8, global_rank=0)
  485. strategy1 = ((1, 4, 2), (1, 4, 2))
  486. strategy2 = ((4, 1, 2),)
  487. net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2)))
  488. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  489. gen_inputs_and_compile_net_no_bias(net)
  490. def test_arg_min_mul_semi():
  491. """
  492. Feature: test Argmin model parallel strategy
  493. Description: partition the non-reduced axes
  494. Expectation: compile success
  495. """
  496. context.set_auto_parallel_context(device_num=8, global_rank=0)
  497. strategy1 = ((1, 4, 2), (1, 4, 2))
  498. strategy2 = ((4, 2, 1),)
  499. net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2)))
  500. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  501. gen_inputs_and_compile_net_no_bias(net)
  502. def test_arg_min_mul_auto():
  503. """
  504. Feature: test Argmin auto parallel strategy
  505. Description: don't set the strategy
  506. Expectation: compile success
  507. """
  508. context.set_auto_parallel_context(device_num=8, global_rank=0)
  509. strategy1 = None
  510. strategy2 = None
  511. net = GradWrapNoBias(NetWithLossNoBias(ArgMinNet(strategy1, strategy2)))
  512. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  513. gen_inputs_and_compile_net_no_bias(net)
  514. class ArgMinWithValueNet2(nn.Cell):
  515. def __init__(self, strategy1, strategy2, strategy3):
  516. super(ArgMinWithValueNet2, self).__init__()
  517. self.mul1 = P.Mul().shard(strategy1)
  518. self.arg_min_with_value = P.ArgMinWithValue(keep_dims=True, axis=-1).shard(strategy2)
  519. self.relu = P.ReLU().shard(strategy3)
  520. def construct(self, x, y):
  521. out = self.mul1(x, y)
  522. _, out = self.arg_min_with_value(out)
  523. out = self.relu(out)
  524. return out
  525. def tobefixed_test_arg_min_with_value_mul_semi_axis_parallel2():
  526. context.set_auto_parallel_context(device_num=8, global_rank=0)
  527. strategy1 = ((1, 4, 2), (1, 4, 2))
  528. strategy2 = ((4, 1, 2),)
  529. strategy3 = ((2, 4, 1),)
  530. net = GradWrapNoBias(NetWithLossNoBias(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
  531. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  532. gen_inputs_and_compile_net_no_bias(net)
  533. def test_arg_min_with_value_mul_semi2():
  534. """
  535. Feature: test ArgMinWithValue semi parallel strategy
  536. Description: partition the non-reduced axes, keep_dims is True
  537. Expectation: compile success
  538. """
  539. context.set_auto_parallel_context(device_num=8, global_rank=0)
  540. strategy1 = ((1, 4, 2), (1, 4, 2))
  541. strategy2 = ((4, 1, 1),)
  542. strategy3 = ((2, 4, 1),)
  543. net = GradWrapNoBias(NetWithLossNoBias(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
  544. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  545. gen_inputs_and_compile_net_no_bias(net)
  546. def test_arg_min_with_value_mul_auto2():
  547. """
  548. Feature: test ArgMinWithValue auto parallel strategy
  549. Description: don't set the strategy, keep_dims is True
  550. Expectation: compile success
  551. """
  552. context.set_auto_parallel_context(device_num=8, global_rank=0)
  553. strategy1 = None
  554. strategy2 = None
  555. strategy3 = None
  556. net = GradWrapNoBias(NetWithLossNoBias(ArgMinWithValueNet2(strategy1, strategy2, strategy3)))
  557. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  558. gen_inputs_and_compile_net_no_bias(net)
  559. def test_cross_batch():
  560. """
  561. Feature: test ReduceMean semi parallel strategy with cross_batch
  562. Description: partition the reduced axes, keep_dims is False
  563. Expectation: compile success
  564. """
  565. class Net(nn.Cell):
  566. def __init__(self, strategy1, strategy2, strategy3):
  567. super(Net, self).__init__()
  568. self.mul1 = P.Mul().shard(strategy1)
  569. self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy2)
  570. self.reduce_mean = P.ReduceMean(keep_dims=False).shard(strategy3) \
  571. .add_prim_attr("cross_batch", True)
  572. def construct(self, x, y):
  573. out = self.mul1(x, y)
  574. out = self.reduce_sum(out, -1)
  575. out = self.reduce_mean(out, 0)
  576. return out
  577. context.set_auto_parallel_context(device_num=8, global_rank=0)
  578. strategy1 = ((4, 2), (4, 2))
  579. strategy2 = ((2, 1),)
  580. strategy3 = ((8,),)
  581. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2, strategy3)))
  582. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  583. x = Tensor(np.ones([32, 64]), dtype=ms.float32)
  584. y = Tensor(np.ones([32, 64]), dtype=ms.float32)
  585. compile_net_no_bias(net, x, y)
  586. def test_cross_batch2():
  587. """
  588. Feature: test ReduceSum semi parallel strategy with cross_batch
  589. Description: partition the reduced axes, keep_dims is False
  590. Expectation: compile success
  591. """
  592. class Net(nn.Cell):
  593. def __init__(self, strategy1, strategy2, strategy3):
  594. super(Net, self).__init__()
  595. self.mul1 = P.Mul().shard(strategy1)
  596. self.reduce_mean = P.ReduceMean(keep_dims=False).shard(strategy2)
  597. self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy3) \
  598. .add_prim_attr("cross_batch", True)
  599. def construct(self, x, y):
  600. out = self.mul1(x, y)
  601. out = self.reduce_mean(out, -1)
  602. out = self.reduce_sum(out, 0)
  603. return out
  604. context.set_auto_parallel_context(device_num=8, global_rank=0)
  605. strategy1 = ((4, 2), (4, 2))
  606. strategy2 = ((2, 1),)
  607. strategy3 = ((8,),)
  608. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2, strategy3)))
  609. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  610. x = Tensor(np.ones([32, 64]), dtype=ms.float32)
  611. y = Tensor(np.ones([32, 64]), dtype=ms.float32)
  612. compile_net_no_bias(net, x, y)
  613. def test_cross_batch_auto():
  614. """
  615. Feature: test ReduceSum auto parallel strategy with cross_batch
  616. Description: don't set the strategy, keep_dims is False
  617. Expectation: compile success
  618. """
  619. class Net(nn.Cell):
  620. def __init__(self):
  621. super(Net, self).__init__()
  622. self.mul1 = P.Mul()
  623. self.reduce_mean = P.ReduceMean(keep_dims=False)
  624. self.reduce_sum = P.ReduceSum(keep_dims=False).add_prim_attr("cross_batch", True)
  625. def construct(self, x, y):
  626. out = self.mul1(x, y)
  627. out = self.reduce_mean(out, -1)
  628. out = self.reduce_sum(out, 0)
  629. return out
  630. context.set_auto_parallel_context(device_num=8, global_rank=0)
  631. net = GradWrapNoBias(NetWithLossNoBias(Net()))
  632. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  633. x = Tensor(np.ones([32, 64]), dtype=ms.float32)
  634. y = Tensor(np.ones([32, 64]), dtype=ms.float32)
  635. compile_net_no_bias(net, x, y)
  636. def test_max_empty_tuple():
  637. """
  638. Feature: test ReduceMax semi parallel strategy
  639. Description: partition the reduced axes, keep_dims is False
  640. Expectation: compile success
  641. """
  642. class Net(nn.Cell):
  643. def __init__(self, strategy1, strategy2, strategy3):
  644. super(Net, self).__init__()
  645. self.mul = P.Mul().shard(strategy1)
  646. self.reduce_max = P.ReduceMax(keep_dims=False).shard(strategy2)
  647. self.add = P.Add().shard(strategy3)
  648. def construct(self, x, y, b):
  649. out = self.mul(x, y)
  650. out = self.reduce_max(out)
  651. out = self.add(out, b)
  652. return out
  653. context.set_auto_parallel_context(device_num=8, global_rank=0)
  654. strategy1 = ((1, 4, 2), (1, 4, 2))
  655. strategy2 = ((4, 1, 2),)
  656. strategy3 = ((), (1, 1))
  657. net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
  658. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  659. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  660. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  661. b = Tensor(np.ones([128, 32]), dtype=ms.float32)
  662. compile_net(net, x, y, b)
  663. def test_any_mul():
  664. """
  665. Feature: test ReduceAny semi parallel strategy
  666. Description: partition the reduced axes, keep_dims is False
  667. Expectation: compile success
  668. """
  669. class Net(nn.Cell):
  670. def __init__(self, strategy1, strategy2):
  671. super(Net, self).__init__()
  672. self.mul1 = P.Mul().shard(strategy1)
  673. self.reduce_any = P.ReduceAny(keep_dims=False).shard(strategy2)
  674. self.cast = P.Cast()
  675. def construct(self, x, y):
  676. out = self.mul1(x, y)
  677. out = self.cast(out, ms.bool_)
  678. out = self.reduce_any(out, 1)
  679. return out
  680. context.set_auto_parallel_context(device_num=64, global_rank=0)
  681. strategy1 = ((1, 8, 1), (1, 8, 1))
  682. strategy2 = ((1, 8, 1),)
  683. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  684. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  685. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  686. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  687. compile_net_no_bias(net, x, y)
  688. def test_any_mul2():
  689. """
  690. Feature: test ReduceAny semi parallel strategy
  691. Description: partition the non-reduced axes, keep_dims is False
  692. Expectation: compile success
  693. """
  694. class Net(nn.Cell):
  695. def __init__(self, strategy1, strategy2):
  696. super(Net, self).__init__()
  697. self.mul1 = P.Mul().shard(strategy1)
  698. self.reduce_any = P.ReduceAny(keep_dims=False).shard(strategy2)
  699. self.cast = P.Cast()
  700. def construct(self, x, y):
  701. out = self.mul1(x, y)
  702. out = self.cast(out, ms.bool_)
  703. out = self.reduce_any(out, -1)
  704. return out
  705. context.set_auto_parallel_context(device_num=64, global_rank=0)
  706. strategy1 = ((8, 1, 1), (8, 1, 1))
  707. strategy2 = ((8, 1, 1),)
  708. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  709. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  710. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  711. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  712. compile_net_no_bias(net, x, y)
  713. def test_all_mul():
  714. """
  715. Feature: test ReduceAll semi parallel strategy
  716. Description: partition the reduced axes, keep_dims is False
  717. Expectation: compile success
  718. """
  719. class Net(nn.Cell):
  720. def __init__(self, strategy1, strategy2):
  721. super(Net, self).__init__()
  722. self.mul1 = P.Mul().shard(strategy1)
  723. self.reduce_all = P.ReduceAll(keep_dims=False).shard(strategy2)
  724. self.cast = P.Cast()
  725. def construct(self, x, y):
  726. out = self.mul1(x, y)
  727. out = self.cast(out, ms.bool_)
  728. out = self.reduce_all(out, 1)
  729. return out
  730. context.set_auto_parallel_context(device_num=8, global_rank=0)
  731. strategy1 = ((1, 8, 1), (1, 8, 1))
  732. strategy2 = ((1, 8, 1),)
  733. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  734. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  735. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  736. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  737. compile_net_no_bias(net, x, y)
  738. def test_all_mul2():
  739. """
  740. Feature: test ReduceAll semi parallel strategy
  741. Description: partition the non-reduced axes, keep_dims is False
  742. Expectation: compile success
  743. """
  744. class Net(nn.Cell):
  745. def __init__(self, strategy1, strategy2):
  746. super(Net, self).__init__()
  747. self.mul1 = P.Mul().shard(strategy1)
  748. self.reduce_all = P.ReduceAll(keep_dims=False).shard(strategy2)
  749. self.cast = P.Cast()
  750. def construct(self, x, y):
  751. out = self.mul1(x, y)
  752. out = self.cast(out, ms.bool_)
  753. out = self.reduce_all(out, -1)
  754. return out
  755. context.set_auto_parallel_context(device_num=8, global_rank=0)
  756. strategy1 = ((8, 1, 1), (8, 1, 1))
  757. strategy2 = ((8, 1, 1),)
  758. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  759. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  760. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  761. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  762. compile_net_no_bias(net, x, y)
  763. def test_prod_mul():
  764. """
  765. Feature: test ReduceProd model parallel strategy
  766. Description: partition the reduced axes, keep_dims is False
  767. Expectation: compile success
  768. """
  769. class Net(nn.Cell):
  770. def __init__(self, strategy1, strategy2):
  771. super(Net, self).__init__()
  772. self.mul1 = P.Mul().shard(strategy1)
  773. self.reduce_prod = P.ReduceProd(keep_dims=False).shard(strategy2)
  774. def construct(self, x, y):
  775. out = self.mul1(x, y)
  776. out = self.reduce_prod(out, 0)
  777. return out
  778. context.set_auto_parallel_context(device_num=8, global_rank=0)
  779. strategy1 = ((1, 1, 8), (1, 1, 8))
  780. strategy2 = ((2, 4, 1),)
  781. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  782. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  783. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  784. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  785. compile_net_no_bias(net, x, y)
  786. def test_prod_mul2():
  787. """
  788. Feature: test ReduceProd model parallel strategy
  789. Description: partition the non-reduced axes, keep_dims is False
  790. Expectation: compile success
  791. """
  792. class Net(nn.Cell):
  793. def __init__(self, strategy1, strategy2):
  794. super(Net, self).__init__()
  795. self.mul1 = P.Mul().shard(strategy1)
  796. self.reduce_prod = P.ReduceProd(keep_dims=False).shard(strategy2)
  797. def construct(self, x, y):
  798. out = self.mul1(x, y)
  799. out = self.reduce_prod(out, -1)
  800. return out
  801. context.set_auto_parallel_context(device_num=8, global_rank=0)
  802. strategy1 = ((1, 8, 1), (1, 8, 1))
  803. strategy2 = ((2, 4, 1),)
  804. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  805. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  806. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  807. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  808. compile_net_no_bias(net, x, y)
  809. def test_prod_mul3():
  810. """
  811. Feature: test ReduceProd model parallel strategy
  812. Description: partition the reduced axes, keep_dims is True
  813. Expectation: compile success
  814. """
  815. class Net(nn.Cell):
  816. def __init__(self, stra_mul, stra_prod):
  817. super(Net, self).__init__()
  818. self.mul = P.Mul().shard(stra_mul)
  819. self.reduce_prod = P.ReduceProd(keep_dims=True).shard(stra_prod)
  820. def construct(self, x, y):
  821. out = self.mul(x, y)
  822. out = self.reduce_prod(out, 0)
  823. return out
  824. context.set_auto_parallel_context(device_num=8, global_rank=0)
  825. strategy1 = ((1, 1, 8), (1, 1, 8))
  826. strategy2 = ((8, 1, 1),)
  827. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  828. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  829. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  830. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  831. compile_net_no_bias(net, x, y)
  832. def test_prod_mul_auto():
  833. """
  834. Feature: test ReduceProd auto parallel strategy
  835. Description: don't set the strategy, keep_dims is True
  836. Expectation: compile success
  837. """
  838. class Net(nn.Cell):
  839. def __init__(self, strategy1, strategy2):
  840. super(Net, self).__init__()
  841. self.mul1 = P.Mul().shard(strategy1)
  842. self.reduce_prod = P.ReduceProd(keep_dims=True).shard(strategy2)
  843. def construct(self, x, y):
  844. out = self.mul1(x, y)
  845. out = self.reduce_prod(out, 0)
  846. return out
  847. context.set_auto_parallel_context(device_num=8, global_rank=0)
  848. strategy1 = None
  849. strategy2 = None
  850. net = GradWrapNoBias(NetWithLossNoBias(Net(strategy1, strategy2)))
  851. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  852. gen_inputs_and_compile_net_no_bias(net)
  853. def test_square_sum_all_mul():
  854. """
  855. Feature: test SquareSumAll model parallel strategy
  856. Description: partition the reduced axes
  857. Expectation: compile success
  858. """
  859. class Net(nn.Cell):
  860. def __init__(self, strategy1, strategy2):
  861. super(Net, self).__init__()
  862. self.mul1 = P.Mul().shard(strategy1)
  863. self.square_sum_all = P.SquareSumAll().shard(strategy2)
  864. def construct(self, x, y):
  865. out = self.mul1(x, y)
  866. out = self.square_sum_all(out, out)
  867. return out
  868. context.set_auto_parallel_context(device_num=8, global_rank=0)
  869. strategy1 = ((1, 1, 8), (1, 1, 8))
  870. strategy2 = ((2, 4, 1), (2, 4, 1))
  871. net = Net(strategy1, strategy2)
  872. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  873. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  874. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  875. compile_net_no_bias(net, x, y)
  876. def test_square_sum_all_mul2():
  877. """
  878. Feature: test SquareSumAll model parallel strategy
  879. Description: partition the reduced axes
  880. Expectation: compile success
  881. """
  882. class Net(nn.Cell):
  883. def __init__(self, stra_mul, stra_prod):
  884. super(Net, self).__init__()
  885. self.mul = P.Mul().shard(stra_mul)
  886. self.square_sum_all = P.SquareSumAll().shard(stra_prod)
  887. def construct(self, x, y):
  888. out = self.mul(x, y)
  889. out = self.square_sum_all(out, out)
  890. return out
  891. context.set_auto_parallel_context(device_num=8, global_rank=0)
  892. strategy1 = ((1, 1, 8), (1, 1, 8))
  893. strategy2 = ((8, 1, 1), (8, 1, 1))
  894. net = Net(strategy1, strategy2)
  895. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  896. x = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  897. y = Tensor(np.ones([128, 32, 64]), dtype=ms.float32)
  898. compile_net_no_bias(net, x, y)