|
|
|
@@ -84,9 +84,15 @@ def run_pretrain(): |
|
|
|
device_num=device_num) |
|
|
|
from mindspore.parallel._auto_parallel_context import auto_parallel_context |
|
|
|
if bert_net_cfg.num_hidden_layers == 12: |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) |
|
|
|
if bert_net_cfg.use_relative_positions: |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217]) |
|
|
|
else: |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) |
|
|
|
elif bert_net_cfg.num_hidden_layers == 24: |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) |
|
|
|
if bert_net_cfg.use_relative_positions: |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421]) |
|
|
|
else: |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) |
|
|
|
D.init() |
|
|
|
rank = args_opt.device_id % device_num |
|
|
|
else: |
|
|
|
|