You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pipeline_split.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import shutil
  17. import glob
  18. import numpy as np
  19. import mindspore as ms
  20. import mindspore.nn as nn
  21. from mindspore import context
  22. from mindspore import Tensor
  23. from mindspore.ops import operations as P
  24. from mindspore.common.parameter import Parameter
  25. from mindspore.common.initializer import initializer
  26. from mindspore.train.model import Model
  27. from mindspore.nn.wrap.cell_wrapper import PipelineCell, MicroBatchInterleaved
  28. class DatasetLenet():
  29. def __init__(self, data, label, length=3):
  30. self.data = data
  31. self.label = label
  32. self.index = 1
  33. self.length = length
  34. def __iter__(self):
  35. return self
  36. def __next__(self):
  37. if self.index >= self.length:
  38. raise StopIteration
  39. self.index += 1
  40. return self.data, self.label
  41. def reset(self):
  42. self.index = 0
  43. def get_dataset_size(self):
  44. return 32
  45. def get_repeat_count(self):
  46. return 1
  47. def get_batch_size(self):
  48. return 32
  49. def create_tuple_iterator(self, num_epochs=1, do_copy=True):
  50. return self
  51. class MatMulCell(nn.Cell):
  52. def __init__(self, strategy1, strategy2, param=None, dtype=ms.float32):
  53. super().__init__()
  54. self.param = Parameter(initializer("zeros", [64, 64]), name="param")
  55. if param is not None:
  56. self.param = param
  57. self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
  58. self.matmul = P.MatMul().shard(strategy1)
  59. self.matmul1 = P.MatMul().shard(strategy2)
  60. self.cast = P.Cast()
  61. self.dtype = dtype
  62. def construct(self, x):
  63. out = self.matmul(self.cast(x, self.dtype), self.cast(self.param, self.dtype))
  64. out = self.matmul1(out, self.cast(self.param1, self.dtype))
  65. return out
  66. class Net(nn.Cell):
  67. def __init__(self, strategy1, strategy2, param=None, dtype=ms.float32):
  68. super().__init__()
  69. self.block = nn.CellList()
  70. for i in range(2):
  71. cell = MatMulCell(strategy1, strategy2, param, dtype)
  72. cell.pipeline_stage = i
  73. self.block.append(cell)
  74. def construct(self, x):
  75. for i in range(2):
  76. x = self.block[i](x)
  77. return x
  78. class PipelineSplit(nn.Cell):
  79. def __init__(self, strategy1, strategy2, dtype=ms.float32):
  80. super().__init__()
  81. self.cell = Net(strategy1, strategy2, dtype=dtype)
  82. self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
  83. def construct(self, x, label):
  84. x = self.cell(x)
  85. return x
  86. class PipelineSplit2(nn.Cell):
  87. def __init__(self, strategy1, strategy2, dtype=ms.float32):
  88. super().__init__()
  89. self.param = Parameter(initializer("zeros", [64, 64]), name="param")
  90. self.cell = Net(strategy1, strategy2, self.param, dtype)
  91. self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
  92. def construct(self, x, label):
  93. x = self.cell(x)
  94. return x
  95. def test_pipeline_split_stage0():
  96. context.set_auto_parallel_context(device_num=32, global_rank=0, pipeline_stages=2)
  97. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  98. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  99. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  100. strategy1 = ((16, 1), (1, 1))
  101. strategy2 = ((8, 1), (1, 1))
  102. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  103. params = net.network.cell.block[0].trainable_params()
  104. dataset = DatasetLenet(data, label, 3)
  105. optimizer = nn.Lamb(params, learning_rate=0.01)
  106. model = Model(net, optimizer=optimizer)
  107. model.train(2, dataset, dataset_sink_mode=False)
  108. for _, param in model._train_network.parameters_and_names():
  109. assert param.name != "cell.block.1.param"
  110. assert param.name != "cell.block.1.param1"
  111. def test_pipeline_split_stage1():
  112. context.set_auto_parallel_context(device_num=32, global_rank=16, pipeline_stages=2)
  113. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  114. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  115. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  116. strategy1 = ((16, 1), (1, 1))
  117. strategy2 = ((8, 1), (1, 1))
  118. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  119. params = net.network.cell.block[1].trainable_params()
  120. dataset = DatasetLenet(data, label, 3)
  121. optimizer = nn.Lamb(params, learning_rate=0.01)
  122. model = Model(net, optimizer=optimizer)
  123. model.train(2, dataset, dataset_sink_mode=False)
  124. for _, param in model._train_network.parameters_and_names():
  125. assert param.name != "cell.block.0.param"
  126. assert param.name != "cell.block.0.param1"
  127. def test_pipeline_split_shared_parameter_stage0():
  128. context.set_auto_parallel_context(device_num=32, global_rank=0, pipeline_stages=2)
  129. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  130. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  131. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  132. strategy1 = ((16, 1), (1, 1))
  133. strategy2 = ((8, 1), (1, 1))
  134. net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
  135. params = net.network.cell.block[0].trainable_params()
  136. dataset = DatasetLenet(data, label, 3)
  137. optimizer = nn.Lamb(params, learning_rate=0.01)
  138. model = Model(net, optimizer=optimizer)
  139. model.train(2, dataset, dataset_sink_mode=False)
  140. def test_pipeline_split_shared_parameter_stage1():
  141. context.set_auto_parallel_context(device_num=32, global_rank=16, pipeline_stages=2)
  142. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  143. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  144. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  145. strategy1 = ((16, 1), (1, 1))
  146. strategy2 = ((8, 1), (1, 1))
  147. net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
  148. params = net.network.cell.block[1].trainable_params()
  149. dataset = DatasetLenet(data, label, 3)
  150. optimizer = nn.Lamb(params, learning_rate=0.01)
  151. model = Model(net, optimizer=optimizer)
  152. model.train(2, dataset, dataset_sink_mode=False)
  153. def test_pipeline_split_shared_parameter_stage0_predict():
  154. context.set_auto_parallel_context(device_num=32, global_rank=0, pipeline_stages=2, full_batch=True)
  155. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  156. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  157. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  158. strategy1 = ((16, 1), (1, 1))
  159. strategy2 = ((8, 1), (1, 1))
  160. net = PipelineSplit2(strategy1, strategy2)
  161. model = Model(net)
  162. model.predict(data, label)
  163. def test_pipeline_split_shared_parameter_stage1_predict():
  164. context.set_auto_parallel_context(device_num=32, global_rank=16, pipeline_stages=2, full_batch=True)
  165. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  166. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  167. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  168. strategy1 = ((16, 1), (1, 1))
  169. strategy2 = ((8, 1), (1, 1))
  170. net = PipelineSplit2(strategy1, strategy2)
  171. model = Model(net)
  172. model.predict(data, label)
  173. def test_pipeline_split_stage0_opt_shard():
  174. context.set_auto_parallel_context(device_num=32, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
  175. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  176. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  177. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  178. strategy1 = ((16, 1), (1, 1))
  179. strategy2 = ((8, 1), (1, 1))
  180. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  181. params = net.network.cell.block[0].trainable_params()
  182. dataset = DatasetLenet(data, label, 3)
  183. optimizer = nn.Lamb(params, learning_rate=0.01)
  184. model = Model(net, optimizer=optimizer)
  185. model.train(2, dataset, dataset_sink_mode=False)
  186. for _, param in model._train_network.parameters_and_names():
  187. assert param.name != "cell.block.1.param"
  188. assert param.name != "cell.block.1.param1"
  189. def test_pipeline_split_stage1_opt_shard():
  190. context.set_auto_parallel_context(device_num=32, global_rank=16, pipeline_stages=2, enable_parallel_optimizer=True)
  191. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  192. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  193. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  194. strategy1 = ((16, 1), (1, 1))
  195. strategy2 = ((8, 1), (1, 1))
  196. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  197. params = net.network.cell.block[1].trainable_params()
  198. dataset = DatasetLenet(data, label, 3)
  199. optimizer = nn.Lamb(params, learning_rate=0.01)
  200. model = Model(net, optimizer=optimizer)
  201. model.train(2, dataset, dataset_sink_mode=False)
  202. for _, param in model._train_network.parameters_and_names():
  203. assert param.name != "cell.block.0.param"
  204. assert param.name != "cell.block.0.param1"
  205. def test_pipeline_split_shared_parameter_stage0_opt_shard():
  206. context.set_auto_parallel_context(device_num=32, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
  207. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  208. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  209. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  210. strategy1 = ((16, 1), (1, 1))
  211. strategy2 = ((8, 1), (1, 1))
  212. net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
  213. params = net.network.cell.block[0].trainable_params()
  214. dataset = DatasetLenet(data, label, 3)
  215. optimizer = nn.Lamb(params, learning_rate=0.01)
  216. model = Model(net, optimizer=optimizer)
  217. model.train(2, dataset, dataset_sink_mode=False)
  218. def test_pipeline_split_shared_parameter_stage1_opt_shard():
  219. context.set_auto_parallel_context(device_num=32, global_rank=16, pipeline_stages=2, enable_parallel_optimizer=True)
  220. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  221. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  222. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  223. strategy1 = ((16, 1), (1, 1))
  224. strategy2 = ((8, 1), (1, 1))
  225. net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
  226. params = net.network.cell.block[1].trainable_params()
  227. dataset = DatasetLenet(data, label, 3)
  228. optimizer = nn.Lamb(params, learning_rate=0.01)
  229. model = Model(net, optimizer=optimizer)
  230. model.train(2, dataset, dataset_sink_mode=False)
  231. def test_pipeline_split_with_micro_batch_interleaved_stage0():
  232. """
  233. Feature: test PipelineSplit with MicroBatchInterleaved in auto parallel.
  234. Description: net with MicroBatchInterleaved in semi auto parallel.
  235. Expectation: success.
  236. """
  237. context.set_auto_parallel_context(device_num=32, global_rank=0, pipeline_stages=2)
  238. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  239. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  240. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  241. strategy1 = ((16, 1), (1, 1))
  242. strategy2 = ((8, 1), (1, 1))
  243. micro_batch_interleaved = 2
  244. net = PipelineCell(MicroBatchInterleaved(PipelineSplit(strategy1, strategy2), micro_batch_interleaved), 4)
  245. params = net.network.network.cell.block[0].trainable_params()
  246. dataset = DatasetLenet(data, label, 3)
  247. optimizer = nn.Lamb(params, learning_rate=0.01)
  248. model = Model(net, optimizer=optimizer)
  249. model.train(2, dataset, dataset_sink_mode=False)
  250. for _, param in model._train_network.parameters_and_names():
  251. assert param.name != "cell.block.1.param"
  252. assert param.name != "cell.block.1.param1"
  253. def test_pipeline_split_with_micro_batch_interleaved_stage1():
  254. """
  255. Feature: test PipelineSplit with MicroBatchInterleaved in auto parallel.
  256. Description: net with MicroBatchInterleaved in semi auto parallel.
  257. Expectation: success.
  258. """
  259. context.set_auto_parallel_context(device_num=32, global_rank=16, pipeline_stages=2)
  260. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  261. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  262. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  263. strategy1 = ((16, 1), (1, 1))
  264. strategy2 = ((8, 1), (1, 1))
  265. micro_batch_interleaved = 2
  266. net = PipelineCell(MicroBatchInterleaved(PipelineSplit(strategy1, strategy2), micro_batch_interleaved), 4)
  267. params = net.network.network.cell.block[1].trainable_params()
  268. dataset = DatasetLenet(data, label, 3)
  269. optimizer = nn.Lamb(params, learning_rate=0.01)
  270. model = Model(net, optimizer=optimizer)
  271. model.train(2, dataset, dataset_sink_mode=False)
  272. for _, param in model._train_network.parameters_and_names():
  273. assert param.name != "cell.block.0.param"
  274. assert param.name != "cell.block.0.param1"
  275. def test_pipeline_split_shared_parameter_with_micro_batch_interleaved_stage0_opt_shard():
  276. """
  277. Feature: test PipelineSplitSharedParameter with MicroBatchInterleaved in auto parallel.
  278. Description: net with MicroBatchInterleaved in semi auto parallel.
  279. Expectation: success.
  280. """
  281. context.set_auto_parallel_context(device_num=32, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
  282. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  283. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  284. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  285. strategy1 = ((16, 1), (1, 1))
  286. strategy2 = ((8, 1), (1, 1))
  287. micro_batch_interleaved = 2
  288. net = PipelineCell(MicroBatchInterleaved(PipelineSplit2(strategy1, strategy2), micro_batch_interleaved), 4)
  289. params = net.network.network.cell.block[0].trainable_params()
  290. dataset = DatasetLenet(data, label, 3)
  291. optimizer = nn.Lamb(params, learning_rate=0.01)
  292. model = Model(net, optimizer=optimizer)
  293. model.train(2, dataset, dataset_sink_mode=False)
  294. def test_pipeline_split_shared_parameter_with_micro_batch_interleaved_stage1_opt_shard():
  295. """
  296. Feature: test PipelineSplitSharedParameter with MicroBatchInterleaved in auto parallel.
  297. Description: net with MicroBatchInterleaved in semi auto parallel.
  298. Expectation: success.
  299. """
  300. context.set_auto_parallel_context(device_num=32, global_rank=16, pipeline_stages=2, enable_parallel_optimizer=True)
  301. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  302. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  303. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  304. strategy1 = ((16, 1), (1, 1))
  305. strategy2 = ((8, 1), (1, 1))
  306. micro_batch_interleaved = 2
  307. net = PipelineCell(MicroBatchInterleaved(PipelineSplit2(strategy1, strategy2), micro_batch_interleaved), 4)
  308. params = net.network.network.cell.block[1].trainable_params()
  309. dataset = DatasetLenet(data, label, 3)
  310. optimizer = nn.Lamb(params, learning_rate=0.01)
  311. model = Model(net, optimizer=optimizer)
  312. model.train(2, dataset, dataset_sink_mode=False)
  313. def run_pipeline_split_function(pipeline_net, micro_batch_interleaved=1):
  314. """
  315. Feature: test PipelineSplitSharedParameter with MicroBatchInterleaved in auto parallel.
  316. Description: net with MicroBatchInterleaved in semi auto parallel.
  317. Expectation: success.
  318. """
  319. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  320. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  321. net = PipelineCell(MicroBatchInterleaved(pipeline_net, micro_batch_interleaved), 4)
  322. params = net.infer_param_pipeline_stage()
  323. dataset = DatasetLenet(data, label, 3)
  324. optimizer = nn.Lamb(params, learning_rate=0.01)
  325. model = Model(net, optimizer=optimizer)
  326. model.train(2, dataset, dataset_sink_mode=False)
  327. class TestPipelineSplitWithNoOptimizer:
  328. def setup_method(self):
  329. self.output_path = './graphs' + self.__str__()
  330. context.set_context(save_graphs=True,
  331. save_graphs_path=self.output_path)
  332. def teardown_method(self):
  333. shutil.rmtree(self.output_path)
  334. def cat_fp16_from_ir(self, pattern, target_count):
  335. """
  336. This function will check the float16 count with the golden one.
  337. :param pattern: The match pattern for the specific count
  338. :param target_count: The gold float16 count in the Ir files
  339. """
  340. ir_files = glob.glob(os.path.join(self.output_path, 'rank_0', '*_validate*.ir'))
  341. assert len(ir_files) == 1
  342. appear_count = 0
  343. with open(ir_files[0], 'r') as fp:
  344. for line in fp:
  345. if pattern in line:
  346. appear_count += 1
  347. assert appear_count == target_count
  348. def test_pipeline_with_no_parallel_optimizer_and_micro(self):
  349. """
  350. Feature: Test Pipeline with Mirror Operator.
  351. Description: When using fp16 computation, there should be only one mirror operator for one parameter.
  352. Expectation: the number of the float16 tensor is not equal to 16, 16 is obtained by manually checked graph.
  353. the number of the Mirror is not equal to 2, 2 is obtained by manually checked graph.
  354. """
  355. context.set_auto_parallel_context(device_num=32, global_rank=0, pipeline_stages=2,
  356. enable_parallel_optimizer=False)
  357. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  358. strategy1 = ((16, 1), (1, 1))
  359. strategy2 = ((8, 1), (1, 1))
  360. pipeline_net = PipelineSplit(strategy1, strategy2, dtype=ms.float16)
  361. run_pipeline_split_function(pipeline_net, micro_batch_interleaved=1)
  362. self.cat_fp16_from_ir(pattern='grad_mirror_MirrorMicroStepOperator',
  363. target_count=2)
  364. self.cat_fp16_from_ir(pattern='Cast(',
  365. target_count=15)
  366. def test_pipeline_with_micro_batch_no_parallel_optimizer(self):
  367. """
  368. Feature: Test Pipeline with Mirror Operator, when enabled the micro batch interleave.
  369. Description: When using fp16 computation, there should be only one mirror operator for one parameter.
  370. Expectation: the number of the float16 tensor is not equal to 16, 16 is obtained by manually checked graph.
  371. the number of the Mirror is not equal to 2, 2 is obtained by manually checked graph.
  372. """
  373. context.set_auto_parallel_context(device_num=32, global_rank=0, pipeline_stages=2,
  374. enable_parallel_optimizer=False)
  375. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  376. strategy1 = ((16, 1), (1, 1))
  377. strategy2 = ((8, 1), (1, 1))
  378. pipeline_net = PipelineSplit(strategy1, strategy2, dtype=ms.float16)
  379. run_pipeline_split_function(pipeline_net, micro_batch_interleaved=2)
  380. self.cat_fp16_from_ir(pattern='grad_mirror_MirrorMicroStepOperator',
  381. target_count=2)
  382. self.cat_fp16_from_ir(pattern='Cast(',
  383. target_count=27)
  384. def test_pipeline_split_stage0_device_num_48():
  385. """
  386. Feature: test PipelineSplit with 48 devices in auto parallel.
  387. Description: net with pipeline parallel in auto parallel mode using 48 devices, stage0.
  388. Expectation: success.
  389. """
  390. context.set_auto_parallel_context(device_num=48, global_rank=0, pipeline_stages=2)
  391. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  392. context.set_context(device_target="Ascend")
  393. data = Tensor(np.ones([32 * 6, 64]), dtype=ms.float32)
  394. label = Tensor(np.ones([64 * 6, 64]), dtype=ms.float32)
  395. strategy1 = ((3, 8), (8, 1))
  396. strategy2 = ((24, 1), (1, 1))
  397. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  398. params = net.network.cell.block[0].trainable_params()
  399. dataset = DatasetLenet(data, label, 3)
  400. optimizer = nn.Lamb(params, learning_rate=0.01)
  401. model = Model(net, optimizer=optimizer)
  402. model.train(2, dataset, dataset_sink_mode=False)
  403. for _, param in model._train_network.parameters_and_names():
  404. assert param.name != "cell.block.1.param"
  405. assert param.name != "cell.block.1.param1"
  406. def test_pipeline_split_stage1_device_num_48():
  407. """
  408. Feature: test PipelineSplit with 48 devices in auto parallel.
  409. Description: net with pipeline parallel in auto parallel mode using 48 devices, stage1.
  410. Expectation: success.
  411. """
  412. context.set_auto_parallel_context(device_num=48, global_rank=24, pipeline_stages=2)
  413. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  414. context.set_context(device_target="Ascend")
  415. data = Tensor(np.ones([32 * 6, 64]), dtype=ms.float32)
  416. label = Tensor(np.ones([64 * 6, 64]), dtype=ms.float32)
  417. strategy1 = ((3, 8), (8, 1))
  418. strategy2 = ((24, 1), (1, 1))
  419. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  420. params = net.network.cell.block[1].trainable_params()
  421. dataset = DatasetLenet(data, label, 3)
  422. optimizer = nn.Lamb(params, learning_rate=0.01)
  423. model = Model(net, optimizer=optimizer)
  424. model.train(2, dataset, dataset_sink_mode=False)
  425. for _, param in model._train_network.parameters_and_names():
  426. assert param.name != "cell.block.0.param"
  427. assert param.name != "cell.block.0.param1"