You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

run_pretrain.py 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. #################pre_train bert example on zh-wiki########################
  17. python run_pretrain.py
  18. """
  19. import os
  20. import argparse
  21. import mindspore.communication.management as D
  22. from mindspore.communication.management import get_rank
  23. import mindspore.common.dtype as mstype
  24. from mindspore import context
  25. from mindspore.train.model import Model
  26. from mindspore.context import ParallelMode
  27. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  28. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
  29. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  30. from mindspore.train.train_thor import ConvertModelUtils
  31. from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecay, THOR
  32. from mindspore import log as logger
  33. from mindspore.common import set_seed
  34. from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
  35. BertTrainAccumulationAllReduceEachWithLossScaleCell, \
  36. BertTrainAccumulationAllReducePostWithLossScaleCell, \
  37. BertTrainOneStepWithLossScaleCellForAdam, \
  38. AdamWeightDecayForBert
  39. from src.dataset import create_bert_dataset
  40. from src.config import cfg, bert_net_cfg
  41. from src.utils import LossCallBack, BertLearningRate
  42. _current_dir = os.path.dirname(os.path.realpath(__file__))
  43. def _set_bert_all_reduce_split():
  44. """set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
  45. device_target = context.get_context('device_target')
  46. enable_graph_kernel = context.get_context('enable_graph_kernel')
  47. device_num = context.get_auto_parallel_context('device_num')
  48. if bert_net_cfg.num_hidden_layers == 12:
  49. if bert_net_cfg.use_relative_positions:
  50. context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
  51. else:
  52. context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
  53. if device_target == 'GPU' and enable_graph_kernel and device_num == 8:
  54. context.set_auto_parallel_context(all_reduce_fusion_config=[180, 205])
  55. elif device_target == 'GPU' and enable_graph_kernel and device_num == 16:
  56. context.set_auto_parallel_context(all_reduce_fusion_config=[120, 205])
  57. elif bert_net_cfg.num_hidden_layers == 24:
  58. if bert_net_cfg.use_relative_positions:
  59. context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
  60. else:
  61. context.set_auto_parallel_context(all_reduce_fusion_config=[38, 93, 148, 203, 258, 313, 368, 397])
  62. def _get_optimizer(args_opt, network):
  63. """get bert optimizer, support Lamb, Momentum, AdamWeightDecay."""
  64. if cfg.optimizer == 'Lamb':
  65. lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
  66. end_learning_rate=cfg.Lamb.end_learning_rate,
  67. warmup_steps=cfg.Lamb.warmup_steps,
  68. decay_steps=args_opt.train_steps,
  69. power=cfg.Lamb.power)
  70. params = network.trainable_params()
  71. decay_params = list(filter(cfg.Lamb.decay_filter, params))
  72. other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
  73. group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
  74. {'params': other_params},
  75. {'order_params': params}]
  76. optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
  77. elif cfg.optimizer == 'Momentum':
  78. optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
  79. momentum=cfg.Momentum.momentum)
  80. elif cfg.optimizer == 'AdamWeightDecay':
  81. lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
  82. end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
  83. warmup_steps=cfg.AdamWeightDecay.warmup_steps,
  84. decay_steps=args_opt.train_steps,
  85. power=cfg.AdamWeightDecay.power)
  86. params = network.trainable_params()
  87. decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
  88. other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
  89. group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
  90. {'params': other_params, 'weight_decay': 0.0},
  91. {'order_params': params}]
  92. if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU':
  93. optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
  94. else:
  95. optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
  96. elif cfg.optimizer == "Thor":
  97. from src.utils import get_bert_thor_lr, get_bert_thor_damping
  98. lr = get_bert_thor_lr(cfg.Thor.lr_max, cfg.Thor.lr_min, cfg.Thor.lr_power, cfg.Thor.lr_total_steps)
  99. damping = get_bert_thor_damping(cfg.Thor.damping_max, cfg.Thor.damping_min, cfg.Thor.damping_power,
  100. cfg.Thor.damping_total_steps)
  101. split_indices = None
  102. if bert_net_cfg.num_hidden_layers == 12:
  103. if bert_net_cfg.use_relative_positions:
  104. split_indices = [29, 58, 87, 116, 145, 174, 203, 217]
  105. else:
  106. split_indices = [28, 55, 82, 109, 136, 163, 190, 205]
  107. elif bert_net_cfg.num_hidden_layers == 24:
  108. if bert_net_cfg.use_relative_positions:
  109. split_indices = [30, 90, 150, 210, 270, 330, 390, 421]
  110. else:
  111. split_indices = [38, 93, 148, 203, 258, 313, 368, 397]
  112. optimizer = THOR(network, lr, damping, cfg.Thor.momentum,
  113. cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size,
  114. decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
  115. split_indices=split_indices)
  116. else:
  117. raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
  118. format(cfg.optimizer))
  119. return optimizer
  120. def _auto_enable_graph_kernel(device_target, graph_kernel_mode):
  121. """Judge whether is suitable to enable graph kernel."""
  122. return graph_kernel_mode in ("auto", "true") and device_target == 'GPU' and \
  123. cfg.bert_network == 'base' and cfg.optimizer == 'AdamWeightDecay'
  124. def _set_graph_kernel_context(device_target, enable_graph_kernel, is_auto_enable_graph_kernel):
  125. if enable_graph_kernel == "true" or is_auto_enable_graph_kernel:
  126. if device_target == 'GPU':
  127. context.set_context(enable_graph_kernel=True)
  128. else:
  129. logger.warning('Graph kernel only supports GPU back-end now, run with graph kernel off.')
  130. def _check_compute_type(args_opt, is_auto_enable_graph_kernel):
  131. if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32 and \
  132. not is_auto_enable_graph_kernel:
  133. warning_message = 'Gpu only support fp32 temporarily, run with fp32.'
  134. bert_net_cfg.compute_type = mstype.float32
  135. if args_opt.enable_lossscale == "true":
  136. args_opt.enable_lossscale = "false"
  137. warning_message = 'Gpu only support fp32 temporarily, run with fp32 and disable lossscale.'
  138. logger.warning(warning_message)
  139. def argparse_init():
  140. """Argparse init."""
  141. parser = argparse.ArgumentParser(description='bert pre_training')
  142. parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
  143. help='device where the code will be implemented. (Default: Ascend)')
  144. parser.add_argument("--distribute", type=str, default="false", choices=["true", "false"],
  145. help="Run distribute, default is false.")
  146. parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
  147. parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
  148. parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
  149. parser.add_argument("--enable_save_ckpt", type=str, default="true", choices=["true", "false"],
  150. help="Enable save checkpoint, default is true.")
  151. parser.add_argument("--enable_lossscale", type=str, default="true", choices=["true", "false"],
  152. help="Use lossscale or not, default is not.")
  153. parser.add_argument("--do_shuffle", type=str, default="true", choices=["true", "false"],
  154. help="Enable shuffle for dataset, default is true.")
  155. parser.add_argument("--enable_data_sink", type=str, default="true", choices=["true", "false"],
  156. help="Enable data sink, default is true.")
  157. parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
  158. parser.add_argument("--accumulation_steps", type=int, default="1",
  159. help="Accumulating gradients N times before weight update, default is 1.")
  160. parser.add_argument("--allreduce_post_accumulation", type=str, default="true", choices=["true", "false"],
  161. help="Whether to allreduce after accumulation of N steps or after each step, default is true.")
  162. parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
  163. parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
  164. parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
  165. "default is 1000.")
  166. parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, "
  167. "meaning run all steps according to epoch number.")
  168. parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.")
  169. parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
  170. parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
  171. parser.add_argument("--enable_graph_kernel", type=str, default="auto", choices=["auto", "true", "false"],
  172. help="Accelerate by graph kernel, default is auto.")
  173. return parser
  174. def run_pretrain():
  175. """pre-train bert_clue"""
  176. parser = argparse_init()
  177. args_opt = parser.parse_args()
  178. context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
  179. context.set_context(reserve_class_name_in_scope=False)
  180. is_auto_enable_graph_kernel = _auto_enable_graph_kernel(args_opt.device_target, args_opt.enable_graph_kernel)
  181. _set_graph_kernel_context(args_opt.device_target, args_opt.enable_graph_kernel, is_auto_enable_graph_kernel)
  182. ckpt_save_dir = args_opt.save_checkpoint_path
  183. if args_opt.distribute == "true":
  184. if args_opt.device_target == 'Ascend':
  185. D.init()
  186. device_num = args_opt.device_num
  187. rank = args_opt.device_id % device_num
  188. else:
  189. D.init()
  190. device_num = D.get_group_size()
  191. rank = D.get_rank()
  192. ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(get_rank()) + '/'
  193. context.reset_auto_parallel_context()
  194. context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
  195. device_num=device_num)
  196. _set_bert_all_reduce_split()
  197. else:
  198. rank = 0
  199. device_num = 1
  200. _check_compute_type(args_opt, is_auto_enable_graph_kernel)
  201. if args_opt.accumulation_steps > 1:
  202. logger.info("accumulation steps: {}".format(args_opt.accumulation_steps))
  203. logger.info("global batch size: {}".format(cfg.batch_size * args_opt.accumulation_steps))
  204. if args_opt.enable_data_sink == "true":
  205. args_opt.data_sink_steps *= args_opt.accumulation_steps
  206. logger.info("data sink steps: {}".format(args_opt.data_sink_steps))
  207. if args_opt.enable_save_ckpt == "true":
  208. args_opt.save_checkpoint_steps *= args_opt.accumulation_steps
  209. logger.info("save checkpoint steps: {}".format(args_opt.save_checkpoint_steps))
  210. ds = create_bert_dataset(device_num, rank, args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
  211. net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
  212. new_repeat_count = args_opt.epoch_size * ds.get_dataset_size() // args_opt.data_sink_steps
  213. if args_opt.train_steps > 0:
  214. train_steps = args_opt.train_steps * args_opt.accumulation_steps
  215. new_repeat_count = min(new_repeat_count, train_steps // args_opt.data_sink_steps)
  216. else:
  217. args_opt.train_steps = args_opt.epoch_size * ds.get_dataset_size() // args_opt.accumulation_steps
  218. logger.info("train steps: {}".format(args_opt.train_steps))
  219. optimizer = _get_optimizer(args_opt, net_with_loss)
  220. callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack(ds.get_dataset_size())]
  221. if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
  222. config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
  223. keep_checkpoint_max=args_opt.save_checkpoint_num)
  224. ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert',
  225. directory=None if ckpt_save_dir == "" else ckpt_save_dir, config=config_ck)
  226. callback.append(ckpoint_cb)
  227. if args_opt.load_checkpoint_path:
  228. param_dict = load_checkpoint(args_opt.load_checkpoint_path)
  229. load_param_into_net(net_with_loss, param_dict)
  230. if args_opt.enable_lossscale == "true":
  231. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=cfg.loss_scale_value,
  232. scale_factor=cfg.scale_factor,
  233. scale_window=cfg.scale_window)
  234. accumulation_steps = args_opt.accumulation_steps
  235. enable_global_norm = cfg.enable_global_norm
  236. if accumulation_steps <= 1:
  237. if cfg.optimizer == 'AdamWeightDecay' and args_opt.device_target == 'GPU':
  238. net_with_grads = BertTrainOneStepWithLossScaleCellForAdam(net_with_loss, optimizer=optimizer,
  239. scale_update_cell=update_cell)
  240. else:
  241. net_with_grads = BertTrainOneStepWithLossScaleCell(net_with_loss, optimizer=optimizer,
  242. scale_update_cell=update_cell)
  243. else:
  244. allreduce_post = args_opt.distribute == "false" or args_opt.allreduce_post_accumulation == "true"
  245. net_with_accumulation = (BertTrainAccumulationAllReducePostWithLossScaleCell if allreduce_post else
  246. BertTrainAccumulationAllReduceEachWithLossScaleCell)
  247. net_with_grads = net_with_accumulation(net_with_loss, optimizer=optimizer,
  248. scale_update_cell=update_cell,
  249. accumulation_steps=accumulation_steps,
  250. enable_global_norm=enable_global_norm)
  251. else:
  252. net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
  253. model = Model(net_with_grads)
  254. model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer,
  255. frequency=cfg.Thor.frequency)
  256. model.train(new_repeat_count, ds, callbacks=callback,
  257. dataset_sink_mode=(args_opt.enable_data_sink == "true"), sink_size=args_opt.data_sink_steps)
  258. if __name__ == '__main__':
  259. set_seed(0)
  260. run_pretrain()