| @@ -107,6 +107,32 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None, | |||||
| logger.info("repeat count: {}".format(ds.get_repeat_count())) | logger.info("repeat count: {}".format(ds.get_repeat_count())) | ||||
| return ds | return ds | ||||
| def _set_bert_all_reduce_split(): | |||||
| """set bert all_reduce fusion split, support num_hidden_layers is 12 and 24.""" | |||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||||
| if bert_net_cfg.num_hidden_layers == 12: | |||||
| if bert_net_cfg.use_relative_positions: | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217], | |||||
| "hccl_world_groupsum1") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217], | |||||
| "hccl_world_groupsum3") | |||||
| else: | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205], | |||||
| "hccl_world_groupsum1") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205], | |||||
| "hccl_world_groupsum3") | |||||
| elif bert_net_cfg.num_hidden_layers == 24: | |||||
| if bert_net_cfg.use_relative_positions: | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421], | |||||
| "hccl_world_groupsum1") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421], | |||||
| "hccl_world_groupsum3") | |||||
| else: | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([38, 77], "hccl_world_groupsum1") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([38, 77], "hccl_world_groupsum3") | |||||
| def train_process_bert_thor(q, device_id, epoch_size, device_num): | def train_process_bert_thor(q, device_id, epoch_size, device_num): | ||||
| os.system("mkdir " + str(device_id)) | os.system("mkdir " + str(device_id)) | ||||
| os.chdir(str(device_id)) | os.chdir(str(device_id)) | ||||
| @@ -120,10 +146,11 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num): | |||||
| D.init() | D.init() | ||||
| rank = device_id % device_num | rank = device_id % device_num | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| _set_bert_all_reduce_split() | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | ||||
| device_num=device_num) | device_num=device_num) | ||||
| bert_net_cfg.num_hidden_layers = 2 | |||||
| bert_net_cfg.num_hidden_layers = 4 | |||||
| ds = create_bert_dataset(device_num=device_num, rank=rank, do_shuffle=False, data_dir=DATASET_PATH, schema_dir=None) | ds = create_bert_dataset(device_num=device_num, rank=rank, do_shuffle=False, data_dir=DATASET_PATH, schema_dir=None) | ||||
| net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) | net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) | ||||
| @@ -200,8 +227,8 @@ def test_bert_thor_mlperf_8p(): | |||||
| os.system("rm -rf " + str(i)) | os.system("rm -rf " + str(i)) | ||||
| print("End training...") | print("End training...") | ||||
| assert mean_cost < 51 | |||||
| assert mean_loss < 8.5 | |||||
| assert mean_cost < 64.2 | |||||
| assert mean_loss < 7.9 | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_bert_thor_mlperf_8p() | test_bert_thor_mlperf_8p() | ||||