You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_pretrain.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. #################pre_train bert example on zh-wiki########################
  17. python run_pretrain.py
  18. """
  19. import os
  20. import argparse
  21. import mindspore.communication.management as D
  22. from mindspore.communication.management import get_rank
  23. import mindspore.common.dtype as mstype
  24. from mindspore import context
  25. from mindspore.train.model import Model
  26. from mindspore.context import ParallelMode
  27. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  28. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
  29. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  30. from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
  31. from mindspore import log as logger
  32. from mindspore.common import set_seed
  33. from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
  34. BertTrainAccumulateStepsWithLossScaleCell
  35. from src.dataset import create_bert_dataset
  36. from src.config import cfg, bert_net_cfg
  37. from src.utils import LossCallBack, BertLearningRate
  38. _current_dir = os.path.dirname(os.path.realpath(__file__))
  39. def _set_bert_all_reduce_split():
  40. """set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
  41. if bert_net_cfg.num_hidden_layers == 12:
  42. if bert_net_cfg.use_relative_positions:
  43. context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
  44. else:
  45. context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
  46. elif bert_net_cfg.num_hidden_layers == 24:
  47. if bert_net_cfg.use_relative_positions:
  48. context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
  49. else:
  50. context.set_auto_parallel_context(all_reduce_fusion_config=[38, 93, 148, 203, 258, 313, 368, 397])
  51. def _get_optimizer(args_opt, network):
  52. """get bert optimizer, support Lamb, Momentum, AdamWeightDecay."""
  53. if cfg.optimizer == 'Lamb':
  54. lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
  55. end_learning_rate=cfg.Lamb.end_learning_rate,
  56. warmup_steps=cfg.Lamb.warmup_steps,
  57. decay_steps=args_opt.train_steps,
  58. power=cfg.Lamb.power)
  59. params = network.trainable_params()
  60. decay_params = list(filter(cfg.Lamb.decay_filter, params))
  61. other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
  62. group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
  63. {'params': other_params},
  64. {'order_params': params}]
  65. optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
  66. elif cfg.optimizer == 'Momentum':
  67. optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
  68. momentum=cfg.Momentum.momentum)
  69. elif cfg.optimizer == 'AdamWeightDecay':
  70. lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
  71. end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
  72. warmup_steps=cfg.AdamWeightDecay.warmup_steps,
  73. decay_steps=args_opt.train_steps,
  74. power=cfg.AdamWeightDecay.power)
  75. params = network.trainable_params()
  76. decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
  77. other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
  78. group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
  79. {'params': other_params, 'weight_decay': 0.0},
  80. {'order_params': params}]
  81. optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
  82. else:
  83. raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
  84. format(cfg.optimizer))
  85. return optimizer
  86. def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
  87. """Judge whether is suitable to enable graph kernel."""
  88. return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
  89. cfg.bert_network == 'base' and (cfg.batch_size == 32 or cfg.batch_size == 64) and \
  90. cfg.optimizer == 'AdamWeightDecay'
  91. def run_pretrain():
  92. """pre-train bert_clue"""
  93. parser = argparse.ArgumentParser(description='bert pre_training')
  94. parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
  95. help='device where the code will be implemented. (Default: Ascend)')
  96. parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
  97. help="Run distribute, default is false.")
  98. parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
  99. parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
  100. parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
  101. parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"],
  102. help="Enable save checkpoint, default is true.")
  103. parser.add_argument("--enable_lossscale", type=str, default="true", choices=["true", "false"],
  104. help="Use lossscale or not, default is not.")
  105. parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
  106. help="Enable shuffle for dataset, default is true.")
  107. parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
  108. help="Enable data sink, default is true.")
  109. parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
  110. parser.add_argument("--accumulation_steps", type=int, default="1",
  111. help="Accumulating gradients N times before weight update, default is 1.")
  112. parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
  113. parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
  114. parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
  115. "default is 1000.")
  116. parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, "
  117. "meaning run all steps according to epoch number.")
  118. parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
  119. parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
  120. parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
  121. parser.add_argument("--enable_graph_kernel", type=str, default="auto", choices=["auto", "true", "false"],
  122. help="Accelerate by graph kernel, default is auto.")
  123. args_opt = parser.parse_args()
  124. context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
  125. context.set_context(reserve_class_name_in_scope=False)
  126. ckpt_save_dir = args_opt.save_checkpoint_path
  127. if args_opt.distribute == "true":
  128. if args_opt.device_target == 'Ascend':
  129. D.init()
  130. device_num = args_opt.device_num
  131. rank = args_opt.device_id % device_num
  132. else:
  133. D.init()
  134. device_num = D.get_group_size()
  135. rank = D.get_rank()
  136. ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
  137. context.reset_auto_parallel_context()
  138. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
  139. device_num=device_num)
  140. if args_opt.device_target == 'Ascend':
  141. _set_bert_all_reduce_split()
  142. else:
  143. rank = 0
  144. device_num = 1
  145. is_auto_enable_graph_kernel = _auto_enable_graph_kernel(args_opt.device_target, args_opt.enable_graph_kernel)
  146. if args_opt.enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
  147. context.set_context(enable_graph_kernel=True)
  148. if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \
  149. not is_auto_enable_graph_kernel:
  150. logger.warning('Gpu only support fp32 temporarily, run with fp32.')
  151. bert_net_cfg.compute_type = mstype.float32
  152. if args_opt.accumulation_steps > 1:
  153. logger.info("accumulation steps: {}".format(args_opt.accumulation_steps))
  154. logger.info("global batch size: {}".format(cfg.batch_size * args_opt.accumulation_steps))
  155. if args_opt.enable_data_sink == "true":
  156. args_opt.data_sink_steps *= args_opt.accumulation_steps
  157. logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
  158. if args_opt.enable_save_ckpt == "true":
  159. args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
  160. logger.info("save checkpoint steps: {}".format(args_opt.save_checkpoint_steps))
  161. ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
  162. net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
  163. new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
  164. if args_opt.train_steps > 0:
  165. train_steps = args_opt.train_steps * args_opt.accumulation_steps
  166. new_repeat_count = min(new_repeat_count, train_steps // args_opt.data_sink_steps)
  167. else:
  168. args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() // args_opt.accumulation_steps
  169. logger.info("train steps: {}".format(args_opt.train_steps))
  170. optimizer = _get_optimizer(args_opt, net_with_loss)
  171. callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(ds.get_dataset_size())]
  172. if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
  173. config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
  174. keep_checkpoint_max=args_opt.save_checkpoint_num)
  175. ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
  176. directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
  177. callback.append(ckpoint_cb)
  178. if args_opt.load_checkpoint_path:
  179. param_dict = load_checkpoint(args_opt.load_checkpoint_path)
  180. load_param_into_net(net_with_loss, param_dict)
  181. if args_opt.enable_lossscale == "true":
  182. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
  183. scale_factor=cfg.scale_factor,
  184. scale_window=cfg.scale_window)
  185. if args_opt.accumulation_steps <= 1:
  186. net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
  187. scale_update_cell=update_cell)
  188. else:
  189. accumulation_steps = args_opt.accumulation_steps
  190. net_with_grads = BertTrainAccumulateStepsWithLossScaleCell(net_with_loss, optimizer=optimizer,
  191. scale_update_cell=update_cell,
  192. accumulation_steps=accumulation_steps,
  193. enable_global_norm=cfg.enable_global_norm)
  194. else:
  195. net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
  196. model = Model(net_with_grads)
  197. model.train(new_repeat_count, ds, callbacks=callback,
  198. dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
  199. if __name__ == '__main__':
  200. set_seed(0)
  201. run_pretrain()