You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_pretrain.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. #################pre_train bert example on zh-wiki########################
  17. python run_pretrain.py
  18. """
  19. import os
  20. import argparse
  21. import mindspore.communication.management as D
  22. from mindspore.communication.management import get_rank
  23. import mindspore.common.dtype as mstype
  24. from mindspore import context
  25. from mindspore.train.model import Model
  26. from mindspore.context import ParallelMode
  27. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  28. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
  29. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  30. from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay
  31. from mindspore import log as logger
  32. from mindspore.common import set_seed
  33. from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
  34. BertTrainAccumulationAllReduceEachWithLossScaleCell, \
  35. BertTrainAccumulationAllReducePostWithLossScaleCell, \
  36. BertTrainOneStepWithLossScaleCellForAdam, \
  37. AdamWeightDecayForBert
  38. from src.dataset import create_bert_dataset
  39. from src.config import cfg, bert_net_cfg
  40. from src.utils import LossCallBack, BertLearningRate
  41. _current_dir = os.path.dirname(os.path.realpath(__file__))
  42. def _set_bert_all_reduce_split():
  43. """set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
  44. if bert_net_cfg.num_hidden_layers == 12:
  45. if bert_net_cfg.use_relative_positions:
  46. context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
  47. else:
  48. context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
  49. elif bert_net_cfg.num_hidden_layers == 24:
  50. if bert_net_cfg.use_relative_positions:
  51. context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
  52. else:
  53. context.set_auto_parallel_context(all_reduce_fusion_config=[38, 93, 148, 203, 258, 313, 368, 397])
  54. def _get_optimizer(args_opt, network):
  55. """get bert optimizer, support Lamb, Momentum, AdamWeightDecay."""
  56. if cfg.optimizer == 'Lamb':
  57. lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
  58. end_learning_rate=cfg.Lamb.end_learning_rate,
  59. warmup_steps=cfg.Lamb.warmup_steps,
  60. decay_steps=args_opt.train_steps,
  61. power=cfg.Lamb.power)
  62. params = network.trainable_params()
  63. decay_params = list(filter(cfg.Lamb.decay_filter, params))
  64. other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
  65. group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
  66. {'params': other_params},
  67. {'order_params': params}]
  68. optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
  69. elif cfg.optimizer == 'Momentum':
  70. optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
  71. momentum=cfg.Momentum.momentum)
  72. elif cfg.optimizer == 'AdamWeightDecay':
  73. lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
  74. end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
  75. warmup_steps=cfg.AdamWeightDecay.warmup_steps,
  76. decay_steps=args_opt.train_steps,
  77. power=cfg.AdamWeightDecay.power)
  78. params = network.trainable_params()
  79. decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
  80. other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
  81. group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
  82. {'params': other_params, 'weight_decay': 0.0},
  83. {'order_params': params}]
  84. if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU':
  85. optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
  86. else:
  87. optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
  88. else:
  89. raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
  90. format(cfg.optimizer))
  91. return optimizer
  92. def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
  93. """Judge whether is suitable to enable graph kernel."""
  94. return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
  95. cfg.bert_network == 'base' and (cfg.batch_size == 32 or cfg.batch_size == 64) and \
  96. cfg.optimizer == 'AdamWeightDecay'
  97. def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
  98. if enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
  99. if device_target == 'GPU':
  100. context.set_context(enable_graph_kernel=True)
  101. else:
  102. logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')
  103. def _check_compute_type(device_target, is_auto_enable_graph_kernel):
  104. if device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and not is_auto_enable_graph_kernel:
  105. logger.warning('Gpu only support fp32 temporarily, run with fp32.')
  106. bert_net_cfg.compute_type = mstype.float32
  107. def argparse_init():
  108. """Argparse init."""
  109. parser = argparse.ArgumentParser(description='bert pre_training')
  110. parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
  111. help='device where the code will be implemented. (Default: Ascend)')
  112. parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
  113. help="Run distribute, default is false.")
  114. parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
  115. parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
  116. parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
  117. parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"],
  118. help="Enable save checkpoint, default is true.")
  119. parser.add_argument("--enable_lossscale", type=str, default="true", choices=["true", "false"],
  120. help="Use lossscale or not, default is not.")
  121. parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
  122. help="Enable shuffle for dataset, default is true.")
  123. parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
  124. help="Enable data sink, default is true.")
  125. parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
  126. parser.add_argument("--accumulation_steps", type=int, default="1",
  127. help="Accumulating gradients N times before weight update, default is 1.")
  128. parser.add_argument("--allreduce_post_accumulation", type=str, default="true", choices=["true", "false"],
  129. help="Whether to allreduce after accumulation of N steps or after each step, default is true.")
  130. parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
  131. parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
  132. parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
  133. "default is 1000.")
  134. parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, "
  135. "meaning run all steps according to epoch number.")
  136. parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
  137. parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
  138. parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
  139. parser.add_argument("--enable_graph_kernel", type=str, default="auto", choices=["auto", "true", "false"],
  140. help="Accelerate by graph kernel, default is auto.")
  141. return parser
  142. def run_pretrain():
  143. """pre-train bert_clue"""
  144. parser = argparse_init()
  145. args_opt = parser.parse_args()
  146. context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
  147. context.set_context(reserve_class_name_in_scope=False)
  148. ckpt_save_dir = args_opt.save_checkpoint_path
  149. if args_opt.distribute == "true":
  150. if args_opt.device_target == 'Ascend':
  151. D.init()
  152. device_num = args_opt.device_num
  153. rank = args_opt.device_id % device_num
  154. else:
  155. D.init()
  156. device_num = D.get_group_size()
  157. rank = D.get_rank()
  158. ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
  159. context.reset_auto_parallel_context()
  160. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
  161. device_num=device_num)
  162. if args_opt.device_target == 'Ascend':
  163. _set_bert_all_reduce_split()
  164. else:
  165. rank = 0
  166. device_num = 1
  167. is_auto_enable_graph_kernel = _auto_enable_graph_kernel(args_opt.device_target, args_opt.enable_graph_kernel)
  168. _set_graph_kernel_context(args_opt.device_target, args_opt.enable_graph_kernel, is_auto_enable_graph_kernel)
  169. _check_compute_type(args_opt.device_target, is_auto_enable_graph_kernel)
  170. if args_opt.accumulation_steps > 1:
  171. logger.info("accumulation steps: {}".format(args_opt.accumulation_steps))
  172. logger.info("global batch size: {}".format(cfg.batch_size * args_opt.accumulation_steps))
  173. if args_opt.enable_data_sink == "true":
  174. args_opt.data_sink_steps *= args_opt.accumulation_steps
  175. logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
  176. if args_opt.enable_save_ckpt == "true":
  177. args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
  178. logger.info("save checkpoint steps: {}".format(args_opt.save_checkpoint_steps))
  179. ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
  180. net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
  181. new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
  182. if args_opt.train_steps > 0:
  183. train_steps = args_opt.train_steps * args_opt.accumulation_steps
  184. new_repeat_count = min(new_repeat_count, train_steps // args_opt.data_sink_steps)
  185. else:
  186. args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() // args_opt.accumulation_steps
  187. logger.info("train steps: {}".format(args_opt.train_steps))
  188. optimizer = _get_optimizer(args_opt, net_with_loss)
  189. callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(ds.get_dataset_size())]
  190. if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
  191. config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
  192. keep_checkpoint_max=args_opt.save_checkpoint_num)
  193. ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
  194. directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
  195. callback.append(ckpoint_cb)
  196. if args_opt.load_checkpoint_path:
  197. param_dict = load_checkpoint(args_opt.load_checkpoint_path)
  198. load_param_into_net(net_with_loss, param_dict)
  199. if args_opt.enable_lossscale == "true":
  200. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
  201. scale_factor=cfg.scale_factor,
  202. scale_window=cfg.scale_window)
  203. accumulation_steps = args_opt.accumulation_steps
  204. enable_global_norm = cfg.enable_global_norm
  205. if accumulation_steps <= 1:
  206. if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU':
  207. net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer,
  208. scale_update_cell=update_cell)
  209. else:
  210. net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
  211. scale_update_cell=update_cell)
  212. else:
  213. allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true"
  214. net_with_accumulation = (BertTrainAccumulationAllReducePostWithLossScaleCell if allreduce_post else
  215. BertTrainAccumulationAllReduceEachWithLossScaleCell)
  216. net_with_grads = net_with_accumulation(net_with_loss, optimizer=optimizer,
  217. scale_update_cell=update_cell,
  218. accumulation_steps=accumulation_steps,
  219. enable_global_norm=enable_global_norm)
  220. else:
  221. net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
  222. model = Model(net_with_grads)
  223. model.train(new_repeat_count, ds, callbacks=callback,
  224. dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
  225. if __name__ == '__main__':
  226. set_seed(0)
  227. run_pretrain()