You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_bert_train.py 9.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Bert test."""
  16. # pylint: disable=missing-docstring, arguments-differ, W0612
  17. import os
  18. import mindspore.common.dtype as mstype
  19. import mindspore.context as context
  20. from mindspore import Tensor
  21. from mindspore.nn.optim import AdamWeightDecay
  22. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  23. from mindspore.nn import learning_rate_schedule as lr_schedules
  24. from mindspore.ops import operations as P
  25. from model_zoo.official.nlp.bert.src import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
  26. from ...dataset_mock import MindData
  27. from ...ops_common import nn, np, batch_tuple_tensor, build_construct_graph
  28. _current_dir = os.path.dirname(os.path.realpath(__file__)) + "/../python/test_data"
  29. context.set_context(mode=context.GRAPH_MODE)
  30. def get_dataset(batch_size=1):
  31. dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
  32. dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1), \
  33. (batch_size, 20), (batch_size, 20), (batch_size, 20))
  34. dataset = MindData(size=2, batch_size=batch_size,
  35. np_types=dataset_types,
  36. output_shapes=dataset_shapes,
  37. input_indexs=(0, 1))
  38. return dataset
  39. def load_test_data(batch_size=1):
  40. dataset = get_dataset(batch_size)
  41. ret = dataset.next()
  42. ret = batch_tuple_tensor(ret, batch_size)
  43. return ret
  44. def get_config(version='base', batch_size=1):
  45. """
  46. get_config definition
  47. """
  48. if version == 'base':
  49. return BertConfig(
  50. batch_size=batch_size,
  51. seq_length=128,
  52. vocab_size=21128,
  53. hidden_size=768,
  54. num_hidden_layers=12,
  55. num_attention_heads=12,
  56. intermediate_size=3072,
  57. hidden_act="gelu",
  58. hidden_dropout_prob=0.1,
  59. attention_probs_dropout_prob=0.1,
  60. max_position_embeddings=512,
  61. type_vocab_size=2,
  62. initializer_range=0.02,
  63. use_relative_positions=True,
  64. input_mask_from_dataset=True,
  65. token_type_ids_from_dataset=True,
  66. dtype=mstype.float32,
  67. compute_type=mstype.float32)
  68. if version == 'large':
  69. return BertConfig(
  70. batch_size=batch_size,
  71. seq_length=128,
  72. vocab_size=21128,
  73. hidden_size=1024,
  74. num_hidden_layers=24,
  75. num_attention_heads=16,
  76. intermediate_size=4096,
  77. hidden_act="gelu",
  78. hidden_dropout_prob=0.1,
  79. attention_probs_dropout_prob=0.1,
  80. max_position_embeddings=512,
  81. type_vocab_size=2,
  82. initializer_range=0.02,
  83. use_relative_positions=True,
  84. input_mask_from_dataset=True,
  85. token_type_ids_from_dataset=True,
  86. dtype=mstype.float32,
  87. compute_type=mstype.float32)
  88. return BertConfig(batch_size=batch_size)
  89. class BertLearningRate(lr_schedules.LearningRateSchedule):
  90. def __init__(self, decay_steps, warmup_steps=100, learning_rate=0.1, end_learning_rate=0.0001, power=1.0):
  91. super(BertLearningRate, self).__init__()
  92. self.warmup_lr = lr_schedules.WarmUpLR(learning_rate, warmup_steps)
  93. self.decay_lr = lr_schedules.PolynomialDecayLR(learning_rate, end_learning_rate, decay_steps, power)
  94. self.warmup_steps = Tensor(np.array([warmup_steps]).astype(np.float32))
  95. self.greater = P.Greater()
  96. self.one = Tensor(np.array([1.0]).astype(np.float32))
  97. self.cast = P.Cast()
  98. def construct(self, global_step):
  99. is_warmup = self.cast(self.greater(self.warmup_steps, global_step), mstype.float32)
  100. warmup_lr = self.warmup_lr(global_step)
  101. decay_lr = self.decay_lr(global_step)
  102. lr = (self.one - is_warmup) * decay_lr + is_warmup * warmup_lr
  103. return lr
  104. def test_bert_train():
  105. """
  106. the main function
  107. """
  108. class ModelBert(nn.Cell):
  109. """
  110. ModelBert definition
  111. """
  112. def __init__(self, network, optimizer=None):
  113. super(ModelBert, self).__init__()
  114. self.optimizer = optimizer
  115. self.train_network = BertTrainOneStepCell(network, self.optimizer)
  116. self.train_network.set_train()
  117. def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6):
  118. return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6)
  119. version = os.getenv('VERSION', 'large')
  120. batch_size = int(os.getenv('BATCH_SIZE', '1'))
  121. inputs = load_test_data(batch_size)
  122. config = get_config(version=version, batch_size=batch_size)
  123. netwithloss = BertNetworkWithLoss(config, True)
  124. lr = BertLearningRate(10)
  125. optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
  126. net = ModelBert(netwithloss, optimizer=optimizer)
  127. net.set_train()
  128. build_construct_graph(net, *inputs, execute=False)
  129. def test_bert_withlossscale_train():
  130. class ModelBert(nn.Cell):
  131. def __init__(self, network, optimizer=None):
  132. super(ModelBert, self).__init__()
  133. self.optimizer = optimizer
  134. self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer)
  135. self.train_network.set_train()
  136. def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7):
  137. return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7)
  138. version = os.getenv('VERSION', 'base')
  139. batch_size = int(os.getenv('BATCH_SIZE', '1'))
  140. scaling_sens = Tensor(np.ones([1]).astype(np.float32))
  141. inputs = load_test_data(batch_size) + (scaling_sens,)
  142. config = get_config(version=version, batch_size=batch_size)
  143. netwithloss = BertNetworkWithLoss(config, True)
  144. lr = BertLearningRate(10)
  145. optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
  146. net = ModelBert(netwithloss, optimizer=optimizer)
  147. net.set_train()
  148. build_construct_graph(net, *inputs, execute=True)
  149. def bert_withlossscale_manager_train():
  150. class ModelBert(nn.Cell):
  151. def __init__(self, network, optimizer=None):
  152. super(ModelBert, self).__init__()
  153. self.optimizer = optimizer
  154. manager = DynamicLossScaleManager()
  155. update_cell = LossScaleUpdateCell(manager)
  156. self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer,
  157. scale_update_cell=update_cell)
  158. self.train_network.set_train()
  159. def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6):
  160. return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6)
  161. version = os.getenv('VERSION', 'base')
  162. batch_size = int(os.getenv('BATCH_SIZE', '1'))
  163. inputs = load_test_data(batch_size)
  164. config = get_config(version=version, batch_size=batch_size)
  165. netwithloss = BertNetworkWithLoss(config, True)
  166. lr = BertLearningRate(10)
  167. optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
  168. net = ModelBert(netwithloss, optimizer=optimizer)
  169. net.set_train()
  170. build_construct_graph(net, *inputs, execute=True)
  171. def bert_withlossscale_manager_train_feed():
  172. class ModelBert(nn.Cell):
  173. def __init__(self, network, optimizer=None):
  174. super(ModelBert, self).__init__()
  175. self.optimizer = optimizer
  176. manager = DynamicLossScaleManager()
  177. update_cell = LossScaleUpdateCell(manager)
  178. self.train_network = BertTrainOneStepWithLossScaleCell(network, self.optimizer,
  179. scale_update_cell=update_cell)
  180. self.train_network.set_train()
  181. def construct(self, arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7):
  182. return self.train_network(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7)
  183. version = os.getenv('VERSION', 'base')
  184. batch_size = int(os.getenv('BATCH_SIZE', '1'))
  185. scaling_sens = Tensor(np.ones([1]).astype(np.float32))
  186. inputs = load_test_data(batch_size) + (scaling_sens,)
  187. config = get_config(version=version, batch_size=batch_size)
  188. netwithloss = BertNetworkWithLoss(config, True)
  189. lr = BertLearningRate(10)
  190. optimizer = AdamWeightDecay(netwithloss.trainable_params(), lr)
  191. net = ModelBert(netwithloss, optimizer=optimizer)
  192. net.set_train()
  193. build_construct_graph(net, *inputs, execute=True)