You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert_train.py 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Bert test."""
  16. # pylint: disable=missing-docstring, arguments-differ, W0612
  17. import os
  18. import mindspore.common.dtype as mstype
  19. import mindspore.context as context
  20. from mindspore import Tensor
  21. from model_zoo.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
  22. from mindspore.nn.optim import AdamWeightDecayDynamicLR
  23. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  24. from ...dataset_mock import MindData
  25. from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph
  26. _current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../python/test_data"
  27. context.set_context(mode=context.GRAPH_MODE)
  28. def get_dataset(batch_size=1):
  29. dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
  30. dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1), \
  31. (batch_size, 20), (batch_size, 20), (batch_size, 20))
  32. dataset = MindData(size=2, batch_size=batch_size,
  33. np_types=dataset_types,
  34. output_shapes=dataset_shapes,
  35. input_indexs=(0, 1))
  36. return dataset
  37. def load_test_data(batch_size=1):
  38. dataset = get_dataset(batch_size)
  39. ret = dataset.next()
  40. ret = batch_tuple_tensor(ret, batch_size)
  41. return ret
  42. def get_config(version='base', batch_size=1):
  43. """
  44. get_config definition
  45. """
  46. if version == 'base':
  47. return BertConfig(
  48. batch_size=batch_size,
  49. seq_length=128,
  50. vocab_size=21128,
  51. hidden_size=768,
  52. num_hidden_layers=12,
  53. num_attention_heads=12,
  54. intermediate_size=3072,
  55. hidden_act="gelu",
  56. hidden_dropout_prob=0.1,
  57. attention_probs_dropout_prob=0.1,
  58. max_position_embeddings=512,
  59. type_vocab_size=2,
  60. initializer_range=0.02,
  61. use_relative_positions=True,
  62. input_mask_from_dataset=True,
  63. token_type_ids_from_dataset=True,
  64. dtype=mstype.float32,
  65. compute_type=mstype.float32)
  66. if version == 'large':
  67. return BertConfig(
  68. batch_size=batch_size,
  69. seq_length=128,
  70. vocab_size=21128,
  71. hidden_size=1024,
  72. num_hidden_layers=24,
  73. num_attention_heads=16,
  74. intermediate_size=4096,
  75. hidden_act="gelu",
  76. hidden_dropout_prob=0.1,
  77. attention_probs_dropout_prob=0.1,
  78. max_position_embeddings=512,
  79. type_vocab_size=2,
  80. initializer_range=0.02,
  81. use_relative_positions=True,
  82. input_mask_from_dataset=True,
  83. token_type_ids_from_dataset=True,
  84. dtype=mstype.float32,
  85. compute_type=mstype.float32)
  86. return BertConfig(batch_size=batch_size)
  87. def test_bert_train():
  88. """
  89. the main function
  90. """
  91. class ModelBert(nn.Cell):
  92. """
  93. ModelBert definition
  94. """
  95. def __init__(self, network, optimizer=None):
  96. super(ModelBert, self).__init__()
  97. self.optimizer = optimizer
  98. self.train_network = BertTrainOneStepCell(network, self.optimizer)
  99. self.train_network.set_train()
  100. def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6):
  101. return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6)
  102. version = os.getenv('VERSION', 'large')
  103. batch_size = int(os.getenv('BATCH_SIZE', '1'))
  104. inputs = load_test_data(batch_size)
  105. config = get_config(version=version, batch_size=batch_size)
  106. netwithloss = BertNetworkWithLoss(config, True)
  107. optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10)
  108. net = ModelBert(netwithloss, optimizer=optimizer)
  109. net.set_train()
  110. build_construct_graph(net, *inputs, execute=False)
  111. def test_bert_withlossscale_train():
  112. class ModelBert(nn.Cell):
  113. def __init__(self, network, optimizer=None):
  114. super(ModelBert, self).__init__()
  115. self.optimizer = optimizer
  116. self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer)
  117. self.train_network.set_train()
  118. def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7):
  119. return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7)
  120. version = os.getenv('VERSION', 'base')
  121. batch_size = int(os.getenv('BATCH_SIZE', '1'))
  122. scaling_sens = Tensor(np.ones([1]).astype(np.float32))
  123. inputs = load_test_data(batch_size) + (scaling_sens,)
  124. config = get_config(version=version, batch_size=batch_size)
  125. netwithloss = BertNetworkWithLoss(config, True)
  126. optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10)
  127. net = ModelBert(netwithloss, optimizer=optimizer)
  128. net.set_train()
  129. build_construct_graph(net, *inputs, execute=True)
  130. def bert_withlossscale_manager_train():
  131. class ModelBert(nn.Cell):
  132. def __init__(self, network, optimizer=None):
  133. super(ModelBert, self).__init__()
  134. self.optimizer = optimizer
  135. manager = DynamicLossScaleManager()
  136. update_cell = LossScaleUpdateCell(manager)
  137. self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer,
  138. scale_update_cell=update_cell)
  139. self.train_network.set_train()
  140. def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6):
  141. return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6)
  142. version = os.getenv('VERSION', 'base')
  143. batch_size = int(os.getenv('BATCH_SIZE', '1'))
  144. inputs = load_test_data(batch_size)
  145. config = get_config(version=version, batch_size=batch_size)
  146. netwithloss = BertNetworkWithLoss(config, True)
  147. optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10)
  148. net = ModelBert(netwithloss, optimizer=optimizer)
  149. net.set_train()
  150. build_construct_graph(net, *inputs, execute=True)
  151. def bert_withlossscale_manager_train_feed():
  152. class ModelBert(nn.Cell):
  153. def __init__(self, network, optimizer=None):
  154. super(ModelBert, self).__init__()
  155. self.optimizer = optimizer
  156. manager = DynamicLossScaleManager()
  157. update_cell = LossScaleUpdateCell(manager)
  158. self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer,
  159. scale_update_cell=update_cell)
  160. self.train_network.set_train()
  161. def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7):
  162. return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7)
  163. version = os.getenv('VERSION', 'base')
  164. batch_size = int(os.getenv('BATCH_SIZE', '1'))
  165. scaling_sens = Tensor(np.ones([1]).astype(np.float32))
  166. inputs = load_test_data(batch_size) + (scaling_sens,)
  167. config = get_config(version=version, batch_size=batch_size)
  168. netwithloss = BertNetworkWithLoss(config, True)
  169. optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), 10)
  170. net = ModelBert(netwithloss, optimizer=optimizer)
  171. net.set_train()
  172. build_construct_graph(net, *inputs, execute=True)