|
|
|
@@ -153,8 +153,8 @@ def run_pretrain(): |
|
|
|
device_num = D.get_group_size() |
|
|
|
rank = D.get_rank() |
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' |
|
|
|
_set_bert_all_reduce_split() |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
_set_bert_all_reduce_split() |
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, |
|
|
|
device_num=device_num) |
|
|
|
|
|
|
|
|