diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 16551615d8..9777424678 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -40,15 +40,20 @@ from src.utils import LossCallBack, BertLearningRate _current_dir = os.path.dirname(os.path.realpath(__file__)) -def _set_bert_all_reduce_split(device_target='Ascend', enable_graph_kernel=False): +def _set_bert_all_reduce_split(): """set bert all_reduce fusion split, support num_hidden_layers is 12 and 24.""" + device_target = context.get_context('device_target') + enable_graph_kernel = context.get_context('enable_graph_kernel') + device_num = context.get_auto_parallel_context('device_num') if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.use_relative_positions: context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217]) else: context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205]) - if device_target == 'GPU' and enable_graph_kernel: + if device_target == 'GPU' and enable_graph_kernel and device_num == 8: context.set_auto_parallel_context(all_reduce_fusion_config=[180, 205]) + elif device_target == 'GPU' and enable_graph_kernel and device_num == 16: + context.set_auto_parallel_context(all_reduce_fusion_config=[120, 205]) elif bert_net_cfg.num_hidden_layers == 24: if bert_net_cfg.use_relative_positions: context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421]) @@ -156,7 +161,7 @@ def run_pretrain(): context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) - _set_bert_all_reduce_split(args_opt.device_target, context.get_context('enable_graph_kernel')) + _set_bert_all_reduce_split() else: rank = 0 device_num = 1