You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_transformer.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import mindspore.common.dtype as mstype
  16. import mindspore.nn as nn
  17. from mindspore import Tensor
  18. from mindspore.context import set_auto_parallel_context, ParallelMode
  19. from mindspore.ops import composite as C
  20. from mindspore.nn.parallel import TransformerEncoder, TransformerDecoder, Transformer, TransformerParallelConfig,\
  21. VocabEmbedding
  22. from mindspore.train import Model
  23. from tests.dataset_mock import MindData
  24. from tests.ut.python.ops.test_math_ops import VirtualLoss
  25. grad_all = C.GradOperation(get_all=True)
  26. class Dataset(MindData):
  27. def __init__(self, *inputs, length=3):
  28. super(Dataset, self).__init__(size=length)
  29. self.inputs = inputs
  30. self.index = 0
  31. self.length = length
  32. def __iter__(self):
  33. return self
  34. def __next__(self):
  35. if self.index >= self.length:
  36. raise StopIteration
  37. self.index += 1
  38. return self.inputs
  39. def reset(self):
  40. self.index = 0
  41. def test_transformer_model():
  42. class NetWithLoss(nn.Cell):
  43. def __init__(self, network):
  44. super(NetWithLoss, self).__init__()
  45. self.loss = VirtualLoss()
  46. self.network = network
  47. def construct(self, x1, x2, x3, x4, x5):
  48. predict, _, _ = self.network(x1, x2, x3, x4, x5)
  49. return self.loss(predict)
  50. config = TransformerParallelConfig(dp=1, mp=8)
  51. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  52. net = Transformer(encoder_layers=1,
  53. decoder_layers=2,
  54. hidden_size=64,
  55. num_heads=8,
  56. ffn_hidden_size=64,
  57. src_seq_length=20,
  58. tgt_seq_length=20,
  59. parallel_config=config)
  60. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
  61. encoder_input_mask = Tensor(np.ones((2, 1, 20, 20)), mstype.float16)
  62. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  63. decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), mstype.float16)
  64. memory_mask = Tensor(np.ones((2, 1, 10, 20)), mstype.float16)
  65. net = NetWithLoss(net)
  66. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  67. memory_mask)
  68. model = Model(net)
  69. model.train(1, dataset, dataset_sink_mode=False)
  70. def test_encoder():
  71. class NetWithLoss(nn.Cell):
  72. def __init__(self, network):
  73. super(NetWithLoss, self).__init__()
  74. self.loss = VirtualLoss()
  75. self.network = network
  76. def construct(self, x1, x2):
  77. predict, _ = self.network(x1, x2)
  78. return self.loss(predict)
  79. config = TransformerParallelConfig(dp=1, mp=8)
  80. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  81. net = TransformerEncoder(num_layers=2,
  82. hidden_size=8,
  83. ffn_hidden_size=64,
  84. seq_length=16,
  85. num_heads=8,
  86. parallel_config=config)
  87. encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
  88. encoder_input_mask = Tensor(np.ones((2, 1, 16, 16)), mstype.float16)
  89. net = NetWithLoss(net)
  90. dataset = Dataset(encoder_input_value, encoder_input_mask)
  91. model = Model(net)
  92. model.train(1, dataset, dataset_sink_mode=False)
  93. def test_decoder():
  94. class NetWithLoss(nn.Cell):
  95. def __init__(self, network):
  96. super(NetWithLoss, self).__init__()
  97. self.loss = VirtualLoss()
  98. self.network = network
  99. def construct(self, x1, x2, x3, x4):
  100. predict, _, _ = self.network(x1, x2, x3, x4)
  101. return self.loss(predict)
  102. config = TransformerParallelConfig(dp=1, mp=8)
  103. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  104. net = TransformerDecoder(num_layers=1,
  105. hidden_size=16,
  106. ffn_hidden_size=8,
  107. num_heads=8,
  108. seq_length=10,
  109. parallel_config=config)
  110. encoder_input_value = Tensor(np.ones((2, 20, 16)), mstype.float32)
  111. decoder_input_value = Tensor(np.ones((2, 10, 16)), mstype.float32)
  112. decoder_input_mask = Tensor(np.ones((2, 1, 10, 10)), mstype.float16)
  113. memory_mask = Tensor(np.ones((2, 1, 10, 20)), mstype.float16)
  114. net = NetWithLoss(net)
  115. dataset = Dataset(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
  116. model = Model(net)
  117. model.train(1, dataset, dataset_sink_mode=False)
  118. def test_vocabembedding_dp_true():
  119. config = TransformerParallelConfig(dp=1, mp=8)
  120. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  121. class NetWithLoss(nn.Cell):
  122. def __init__(self, network):
  123. super(NetWithLoss, self).__init__()
  124. self.loss = VirtualLoss()
  125. self.network = network
  126. def construct(self, x1):
  127. predict, _ = self.network(x1)
  128. return self.loss(predict)
  129. class GradWrap(nn.Cell):
  130. def __init__(self, network):
  131. super(GradWrap, self).__init__()
  132. self.network = network
  133. def construct(self, x1):
  134. return grad_all(self.network)(x1)
  135. net = VocabEmbedding(vocab_size=100, embedding_size=16, parallel_config=config)
  136. net = NetWithLoss(net)
  137. encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
  138. dataset = Dataset(encoder_input_value)
  139. model = Model(net)
  140. model.train(1, dataset, dataset_sink_mode=False)
  141. def test_vocabembedding_dp_false():
  142. config = TransformerParallelConfig(dp=1, mp=8, vocab_emb_dp=False)
  143. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  144. class NetWithLoss(nn.Cell):
  145. def __init__(self, network):
  146. super(NetWithLoss, self).__init__()
  147. self.loss = VirtualLoss()
  148. self.network = network
  149. def construct(self, x1):
  150. predict, _ = self.network(x1)
  151. return self.loss(predict)
  152. class GradWrap(nn.Cell):
  153. def __init__(self, network):
  154. super(GradWrap, self).__init__()
  155. self.network = network
  156. def construct(self, x1):
  157. return grad_all(self.network)(x1)
  158. net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config)
  159. net = NetWithLoss(net)
  160. encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
  161. dataset = Dataset(encoder_input_value)
  162. model = Model(net)
  163. model.train(1, dataset, dataset_sink_mode=False)