Browse Source

[Bert][Gpu]set all reduce split

pull/11782/head
hanhuifeng2020 5 years ago
parent
commit
87676bd4de
1 changed files with 8 additions and 3 deletions
  1. +8
    -3
      model_zoo/official/nlp/bert/run_pretrain.py

+ 8
- 3
model_zoo/official/nlp/bert/run_pretrain.py View File

@@ -40,15 +40,20 @@ from src.utils import LossCallBack, BertLearningRate
_current_dir = os.path.dirname(os.path.realpath(__file__))


def _set_bert_all_reduce_split(device_target='Ascend', enable_graph_kernel=False):
def _set_bert_all_reduce_split():
"""set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
device_target = context.get_context('device_target')
enable_graph_kernel = context.get_context('enable_graph_kernel')
device_num = context.get_auto_parallel_context('device_num')
if bert_net_cfg.num_hidden_layers == 12:
if bert_net_cfg.use_relative_positions:
context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217])
else:
context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205])
if device_target == 'GPU' and enable_graph_kernel:
if device_target == 'GPU' and enable_graph_kernel and device_num == 8:
context.set_auto_parallel_context(all_reduce_fusion_config=[180, 205])
elif device_target == 'GPU' and enable_graph_kernel and device_num == 16:
context.set_auto_parallel_context(all_reduce_fusion_config=[120, 205])
elif bert_net_cfg.num_hidden_layers == 24:
if bert_net_cfg.use_relative_positions:
context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421])
@@ -156,7 +161,7 @@ def run_pretrain():
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
_set_bert_all_reduce_split(args_opt.device_target, context.get_context('enable_graph_kernel'))
_set_bert_all_reduce_split()
else:
rank = 0
device_num = 1


Loading…
Cancel
Save