| @@ -140,7 +140,7 @@ if __name__ == '__main__': | |||||
| device_num = args.group_size | device_num = args.group_size | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| gradients_mean=True) | |||||
| gradients_mean=True, all_reduce_fusion_config=[3, 10, 12, 15]) | |||||
| else: | else: | ||||
| if args.device_target == "Ascend": | if args.device_target == "Ascend": | ||||
| context.set_context(device_id=args.device_id) | context.set_context(device_id=args.device_id) | ||||
| @@ -57,7 +57,9 @@ if __name__ == '__main__': | |||||
| device_id = int(os.getenv('DEVICE_ID')) | device_id = int(os.getenv('DEVICE_ID')) | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) | context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id) | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True, | |||||
| all_reduce_fusion_config=[9, 11]) | |||||
| init() | init() | ||||
| rank_id = int(os.environ.get('RANK_ID')) | rank_id = int(os.environ.get('RANK_ID')) | ||||
| elif args_opt.device_target == "GPU": | elif args_opt.device_target == "GPU": | ||||
| @@ -125,6 +125,6 @@ if __name__ == "__main__": | |||||
| init() | init() | ||||
| context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank())) | context.set_context(save_graphs_path='./graphs_of_device_id_'+str(get_rank())) | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, | ||||
| device_num=get_group_size()) | |||||
| device_num=get_group_size(), all_reduce_fusion_config=[6, 12]) | |||||
| train_and_eval(wide_deep_config) | train_and_eval(wide_deep_config) | ||||