You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_transformer.py 29 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore.common.dtype as mstype
  17. import mindspore.nn as nn
  18. from mindspore import Tensor
  19. from mindspore.context import set_auto_parallel_context, ParallelMode
  20. from mindspore import context
  21. from mindspore.ops import composite as C
  22. from mindspore.ops import functional as F
  23. import mindspore.ops as P
  24. from mindspore.parallel.nn import TransformerEncoder, TransformerDecoder, Transformer, TransformerOpParallelConfig, \
  25. VocabEmbedding, CrossEntropyLoss, OpParallelConfig, EmbeddingOpParallelConfig, FixedSparseAttention
  26. from mindspore.nn import Dense as Linear
  27. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  28. from mindspore.nn.optim import AdamWeightDecay
  29. from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, TrainOneStepCell
  30. from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
  31. from mindspore.train import Model
  32. from mindspore.parallel import set_algo_parameters
  33. from tests.dataset_mock import MindData
  34. from tests.ut.python.ops.test_math_ops import VirtualLoss
  35. grad_all = C.GradOperation(get_all=True)
  36. class Dataset(MindData):
  37. def __init__(self, *inputs, length=3):
  38. super(Dataset, self).__init__(size=length)
  39. self.inputs = inputs
  40. self.index = 0
  41. self.length = length
  42. def __iter__(self):
  43. return self
  44. def __next__(self):
  45. if self.index >= self.length:
  46. raise StopIteration
  47. self.index += 1
  48. return self.inputs
  49. def reset(self):
  50. self.index = 0
  51. class TransformerNet(nn.Cell):
  52. def __init__(self, en_layer, de_layer, parallel_config):
  53. super(TransformerNet, self).__init__()
  54. self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
  55. parallel_config=config.embedding_dp_mp_config)
  56. self.network = Transformer(encoder_layers=en_layer,
  57. decoder_layers=de_layer,
  58. batch_size=2,
  59. src_seq_length=20,
  60. tgt_seq_length=10,
  61. hidden_size=64,
  62. num_heads=8,
  63. ffn_hidden_size=64,
  64. parallel_config=parallel_config)
  65. self.head = Linear(in_channels=64, out_channels=200)
  66. self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
  67. def construct(self, x1, x2, x3, x4, x5, y, mask):
  68. predict, _, _ = self.network(x1, x2, x3, x4, x5)
  69. predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
  70. return self.loss(predict, y, mask)
  71. config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
  72. pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
  73. micro_batch_num=4, vocab_emb_dp=False)
  74. class NetWithLossFiveInputs(nn.Cell):
  75. def __init__(self, network):
  76. super(NetWithLossFiveInputs, self).__init__()
  77. self.loss = VirtualLoss()
  78. self.network = network
  79. def construct(self, x1, x2, x3, x4, x5):
  80. predict, _, _ = self.network(x1, x2, x3, x4, x5)
  81. return self.loss(predict)
  82. def run_total_transformer_model_head(e_layer,
  83. d_layer,
  84. arg_parallel_config,
  85. mode=ParallelMode.SEMI_AUTO_PARALLEL):
  86. dp = arg_parallel_config.data_parallel
  87. mp = arg_parallel_config.model_parallel
  88. pp = arg_parallel_config.pipeline_stage
  89. if dp * mp * pp != 1:
  90. set_auto_parallel_context(device_num=8,
  91. full_batch=True,
  92. global_rank=0, parallel_mode=mode)
  93. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
  94. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  95. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  96. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  97. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  98. seq = 20
  99. if d_layer > 0:
  100. seq = 10
  101. label = Tensor(np.ones((2 * seq,)), mstype.int32)
  102. input_mask = Tensor(np.ones((2 * seq,)), mstype.float32)
  103. net = TransformerNet(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
  104. net = _VirtualDatasetCell(net)
  105. params = net.trainable_params()
  106. optimizer = AdamWeightDecay(params)
  107. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  108. memory_mask, label, input_mask)
  109. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  110. model = Model(net_with_grad)
  111. model.train(1, dataset, dataset_sink_mode=False)
  112. def test_transformer_model():
  113. set_auto_parallel_context(device_num=8, global_rank=0,
  114. full_batch=True,
  115. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  116. net = Transformer(encoder_layers=1,
  117. decoder_layers=2,
  118. batch_size=2,
  119. src_seq_length=20,
  120. tgt_seq_length=10,
  121. hidden_size=64,
  122. num_heads=8,
  123. ffn_hidden_size=64,
  124. parallel_config=config)
  125. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
  126. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  127. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  128. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  129. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  130. net = NetWithLossFiveInputs(net)
  131. net = _VirtualDatasetCell(net)
  132. params = net.trainable_params()
  133. optimizer = AdamWeightDecay(params)
  134. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  135. memory_mask)
  136. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  137. model = Model(net_with_grad)
  138. model.train(1, dataset, dataset_sink_mode=False)
  139. def test_transformer_model_2d_inputs():
  140. set_auto_parallel_context(device_num=8, global_rank=0,
  141. full_batch=True,
  142. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  143. net = Transformer(encoder_layers=1,
  144. decoder_layers=2,
  145. batch_size=2,
  146. src_seq_length=20,
  147. tgt_seq_length=10,
  148. hidden_size=64,
  149. num_heads=8,
  150. ffn_hidden_size=64,
  151. parallel_config=config)
  152. encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
  153. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  154. decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
  155. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  156. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  157. net = NetWithLossFiveInputs(net)
  158. net = _VirtualDatasetCell(net)
  159. params = net.trainable_params()
  160. optimizer = AdamWeightDecay(params)
  161. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  162. memory_mask)
  163. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  164. model = Model(net_with_grad)
  165. model.train(1, dataset, dataset_sink_mode=False)
  166. def test_transformer_model_int64_inputs():
  167. set_auto_parallel_context(device_num=8, global_rank=0,
  168. full_batch=True,
  169. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  170. net = Transformer(encoder_layers=1,
  171. decoder_layers=2,
  172. batch_size=2,
  173. src_seq_length=20,
  174. tgt_seq_length=10,
  175. hidden_size=64,
  176. num_heads=8,
  177. ffn_hidden_size=64,
  178. parallel_config=config)
  179. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.int64)
  180. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  181. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  182. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  183. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  184. net = NetWithLossFiveInputs(net)
  185. net = _VirtualDatasetCell(net)
  186. params = net.trainable_params()
  187. optimizer = AdamWeightDecay(params)
  188. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  189. memory_mask)
  190. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  191. model = Model(net_with_grad)
  192. with pytest.raises(TypeError):
  193. model.train(1, dataset, dataset_sink_mode=False)
  194. def test_transformer_model_head_parallel_only_encoder():
  195. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
  196. run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config)
  197. def test_transformer_model_head_parallel():
  198. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
  199. run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config)
  200. def test_transformer_model_head_parallel_decoder():
  201. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
  202. with pytest.raises(ValueError):
  203. run_total_transformer_model_head(e_layer=0, d_layer=1, arg_parallel_config=local_config)
  204. def test_transformer_model_head_stand_alone():
  205. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=1)
  206. run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config)
  207. def test_transformer_model_auto_parallel_no_support():
  208. local_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1)
  209. with pytest.raises(RuntimeError):
  210. run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config,
  211. mode=ParallelMode.AUTO_PARALLEL)
  212. def pipeline_single_transformer(grad_accumulation_shard=False):
  213. """
  214. Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
  215. Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
  216. Expectation: The compile passed
  217. """
  218. set_auto_parallel_context(device_num=32,
  219. full_batch=True,
  220. pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
  221. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  222. context.set_auto_parallel_context(parallel_optimizer_config=
  223. {"gradient_accumulation_shard": grad_accumulation_shard})
  224. net = Transformer(batch_size=4 // pipeline_config.micro_batch_num,
  225. src_seq_length=20,
  226. tgt_seq_length=10,
  227. encoder_layers=2,
  228. decoder_layers=2,
  229. hidden_size=64,
  230. num_heads=8,
  231. ffn_hidden_size=64,
  232. parallel_config=pipeline_config)
  233. encoder_input_value = Tensor(np.ones((4, 20, 64)), mstype.float32)
  234. encoder_input_mask = Tensor(np.ones((4, 20, 20)), mstype.float16)
  235. decoder_input_value = Tensor(np.ones((4, 10, 64)), mstype.float32)
  236. decoder_input_mask = Tensor(np.ones((4, 10, 10)), mstype.float16)
  237. memory_mask = Tensor(np.ones((4, 10, 20)), mstype.float16)
  238. net = NetWithLossFiveInputs(net)
  239. net = PipelineCell(net, pipeline_config.micro_batch_num)
  240. net = _VirtualDatasetCell(net)
  241. params = net.infer_param_pipeline_stage()
  242. optimizer = AdamWeightDecay(params)
  243. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  244. memory_mask)
  245. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=1024, scale_factor=2, scale_window=1000)
  246. net_with_grad = _TrainPipelineWithLossScaleCell(net, optimizer=optimizer,
  247. scale_sense=update_cell)
  248. model = Model(net_with_grad)
  249. model.train(1, dataset, dataset_sink_mode=False)
  250. def test_pipeline_transformer_gradient_shard_true():
  251. """
  252. Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
  253. Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard True
  254. Expectation: The compile passed
  255. """
  256. pipeline_single_transformer(grad_accumulation_shard=True)
  257. def test_pipeline_transformer_gradient_shard_false():
  258. """
  259. Feature: Gradient Accumulation Shard for Pipeline and Gradient Accumulation
  260. Description: Test a single transformer model with pipeline parallel with grad_accumulation_shard False
  261. Expectation: The compile passed
  262. """
  263. pipeline_single_transformer(grad_accumulation_shard=False)
  264. def test_transformer_wrong_head():
  265. set_auto_parallel_context(device_num=32,
  266. full_batch=True,
  267. pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
  268. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  269. error_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
  270. with pytest.raises(ValueError):
  271. net = Transformer(batch_size=4,
  272. src_seq_length=20,
  273. tgt_seq_length=10,
  274. encoder_layers=2,
  275. decoder_layers=2,
  276. hidden_size=64,
  277. num_heads=7,
  278. ffn_hidden_size=64,
  279. parallel_config=error_test_config)
  280. with pytest.raises(ValueError):
  281. net = Transformer(batch_size=4,
  282. src_seq_length=20,
  283. tgt_seq_length=10,
  284. encoder_layers=2,
  285. decoder_layers=2,
  286. hidden_size=63,
  287. num_heads=7,
  288. ffn_hidden_size=64,
  289. parallel_config=error_test_config)
  290. del net
  291. def test_transformer_wrong_dp_no_error():
  292. set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.DATA_PARALLEL,
  293. pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
  294. check_config = TransformerOpParallelConfig(data_parallel=8, model_parallel=1, vocab_emb_dp=False)
  295. net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
  296. decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
  297. parallel_config=check_config)
  298. del net
  299. def test_transformer_wrong_semi_auto_dp_error():
  300. set_auto_parallel_context(device_num=32, full_batch=False, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
  301. pipeline_stages=pipeline_config.pipeline_stage, global_rank=0)
  302. check_config = TransformerOpParallelConfig(data_parallel=16, model_parallel=1, vocab_emb_dp=False)
  303. with pytest.raises(ValueError):
  304. net = Transformer(batch_size=4, src_seq_length=20, tgt_seq_length=10, encoder_layers=2,
  305. decoder_layers=2, hidden_size=64, num_heads=2, ffn_hidden_size=64,
  306. parallel_config=check_config)
  307. del net
  308. def test_encoder():
  309. class NetWithLoss(nn.Cell):
  310. def __init__(self, network):
  311. super(NetWithLoss, self).__init__()
  312. self.loss = VirtualLoss()
  313. self.network = network
  314. def construct(self, x1, x2):
  315. predict, _ = self.network(x1, x2)
  316. return self.loss(predict)
  317. set_auto_parallel_context(device_num=8,
  318. full_batch=True,
  319. global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  320. net = TransformerEncoder(num_layers=2,
  321. batch_size=2,
  322. seq_length=16,
  323. hidden_size=8,
  324. ffn_hidden_size=64,
  325. num_heads=8,
  326. parallel_config=config)
  327. encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
  328. encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
  329. net = NetWithLoss(net)
  330. net = _VirtualDatasetCell(net)
  331. dataset = Dataset(encoder_input_value, encoder_input_mask)
  332. model = Model(net)
  333. model.train(1, dataset, dataset_sink_mode=False)
  334. def test_decoder():
  335. class NetWithLoss(nn.Cell):
  336. def __init__(self, network):
  337. super(NetWithLoss, self).__init__()
  338. self.loss = VirtualLoss()
  339. self.network = network
  340. def construct(self, x1, x2, x3, x4):
  341. predict, _, _ = self.network(x1, x2, x3, x4)
  342. return self.loss(predict)
  343. set_auto_parallel_context(device_num=8,
  344. full_batch=True,
  345. global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  346. net = TransformerDecoder(num_layers=1,
  347. batch_size=8,
  348. hidden_size=16,
  349. ffn_hidden_size=8,
  350. num_heads=8,
  351. src_seq_length=20,
  352. tgt_seq_length=10,
  353. parallel_config=config)
  354. encoder_input_value = Tensor(np.ones((8, 20, 16)), mstype.float32)
  355. decoder_input_value = Tensor(np.ones((8, 10, 16)), mstype.float32)
  356. decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16)
  357. memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16)
  358. net = NetWithLoss(net)
  359. net = _VirtualDatasetCell(net)
  360. dataset = Dataset(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
  361. model = Model(net)
  362. model.train(1, dataset, dataset_sink_mode=False)
  363. def test_vocabembedding_dp_true():
  364. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  365. class NetWithLoss(nn.Cell):
  366. def __init__(self, network):
  367. super(NetWithLoss, self).__init__()
  368. self.loss = VirtualLoss()
  369. self.network = network
  370. def construct(self, x1):
  371. predict, _ = self.network(x1)
  372. return self.loss(predict)
  373. net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
  374. net = NetWithLoss(net)
  375. net = _VirtualDatasetCell(net)
  376. encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
  377. dataset = Dataset(encoder_input_value)
  378. model = Model(net)
  379. model.train(1, dataset, dataset_sink_mode=False)
  380. def test_vocabembedding_dp_false():
  381. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  382. class NetWithLoss(nn.Cell):
  383. def __init__(self, network):
  384. super(NetWithLoss, self).__init__()
  385. self.loss = VirtualLoss()
  386. self.network = network
  387. def construct(self, x1):
  388. predict, _ = self.network(x1)
  389. return self.loss(predict)
  390. net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
  391. net = NetWithLoss(net)
  392. net = _VirtualDatasetCell(net)
  393. encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
  394. dataset = Dataset(encoder_input_value)
  395. model = Model(net)
  396. model.train(1, dataset, dataset_sink_mode=False)
  397. def test_sparse_attention_parallel_mp():
  398. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  399. set_algo_parameters(fully_use_devices=False)
  400. sparse_attention_config = OpParallelConfig(model_parallel=8)
  401. net = FixedSparseAttention(batch_size=16,
  402. seq_length=1024,
  403. size_per_head=64,
  404. num_heads=8,
  405. block_size=64,
  406. parallel_config=sparse_attention_config)
  407. q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  408. k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  409. v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  410. mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
  411. dataset = Dataset(q, k, v, mask)
  412. model = Model(net)
  413. model.train(1, dataset, dataset_sink_mode=False)
  414. def test_sparse_attention_parallel_mix():
  415. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  416. set_algo_parameters(fully_use_devices=False)
  417. sparse_attention_config = OpParallelConfig(data_parallel=2, model_parallel=4)
  418. net = FixedSparseAttention(batch_size=16,
  419. seq_length=1024,
  420. size_per_head=64,
  421. num_heads=8,
  422. block_size=64,
  423. parallel_config=sparse_attention_config)
  424. q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  425. k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  426. v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  427. mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
  428. dataset = Dataset(q, k, v, mask)
  429. model = Model(net)
  430. model.train(1, dataset, dataset_sink_mode=False)
  431. def test_sparse_attention_parallel_mix1():
  432. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  433. set_algo_parameters(fully_use_devices=False)
  434. sparse_attention_config = OpParallelConfig(data_parallel=4, model_parallel=2)
  435. net = FixedSparseAttention(batch_size=16,
  436. seq_length=1024,
  437. size_per_head=64,
  438. num_heads=8,
  439. block_size=64,
  440. parallel_config=sparse_attention_config)
  441. q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  442. k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  443. v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  444. mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
  445. dataset = Dataset(q, k, v, mask)
  446. model = Model(net)
  447. model.train(1, dataset, dataset_sink_mode=False)
  448. def test_sparse_attention_parallel_dp():
  449. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  450. set_algo_parameters(fully_use_devices=False)
  451. sparse_attention_config = OpParallelConfig(data_parallel=8, model_parallel=1)
  452. net = FixedSparseAttention(batch_size=16,
  453. seq_length=1024,
  454. size_per_head=64,
  455. num_heads=8,
  456. block_size=64,
  457. parallel_config=sparse_attention_config)
  458. net = _VirtualDatasetCell(net)
  459. q = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  460. k = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  461. v = Tensor(np.ones((2, 1024, 512)), mstype.float16)
  462. mask = Tensor(np.ones((2, 1024, 1024)), mstype.float32)
  463. dataset = Dataset(q, k, v, mask)
  464. model = Model(net)
  465. model.train(1, dataset, dataset_sink_mode=False)
  466. def test_parallel_cross_entroy_loss_semi_auto_parallel():
  467. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.AUTO_PARALLEL)
  468. class NetWithLoss(nn.Cell):
  469. def __init__(self, network, config_setting):
  470. super(NetWithLoss, self).__init__()
  471. self.loss = CrossEntropyLoss(config_setting)
  472. self.network = network
  473. def construct(self, x1, x2, x3):
  474. predict, _ = self.network(x1)
  475. predict = P.Reshape()(predict, (-1, 16))
  476. return self.loss(predict, x2, x3)
  477. net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
  478. net = NetWithLoss(net, config.dp_mp_config)
  479. net = _VirtualDatasetCell(net)
  480. embed_ids = Tensor(np.ones((2, 64)), mstype.int32)
  481. labels = Tensor(np.ones((2 * 64,)), mstype.int32)
  482. input_mask = Tensor(np.ones((2 * 64,)), mstype.float32)
  483. dataset = Dataset(embed_ids, labels, input_mask)
  484. model = Model(net)
  485. model.train(1, dataset, dataset_sink_mode=False)
  486. def test_transformer_args():
  487. with pytest.raises(TypeError):
  488. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  489. tgt_seq_length=20, decoder_layers="aa")
  490. with pytest.raises(TypeError):
  491. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  492. tgt_seq_length="a")
  493. with pytest.raises(TypeError):
  494. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  495. tgt_seq_length=20, softmax_compute_type=mstype.int64)
  496. with pytest.raises(TypeError):
  497. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  498. tgt_seq_length=20, layernorm_compute_type=mstype.int64)
  499. with pytest.raises(TypeError):
  500. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  501. tgt_seq_length=20, param_init_type=mstype.int64)
  502. with pytest.raises(TypeError):
  503. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  504. tgt_seq_length=20, hidden_dropout_rate=mstype.int64)
  505. Transformer(hidden_size=10, batch_size=2, ffn_hidden_size=20, src_seq_length=10,
  506. tgt_seq_length=20, softmax_compute_type=mstype.float16)
  507. def test_transformer_parallel_config():
  508. parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3)
  509. with pytest.raises(TypeError):
  510. parallel_test_config.data_parallel = False
  511. with pytest.raises(ValueError):
  512. parallel_test_config.data_parallel = 0
  513. with pytest.raises(TypeError):
  514. parallel_test_config.model_parallel = False
  515. with pytest.raises(ValueError):
  516. parallel_test_config.model_parallel = 0
  517. with pytest.raises(TypeError):
  518. parallel_test_config.pipeline_stage = False
  519. with pytest.raises(ValueError):
  520. parallel_test_config.pipeline_stage = 0
  521. with pytest.raises(TypeError):
  522. parallel_test_config.micro_batch_num = False
  523. with pytest.raises(ValueError):
  524. parallel_test_config.micro_batch_num = 0
  525. with pytest.raises(TypeError):
  526. parallel_test_config.gradient_aggregation_group = False
  527. with pytest.raises(ValueError):
  528. parallel_test_config.gradient_aggregation_group = 0
  529. with pytest.raises(TypeError):
  530. parallel_test_config.recompute = 1
  531. parallel_test_config.recompute.recompute = False
  532. assert not parallel_test_config.recompute.recompute
  533. def test_parallel_config():
  534. parallel_test_config = OpParallelConfig(data_parallel=1, model_parallel=3)
  535. with pytest.raises(ValueError):
  536. parallel_test_config.data_parallel = 0
  537. with pytest.raises(TypeError):
  538. parallel_test_config.model_parallel = False
  539. with pytest.raises(ValueError):
  540. parallel_test_config.model_parallel = 0
  541. assert parallel_test_config.model_parallel == 3
  542. def test_embedding_parallel_config():
  543. parallel_test_config = EmbeddingOpParallelConfig(data_parallel=1, model_parallel=3, vocab_emb_dp=False)
  544. with pytest.raises(ValueError):
  545. parallel_test_config.data_parallel = 0
  546. with pytest.raises(TypeError):
  547. parallel_test_config.model_parallel = False
  548. with pytest.raises(ValueError):
  549. parallel_test_config.model_parallel = 0
  550. with pytest.raises(TypeError):
  551. parallel_test_config.vocab_emb_dp = 0
  552. assert not parallel_test_config.vocab_emb_dp