You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_moe.py 9.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import pytest
  15. import numpy as np
  16. import mindspore.common.dtype as mstype
  17. import mindspore.nn as nn
  18. import mindspore.ops as P
  19. from mindspore import Tensor
  20. from mindspore.context import set_auto_parallel_context, ParallelMode
  21. from mindspore.ops import composite as C
  22. from mindspore.ops import functional as F
  23. from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig, CrossEntropyLoss
  24. from mindspore.nn.optim import AdamWeightDecay
  25. from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell
  26. from mindspore.train import Model
  27. from tests.dataset_mock import MindData
  28. from tests.ut.python.ops.test_math_ops import VirtualLoss
  29. grad_all = C.GradOperation(get_all=True)
  30. class Dataset(MindData):
  31. def __init__(self, *inputs, length=3):
  32. super(Dataset, self).__init__(size=length)
  33. self.inputs = inputs
  34. self.index = 0
  35. self.length = length
  36. def __iter__(self):
  37. return self
  38. def __next__(self):
  39. if self.index >= self.length:
  40. raise StopIteration
  41. self.index += 1
  42. return self.inputs
  43. def reset(self):
  44. self.index = 0
  45. config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, vocab_emb_dp=False)
  46. moe_config = MoEConfig(expert_num=4, num_experts_chosen=3)
  47. class NetWithLossFiveInputs(nn.Cell):
  48. def __init__(self, network):
  49. super(NetWithLossFiveInputs, self).__init__()
  50. self.loss = VirtualLoss()
  51. self.network = network
  52. def construct(self, x1, x2, x3, x4, x5):
  53. predict, _, _, _ = self.network(x1, x2, x3, x4, x5)
  54. return self.loss(predict)
  55. def test_transformer_model():
  56. """
  57. Feature: Test Transformer+MoE, with All2All enabled.
  58. Description: 3-dim input.
  59. Expectation: Successful graph compilation with All2All included.
  60. """
  61. set_auto_parallel_context(device_num=16, global_rank=0,
  62. full_batch=True, enable_alltoall=True,
  63. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  64. net = Transformer(encoder_layers=1,
  65. decoder_layers=1,
  66. batch_size=2,
  67. src_seq_length=20,
  68. tgt_seq_length=10,
  69. hidden_size=64,
  70. num_heads=8,
  71. ffn_hidden_size=64,
  72. moe_config=moe_config,
  73. parallel_config=config)
  74. net = _VirtualDatasetCell(net)
  75. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
  76. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  77. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  78. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  79. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  80. net = NetWithLossFiveInputs(net)
  81. params = net.trainable_params()
  82. optimizer = AdamWeightDecay(params)
  83. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  84. memory_mask)
  85. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  86. model = Model(net_with_grad)
  87. model.train(1, dataset, dataset_sink_mode=False)
  88. def test_transformer_model_2d():
  89. """
  90. Feature: Test Transformer+MoE, with All2All enabled.
  91. Description: 2-dim input.
  92. Expectation: Successful graph compilation with All2All included.
  93. """
  94. set_auto_parallel_context(device_num=16, global_rank=0,
  95. full_batch=True, enable_alltoall=True,
  96. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  97. net = Transformer(encoder_layers=1,
  98. decoder_layers=1,
  99. batch_size=2,
  100. src_seq_length=20,
  101. tgt_seq_length=10,
  102. hidden_size=64,
  103. num_heads=8,
  104. ffn_hidden_size=64,
  105. moe_config=moe_config,
  106. parallel_config=config)
  107. net = _VirtualDatasetCell(net)
  108. encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
  109. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  110. decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
  111. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  112. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  113. net = NetWithLossFiveInputs(net)
  114. params = net.trainable_params()
  115. optimizer = AdamWeightDecay(params)
  116. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  117. memory_mask)
  118. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  119. model = Model(net_with_grad)
  120. model.train(1, dataset, dataset_sink_mode=False)
  121. class TransformerNet(nn.Cell):
  122. """Transformer with loss"""
  123. def __init__(self, en_layer, de_layer, parallel_config):
  124. super(TransformerNet, self).__init__()
  125. self.network = Transformer(encoder_layers=en_layer,
  126. decoder_layers=de_layer,
  127. batch_size=2,
  128. src_seq_length=20,
  129. tgt_seq_length=10,
  130. hidden_size=64,
  131. num_heads=8,
  132. ffn_hidden_size=64,
  133. moe_config=moe_config,
  134. parallel_config=parallel_config)
  135. self.loss = CrossEntropyLoss(parallel_config=parallel_config.dp_mp_config)
  136. def construct(self, x1, x2, x3, x4, x5, y, mask):
  137. predict, _, _ = self.network(x1, x2, x3, x4, x5)
  138. predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
  139. return self.loss(predict, y, mask)
  140. def moe_with_loss_plus_mutiparallel(local_parallel_config):
  141. set_auto_parallel_context(device_num=16, enable_alltoall=True,
  142. full_batch=True, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  143. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
  144. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  145. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  146. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  147. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  148. label = Tensor(np.ones((20,)), mstype.int32)
  149. input_mask = Tensor(np.ones((20,)), mstype.float32)
  150. net = TransformerNet(en_layer=1, de_layer=1, parallel_config=local_parallel_config)
  151. net = _VirtualDatasetCell(net)
  152. params = net.trainable_params()
  153. optimizer = AdamWeightDecay(params)
  154. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  155. memory_mask, label, input_mask)
  156. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  157. model = Model(net_with_grad)
  158. model.train(1, dataset, dataset_sink_mode=False)
  159. def test_moe_expert_parallel1():
  160. """
  161. Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
  162. Description: 3-dim input.
  163. Expectation: Successful graph compilation with All2All included.
  164. """
  165. local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=4, expert_parallel=2)
  166. moe_with_loss_plus_mutiparallel(local_p_config)
  167. def test_moe_expert_parallel2():
  168. """
  169. Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
  170. Description: 3-dim input.
  171. Expectation: Successful graph compilation with All2All included.
  172. """
  173. local_p_config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, expert_parallel=1)
  174. moe_with_loss_plus_mutiparallel(local_p_config)
  175. def test_moe_expert_parallel3():
  176. """
  177. Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
  178. Description: 3-dim input.
  179. Expectation: Successful graph compilation.
  180. """
  181. local_p_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, expert_parallel=2)
  182. with pytest.raises(ValueError):
  183. moe_with_loss_plus_mutiparallel(local_p_config)
  184. def test_moe_expert_parallel_exception():
  185. """
  186. Feature: Test Transformer+MoE for data_parallel plus expert_parallel, with All2All enabled.
  187. Description: data_parallel*model_parallel*expert_parallel > device_num
  188. Expectation: Raise ValueError.
  189. """
  190. local_p_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, expert_parallel=4)
  191. with pytest.raises(ValueError):
  192. moe_with_loss_plus_mutiparallel(local_p_config)