You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_transformer.py 33 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import shutil
  16. import glob
  17. import numpy as np
  18. import pytest
  19. import mindspore.common.dtype as mstype
  20. import mindspore.nn as nn
  21. from mindspore import Tensor
  22. from mindspore.context import set_auto_parallel_context, ParallelMode
  23. from mindspore import context
  24. from mindspore.ops import composite as C
  25. from mindspore.ops import functional as F
  26. import mindspore.ops as P
  27. from mindspore.parallel.nn import TransformerEncoder, TransformerDecoder, Transformer, TransformerOpParallelConfig, \
  28. VocabEmbedding, CrossEntropyLoss, OpParallelConfig, EmbeddingOpParallelConfig, FixedSparseAttention
  29. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  30. from mindspore.nn.optim import AdamWeightDecay
  31. from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, TrainOneStepCell
  32. from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
  33. from mindspore.train import Model
  34. from mindspore.parallel import set_algo_parameters
  35. from tests.dataset_mock import MindData
  36. from tests.ut.python.ops.test_math_ops import VirtualLoss
  37. grad_all = C.GradOperation(get_all=True)
  38. class Dataset(MindData):
  39. def __init__(self, *inputs, length=3):
  40. super(Dataset, self).__init__(size=length)
  41. self.inputs = inputs
  42. self.index = 0
  43. self.length = length
  44. def __iter__(self):
  45. return self
  46. def __next__(self):
  47. if self.index >= self.length:
  48. raise StopIteration
  49. self.index += 1
  50. return self.inputs
  51. def reset(self):
  52. self.index = 0
  53. class TransformerNet(nn.Cell):
  54. def __init__(self, en_layer, de_layer, parallel_config):
  55. super(TransformerNet, self).__init__()
  56. self.network = Transformer(encoder_layers=en_layer,
  57. decoder_layers=de_layer,
  58. batch_size=2,
  59. src_seq_length=20,
  60. tgt_seq_length=10,
  61. hidden_size=64,
  62. num_heads=8,
  63. ffn_hidden_size=64,
  64. parallel_config=parallel_config)
  65. self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
  66. def construct(self, x1, x2, x3, x4, x5, y, mask):
  67. predict, _, _ = self.network(x1, x2, x3, x4, x5)
  68. predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
  69. return self.loss(predict, y, mask)
  70. class TransformerEncoderNet(nn.Cell):
  71. def __init__(self, batch_size, en_layer, de_layer, parallel_config):
  72. super(TransformerEncoderNet, self).__init__()
  73. self.embedding = VocabEmbedding(vocab_size=240, embedding_size=64,
  74. parallel_config=parallel_config.embedding_dp_mp_config)
  75. self.network = Transformer(encoder_layers=en_layer,
  76. decoder_layers=de_layer,
  77. batch_size=batch_size,
  78. src_seq_length=20,
  79. tgt_seq_length=10,
  80. hidden_size=64,
  81. num_heads=8,
  82. ffn_hifloat16dden_size=64,
  83. parallel_config=parallel_config)
  84. self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
  85. def construct(self, x, encoder_mask, label, input_mask):
  86. embedded, _ = self.embedding(x)
  87. logits, _, = self.network(embedded, encoder_mask)
  88. logits = P.Reshape()(logits, (-1, F.shape(logits)[-1]))
  89. label = P.Reshape()(label, (-1,))
  90. input_mask = P.Reshape()(input_mask, (-1,))
  91. return self.loss(logits, label, input_mask)
  92. config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
  93. pipeline_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, pipeline_stage=4,
  94. micro_batch_num=4, vocab_emb_dp=False)
  95. class NetWithLossFiveInputs(nn.Cell):
  96. def __init__(self, network):
  97. super(NetWithLossFiveInputs, self).__init__()
  98. self.loss = VirtualLoss()
  99. self.network = network
  100. def construct(self, x1, x2, x3, x4, x5):
  101. predict, _, _ = self.network(x1, x2, x3, x4, x5)
  102. return self.loss(predict)
  103. def run_network_function(dataset, pipeline_net):
  104. """
  105. Feature: Test transformer embedding shared.
  106. Description: a basic function for test compiling.
  107. Expectation: success.
  108. """
  109. params = pipeline_net.trainable_params()
  110. optimizer = nn.Lamb(params, learning_rate=0.01)
  111. model = Model(pipeline_net, optimizer=optimizer)
  112. model.train(2, dataset, dataset_sink_mode=False)
  113. def run_total_transformer_model_head(e_layer,
  114. d_layer,
  115. arg_parallel_config,
  116. mode=ParallelMode.SEMI_AUTO_PARALLEL):
  117. dp = arg_parallel_config.data_parallel
  118. mp = arg_parallel_config.model_parallel
  119. pp = arg_parallel_config.pipeline_stage
  120. if dp * mp * pp != 1:
  121. set_auto_parallel_context(device_num=8,
  122. full_batch=True,
  123. global_rank=0, parallel_mode=mode)
  124. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
  125. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  126. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  127. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  128. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  129. seq = 20
  130. if d_layer > 0:
  131. seq = 10
  132. label = Tensor(np.ones((2 * seq,)), mstype.int32)
  133. input_mask = Tensor(np.ones((2 * seq,)), mstype.float32)
  134. net = TransformerNet(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
  135. net = _VirtualDatasetCell(net)
  136. params = net.trainable_params()
  137. optimizer = AdamWeightDecay(params)
  138. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  139. memory_mask, label, input_mask)
  140. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  141. model = Model(net_with_grad)
  142. model.train(1, dataset, dataset_sink_mode=False)
  143. def test_transformer_model():
  144. set_auto_parallel_context(device_num=8, global_rank=0,
  145. full_batch=True,
  146. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  147. net = Transformer(encoder_layers=1,
  148. decoder_layers=2,
  149. batch_size=2,
  150. src_seq_length=20,
  151. tgt_seq_length=10,
  152. hidden_size=64,
  153. num_heads=8,
  154. ffn_hidden_size=64,
  155. parallel_config=config)
  156. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
  157. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  158. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  159. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  160. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  161. net = NetWithLossFiveInputs(net)
  162. net = _VirtualDatasetCell(net)
  163. params = net.trainable_params()
  164. optimizer = AdamWeightDecay(params)
  165. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  166. memory_mask)
  167. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  168. model = Model(net_with_grad)
  169. model.train(1, dataset, dataset_sink_mode=False)
  170. def test_transformer_model_2d_inputs():
  171. set_auto_parallel_context(device_num=8, global_rank=0,
  172. full_batch=True,
  173. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  174. net = Transformer(encoder_layers=1,
  175. decoder_layers=2,
  176. batch_size=2,
  177. src_seq_length=20,
  178. tgt_seq_length=10,
  179. hidden_size=64,
  180. num_heads=8,
  181. ffn_hidden_size=64,
  182. parallel_config=config)
  183. encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
  184. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  185. decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
  186. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  187. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  188. net = NetWithLossFiveInputs(net)
  189. net = _VirtualDatasetCell(net)
  190. params = net.trainable_params()
  191. optimizer = AdamWeightDecay(params)
  192. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  193. memory_mask)
  194. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  195. model = Model(net_with_grad)
  196. model.train(1, dataset, dataset_sink_mode=False)
  197. class TestTransformerEmbeddingHead:
  198. def __init__(self):
  199. self.output_path = None
  200. def setup_method(self):
  201. self.output_path = './graphs' + self.__str__()
  202. context.set_context(save_graphs=True,
  203. save_graphs_path=self.output_path)
  204. def teardown_method(self):
  205. shutil.rmtree(self.output_path)
  206. def virtual_assign_add_from_ir(self, pattern, target_count):
  207. """
  208. This function will check the assign aa count with the golden one.
  209. :param pattern: The match pattern for the specific count
  210. :param target_count: The gold float16 count in the Ir files
  211. """
  212. ir_files = glob.glob(os.path.join(self.output_path, 'rank_0', '*_validate*.ir'))
  213. assert len(ir_files) == 1
  214. appear_count = 0
  215. with open(ir_files[0], 'r') as fp:
  216. for line in fp:
  217. if pattern in line:
  218. appear_count += 1
  219. assert appear_count == target_count
  220. def test_pipeline_with_embedding(self):
  221. """
  222. Feature: Test Transformer with embedding as shared
  223. Description: When do pipeline training and applied optimzier shard, the embedding which is model parallel will
  224. raise the shape error. This test cast is ensure there is no error raised.
  225. Expectation: The number of AssignAdd is not as expected.
  226. """
  227. bs = 16
  228. pp = 2
  229. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=pp,
  230. full_batch=True,
  231. enable_parallel_optimizer=True)
  232. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  233. cf = TransformerOpParallelConfig(data_parallel=1, model_parallel=4, pipeline_stage=pp, vocab_emb_dp=False)
  234. pipeline_net = TransformerEncoderNet(batch_size=bs // pp,
  235. en_layer=2, de_layer=0, parallel_config=cf)
  236. pipeline_net.embedding.pipeline_stage = 0
  237. pipeline_net.network.encoder.blocks[0].pipeline_stage = 0
  238. pipeline_net.network.encoder.blocks[1].pipeline_stage = 1
  239. pipeline_cell_net = PipelineCell(pipeline_net, 2)
  240. encoder_input_value = Tensor(np.ones((bs, 20)), mstype.int32)
  241. encoder_input_mask = Tensor(np.ones((bs, 20, 20)), mstype.float16)
  242. label = Tensor(np.ones((bs, 20)), mstype.int32)
  243. mask = Tensor(np.ones((bs, 20)), mstype.float32)
  244. dataset = Dataset(encoder_input_value, encoder_input_mask, label, mask)
  245. run_network_function(dataset, pipeline_cell_net)
  246. self.virtual_assign_add_from_ir(pattern=r'AssignAdd(', target_count=35)
  247. def test_transformer_model_int64_inputs():
  248. set_auto_parallel_context(device_num=8, global_rank=0,
  249. full_batch=True,
  250. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  251. net = Transformer(encoder_layers=1,
  252. decoder_layers=2,
  253. batch_size=2,
  254. src_seq_length=20,
  255. tgt_seq_length=10,
  256. hidden_size=64,
  257. num_heads=8,
  258. ffn_hidden_size=64,
  259. parallel_config=config)
  260. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.int64)
  261. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  262. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  263. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  264. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  265. net = NetWithLossFiveInputs(net)
  266. net = _VirtualDatasetCell(net)
  267. params = net.trainable_params()
  268. optimizer = AdamWeightDecay(params)
  269. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  270. memory_mask)
  271. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  272. model = Model(net_with_grad)
  273. with pytest.raises(TypeError):
  274. model.train(1, dataset, dataset_sink_mode=False)
  275. def test_transformer_model_head_parallel_only_encoder():
  276. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
  277. run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config)
  278. def test_transformer_model_head_parallel():
  279. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
  280. run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config)
  281. def test_transformer_model_head_parallel_decoder():
  282. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
  283. with pytest.raises(ValueError):
  284. run_total_transformer_model_head(e_layer=0, d_layer=1, arg_parallel_config=local_config)
  285. def test_transformer_model_head_stand_alone():
  286. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=1)
  287. run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config)
  288. def test_transformer_model_auto_parallel_no_support():
  289. local_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1)
  290. with pytest.raises(RuntimeError):
  291. run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config,
  292. mode=ParallelMode.AUTO_PARALLEL)
  293. def pipeline_single_transformer(grad_accumulation_shard=False):
  294. """
  295. Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
  296. Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
  297. Expectation: The compile passed
  298. """
  299. set_auto_parallel_context(device_num=64,
  300. full_batch=True,
  301. pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
  302. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  303. context.set_auto_parallel_context(parallel_optimizer_config=
  304. {"gradient_accumulation_shard": grad_accumulation_shard})
  305. net = Transformer(batch_size=8 // pipeline_config.micro_batch_num,
  306. src_seq_length=20,
  307. tgt_seq_length=10,
  308. encoder_layers=2,
  309. decoder_layers=2,
  310. hidden_size=64,
  311. num_heads=8,
  312. ffn_hidden_size=64,
  313. parallel_config=pipeline_config)
  314. encoder_input_value = Tensor(np.ones((8, 20, 64)), mstype.float32)
  315. encoder_input_mask = Tensor(np.ones((8, 20, 20)), mstype.float16)
  316. decoder_input_value = Tensor(np.ones((8, 10, 64)), mstype.float32)
  317. decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16)
  318. memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16)
  319. net = NetWithLossFiveInputs(net)
  320. net = PipelineCell(net, pipeline_config.micro_batch_num)
  321. net = _VirtualDatasetCell(net)
  322. params = net.infer_param_pipeline_stage()
  323. optimizer = AdamWeightDecay(params)
  324. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  325. memory_mask)
  326. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=1024, scale_factor=2, scale_window=1000)
  327. net_with_grad = _TrainPipelineWithLossScaleCell(net, optimizer=optimizer,
  328. scale_sense=update_cell)
  329. model = Model(net_with_grad)
  330. model.train(1, dataset, dataset_sink_mode=False)
  331. def test_pipeline_transformer_gradient_shard_true():
  332. """
  333. Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
  334. Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard True
  335. Expectation: The compile passed
  336. """
  337. pipeline_single_transformer(grad_accumulation_shard=True)
  338. def test_pipeline_transformer_gradient_shard_false():
  339. """
  340. Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
  341. Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
  342. Expectation: The compile passed
  343. """
  344. pipeline_single_transformer(grad_accumulation_shard=False)
  345. def test_transformer_wrong_head():
  346. set_auto_parallel_context(device_num=64,
  347. full_batch=True,
  348. pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
  349. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  350. error_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
  351. with pytest.raises(ValueError):
  352. net = Transformer(batch_size=4,
  353. src_seq_length=20,
  354. tgt_seq_length=10,
  355. encoder_layers=2,
  356. decoder_layers=2,
  357. hidden_size=64,
  358. num_heads=7,
  359. ffn_hidden_size=64,
  360. parallel_config=error_test_config)
  361. with pytest.raises(ValueError):
  362. net = Transformer(batch_size=4,
  363. src_seq_length=20,
  364. tgt_seq_length=10,
  365. encoder_layers=2,
  366. decoder_layers=2,
  367. hidden_size=63,
  368. num_heads=7,
  369. ffn_hidden_size=64,
  370. parallel_config=error_test_config)
  371. del net
  372. def test_transformer_wrong_dp_no_error():
  373. set_auto_parallel_context(device_num=64, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL,
  374. pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
  375. check_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1, vocab_emb_dp=False)
  376. net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
  377. decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
  378. parallel_config=check_config)
  379. del net
  380. def test_transformer_wrong_semi_auto_dp_error():
  381. set_auto_parallel_context(device_num=64, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
  382. pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
  383. check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False)
  384. with pytest.raises(ValueError):
  385. net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
  386. decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
  387. parallel_config=check_config)
  388. del net
  389. def test_encoder():
  390. class NetWithLoss(nn.Cell):
  391. def __init__(self, network):
  392. super(NetWithLoss, self).__init__()
  393. self.loss = VirtualLoss()
  394. self.network = network
  395. def construct(self, x1, x2):
  396. predict, _ = self.network(x1, x2)
  397. return self.loss(predict)
  398. set_auto_parallel_context(device_num=8,
  399. full_batch=True,
  400. global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  401. net = TransformerEncoder(num_layers=2,
  402. batch_size=2,
  403. seq_length=16,
  404. hidden_size=8,
  405. ffn_hidden_size=64,
  406. num_heads=8,
  407. parallel_config=config)
  408. encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
  409. encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
  410. net = NetWithLoss(net)
  411. net = _VirtualDatasetCell(net)
  412. dataset = Dataset(encoder_input_value, encoder_input_mask)
  413. model = Model(net)
  414. model.train(1, dataset, dataset_sink_mode=False)
  415. def test_decoder():
  416. class NetWithLoss(nn.Cell):
  417. def __init__(self, network):
  418. super(NetWithLoss, self).__init__()
  419. self.loss = VirtualLoss()
  420. self.network = network
  421. def construct(self, x1, x2, x3, x4):
  422. predict, _, _ = self.network(x1, x2, x3, x4)
  423. return self.loss(predict)
  424. set_auto_parallel_context(device_num=8,
  425. full_batch=True,
  426. global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  427. net = TransformerDecoder(num_layers=1,
  428. batch_size=8,
  429. hidden_size=16,
  430. ffn_hidden_size=8,
  431. num_heads=8,
  432. src_seq_length=20,
  433. tgt_seq_length=10,
  434. parallel_config=config)
  435. encoder_input_value = Tensor(np.ones((8, 20, 16)), mstype.float32)
  436. decoder_input_value = Tensor(np.ones((8, 10, 16)), mstype.float32)
  437. decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16)
  438. memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16)
  439. net = NetWithLoss(net)
  440. net = _VirtualDatasetCell(net)
  441. dataset = Dataset(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
  442. model = Model(net)
  443. model.train(1, dataset, dataset_sink_mode=False)
  444. def test_vocabembedding_dp_true():
  445. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  446. class NetWithLoss(nn.Cell):
  447. def __init__(self, network):
  448. super(NetWithLoss, self).__init__()
  449. self.loss = VirtualLoss()
  450. self.network = network
  451. def construct(self, x1):
  452. predict, _ = self.network(x1)
  453. return self.loss(predict)
  454. net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
  455. net = NetWithLoss(net)
  456. net = _VirtualDatasetCell(net)
  457. encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
  458. dataset = Dataset(encoder_input_value)
  459. model = Model(net)
  460. model.train(1, dataset, dataset_sink_mode=False)
  461. def test_vocabembedding_dp_false():
  462. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  463. class NetWithLoss(nn.Cell):
  464. def __init__(self, network):
  465. super(NetWithLoss, self).__init__()
  466. self.loss = VirtualLoss()
  467. self.network = network
  468. def construct(self, x1):
  469. predict, _ = self.network(x1)
  470. return self.loss(predict)
  471. net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
  472. net = NetWithLoss(net)
  473. net = _VirtualDatasetCell(net)
  474. encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
  475. dataset = Dataset(encoder_input_value)
  476. model = Model(net)
  477. model.train(1, dataset, dataset_sink_mode=False)
  478. def test_sparse_attention_parallel_mp():
  479. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  480. set_algo_parameters(fully_use_devices=False)
  481. sparse_attention_config = OpParallelConfig(model_parallel=8)
  482. net = FixedSparseAttention(batch_size=16,
  483. seq_length=1024,
  484. size_per_head=64,
  485. num_heads=8,
  486. block_size=64,
  487. parallel_config=sparse_attention_config)
  488. q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  489. k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  490. v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  491. mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
  492. dataset = Dataset(q, k, v, mask)
  493. model = Model(net)
  494. model.train(1, dataset, dataset_sink_mode=False)
  495. def test_sparse_attention_parallel_mix():
  496. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  497. set_algo_parameters(fully_use_devices=False)
  498. sparse_attention_config = OpParallelConfig(data_parallel=2, model_parallel=4)
  499. net = FixedSparseAttention(batch_size=16,
  500. seq_length=1024,
  501. size_per_head=64,
  502. num_heads=8,
  503. block_size=64,
  504. parallel_config=sparse_attention_config)
  505. q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  506. k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  507. v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  508. mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
  509. dataset = Dataset(q, k, v, mask)
  510. model = Model(net)
  511. model.train(1, dataset, dataset_sink_mode=False)
  512. def test_sparse_attention_parallel_mix1():
  513. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  514. set_algo_parameters(fully_use_devices=False)
  515. sparse_attention_config = OpParallelConfig(data_parallel=4, model_parallel=2)
  516. net = FixedSparseAttention(batch_size=16,
  517. seq_length=1024,
  518. size_per_head=64,
  519. num_heads=8,
  520. block_size=64,
  521. parallel_config=sparse_attention_config)
  522. q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  523. k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  524. v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  525. mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
  526. dataset = Dataset(q, k, v, mask)
  527. model = Model(net)
  528. model.train(1, dataset, dataset_sink_mode=False)
  529. def test_sparse_attention_parallel_dp():
  530. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  531. set_algo_parameters(fully_use_devices=False)
  532. sparse_attention_config = OpParallelConfig(data_parallel=8, model_parallel=1)
  533. net = FixedSparseAttention(batch_size=16,
  534. seq_length=1024,
  535. size_per_head=64,
  536. num_heads=8,
  537. block_size=64,
  538. parallel_config=sparse_attention_config)
  539. net = _VirtualDatasetCell(net)
  540. q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  541. k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  542. v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  543. mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
  544. dataset = Dataset(q, k, v, mask)
  545. model = Model(net)
  546. model.train(1, dataset, dataset_sink_mode=False)
  547. def test_parallel_cross_entroy_loss_semi_auto_parallel():
  548. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  549. class NetWithLoss(nn.Cell):
  550. def __init__(self, network, config_setting):
  551. super(NetWithLoss, self).__init__()
  552. self.loss = CrossEntropyLoss(config_setting)
  553. self.network = network
  554. def construct(self, x1, x2, x3):
  555. predict, _ = self.network(x1)
  556. predict = P.Reshape()(predict, (-1, 16))
  557. return self.loss(predict, x2, x3)
  558. net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
  559. net = NetWithLoss(net, config.dp_mp_config)
  560. net = _VirtualDatasetCell(net)
  561. embed_ids = Tensor(np.ones((2, 64)), mstype.int32)
  562. labels = Tensor(np.ones((2 * 64,)), mstype.int32)
  563. input_mask = Tensor(np.ones((2 * 64,)), mstype.float32)
  564. dataset = Dataset(embed_ids, labels, input_mask)
  565. model = Model(net)
  566. model.train(1, dataset, dataset_sink_mode=False)
  567. def test_transformer_args():
  568. with pytest.raises(TypeError):
  569. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  570. tgt_seq_length=20, decoder_layers="aa")
  571. with pytest.raises(TypeError):
  572. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  573. tgt_seq_length="a")
  574. with pytest.raises(TypeError):
  575. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  576. tgt_seq_length=20, softmax_compute_type=mstype.int64)
  577. with pytest.raises(TypeError):
  578. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  579. tgt_seq_length=20, layernorm_compute_type=mstype.int64)
  580. with pytest.raises(TypeError):
  581. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  582. tgt_seq_length=20, param_init_type=mstype.int64)
  583. with pytest.raises(TypeError):
  584. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  585. tgt_seq_length=20, hidden_dropout_rate=mstype.int64)
  586. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  587. tgt_seq_length=20, softmax_compute_type=mstype.float16)
  588. def test_transformer_parallel_config():
  589. parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3)
  590. with pytest.raises(TypeError):
  591. parallel_test_config.data_parallel = False
  592. with pytest.raises(ValueError):
  593. parallel_test_config.data_parallel = 0
  594. with pytest.raises(TypeError):
  595. parallel_test_config.model_parallel = False
  596. with pytest.raises(ValueError):
  597. parallel_test_config.model_parallel = 0
  598. with pytest.raises(TypeError):
  599. parallel_test_config.pipeline_stage = False
  600. with pytest.raises(ValueError):
  601. parallel_test_config.pipeline_stage = 0
  602. with pytest.raises(TypeError):
  603. parallel_test_config.micro_batch_num = False
  604. with pytest.raises(ValueError):
  605. parallel_test_config.micro_batch_num = 0
  606. with pytest.raises(TypeError):
  607. parallel_test_config.gradient_aggregation_group = False
  608. with pytest.raises(ValueError):
  609. parallel_test_config.gradient_aggregation_group = 0
  610. with pytest.raises(TypeError):
  611. parallel_test_config.recompute = 1
  612. parallel_test_config.recompute.recompute = False
  613. assert not parallel_test_config.recompute.recompute
  614. def test_parallel_config():
  615. parallel_test_config = OpParallelConfig(data_parallel=1, model_parallel=3)
  616. with pytest.raises(ValueError):
  617. parallel_test_config.data_parallel = 0
  618. with pytest.raises(TypeError):
  619. parallel_test_config.model_parallel = False
  620. with pytest.raises(ValueError):
  621. parallel_test_config.model_parallel = 0
  622. assert parallel_test_config.model_parallel == 3
  623. def test_embedding_parallel_config():
  624. parallel_test_config = EmbeddingOpParallelConfig(data_parallel=1, model_parallel=3, vocab_emb_dp=False)
  625. with pytest.raises(ValueError):
  626. parallel_test_config.data_parallel = 0
  627. with pytest.raises(TypeError):
  628. parallel_test_config.model_parallel = False
  629. with pytest.raises(ValueError):
  630. parallel_test_config.model_parallel = 0
  631. with pytest.raises(TypeError):
  632. parallel_test_config.vocab_emb_dp = 0
  633. assert not parallel_test_config.vocab_emb_dp