| @@ -84,9 +84,15 @@ def run_pretrain(): | |||||
| device_num=device_num) | device_num=device_num) | ||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | from mindspore.parallel._auto_parallel_context import auto_parallel_context | ||||
| if bert_net_cfg.num_hidden_layers == 12: | if bert_net_cfg.num_hidden_layers == 12: | ||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) | |||||
| if bert_net_cfg.use_relative_positions: | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217]) | |||||
| else: | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) | |||||
| elif bert_net_cfg.num_hidden_layers == 24: | elif bert_net_cfg.num_hidden_layers == 24: | ||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) | |||||
| if bert_net_cfg.use_relative_positions: | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421]) | |||||
| else: | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) | |||||
| D.init() | D.init() | ||||
| rank = args_opt.device_id % device_num | rank = args_opt.device_id % device_num | ||||
| else: | else: | ||||