You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_moe.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import mindspore.common.dtype as mstype
  16. import mindspore.nn as nn
  17. from mindspore import Tensor
  18. from mindspore.context import set_auto_parallel_context, ParallelMode
  19. from mindspore.ops import composite as C
  20. from mindspore.parallel.nn import Transformer, TransformerOpParallelConfig, MoEConfig
  21. from mindspore.nn.optim import AdamWeightDecay
  22. from mindspore.nn.wrap.cell_wrapper import TrainOneStepCell, _VirtualDatasetCell
  23. from mindspore.train import Model
  24. from tests.dataset_mock import MindData
  25. from tests.ut.python.ops.test_math_ops import VirtualLoss
  26. grad_all = C.GradOperation(get_all=True)
  27. class Dataset(MindData):
  28. def __init__(self, *inputs, length=3):
  29. super(Dataset, self).__init__(size=length)
  30. self.inputs = inputs
  31. self.index = 0
  32. self.length = length
  33. def __iter__(self):
  34. return self
  35. def __next__(self):
  36. if self.index >= self.length:
  37. raise StopIteration
  38. self.index += 1
  39. return self.inputs
  40. def reset(self):
  41. self.index = 0
  42. config = TransformerOpParallelConfig(data_parallel=2, model_parallel=8, vocab_emb_dp=False)
  43. moe_config = MoEConfig(expert_num=4)
  44. class NetWithLossFiveInputs(nn.Cell):
  45. def __init__(self, network):
  46. super(NetWithLossFiveInputs, self).__init__()
  47. self.loss = VirtualLoss()
  48. self.network = network
  49. def construct(self, x1, x2, x3, x4, x5):
  50. predict, _, _, _ = self.network(x1, x2, x3, x4, x5)
  51. return self.loss(predict)
  52. def test_transformer_model():
  53. set_auto_parallel_context(device_num=16, global_rank=0,
  54. full_batch=True, enable_alltoall=True,
  55. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  56. net = Transformer(encoder_layers=1,
  57. decoder_layers=1,
  58. batch_size=2,
  59. src_seq_length=20,
  60. tgt_seq_length=10,
  61. hidden_size=64,
  62. num_heads=8,
  63. ffn_hidden_size=64,
  64. moe_config=moe_config,
  65. parallel_config=config)
  66. net = _VirtualDatasetCell(net)
  67. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
  68. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  69. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  70. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  71. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  72. net = NetWithLossFiveInputs(net)
  73. params = net.trainable_params()
  74. optimizer = AdamWeightDecay(params)
  75. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  76. memory_mask)
  77. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  78. model = Model(net_with_grad)
  79. model.train(1, dataset, dataset_sink_mode=False)
  80. def test_transformer_model_2d():
  81. set_auto_parallel_context(device_num=16, global_rank=0,
  82. full_batch=True, enable_alltoall=True,
  83. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  84. net = Transformer(encoder_layers=1,
  85. decoder_layers=1,
  86. batch_size=2,
  87. src_seq_length=20,
  88. tgt_seq_length=10,
  89. hidden_size=64,
  90. num_heads=8,
  91. ffn_hidden_size=64,
  92. moe_config=moe_config,
  93. parallel_config=config)
  94. net = _VirtualDatasetCell(net)
  95. encoder_input_value = Tensor(np.ones((40, 64)), mstype.float32)
  96. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  97. decoder_input_value = Tensor(np.ones((20, 64)), mstype.float32)
  98. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  99. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  100. net = NetWithLossFiveInputs(net)
  101. params = net.trainable_params()
  102. optimizer = AdamWeightDecay(params)
  103. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  104. memory_mask)
  105. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  106. model = Model(net_with_grad)
  107. model.train(1, dataset, dataset_sink_mode=False)