| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Communication management API""" | """Communication management API""" | ||||
| import os | import os | ||||
| from mindspore import context | |||||
| from mindspore.parallel._auto_parallel_context import auto_parallel_context | from mindspore.parallel._auto_parallel_context import auto_parallel_context | ||||
| from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ | from ._comm_helper import Backend, _get_rank_helper, _get_size_helper, \ | ||||
| _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ | _get_world_rank_from_group_rank_helper, _get_group_rank_from_world_rank_helper, \ | ||||
| @@ -45,7 +46,7 @@ class GlobalComm: | |||||
| WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP | WORLD_COMM_GROUP = DEFAULT_WORLD_COMM_GROUP | ||||
| def init(backend_name="hccl"): | |||||
| def init(backend_name=None): | |||||
| """ | """ | ||||
| Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used. | Init distributed backend, e.g., hccl/nccl, it is required before communication service can be used. | ||||
| @@ -57,11 +58,20 @@ def init(backend_name="hccl"): | |||||
| backend_name (str): Backend. | backend_name (str): Backend. | ||||
| Raises: | Raises: | ||||
| TypeError: If backend name is not a string. | |||||
| TypeError: If backen_name is not a string. | |||||
| RuntimeError: If device target is invalid. | |||||
| RuntimeError: If backend is invalid or distributed init fails. | RuntimeError: If backend is invalid or distributed init fails. | ||||
| """ | """ | ||||
| if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | if MS_ROLE in ("MS_PSERVER", "MS_SCHED"): | ||||
| return | return | ||||
| if backend_name is None: | |||||
| device_target = context.get_context("device_target") | |||||
| if device_target == "Ascend": | |||||
| backend_name = "hccl" | |||||
| elif device_target == "GPU": | |||||
| backend_name = "nccl" | |||||
| else: | |||||
| raise RuntimeError("Device target {} is not supported.".format(device_target)) | |||||
| if not isinstance(backend_name, str): | if not isinstance(backend_name, str): | ||||
| raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) | raise TypeError("Backend name must be a string, but got {}".format(type(backend_name))) | ||||
| @@ -73,7 +73,7 @@ class AllReduce(PrimitiveWithInfer): | |||||
| >>> import mindspore.nn as nn | >>> import mindspore.nn as nn | ||||
| >>> import mindspore.ops.operations as P | >>> import mindspore.ops.operations as P | ||||
| >>> | >>> | ||||
| >>> init('nccl') | |||||
| >>> init() | |||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| >>> def __init__(self): | >>> def __init__(self): | ||||
| >>> super(Net, self).__init__() | >>> super(Net, self).__init__() | ||||
| @@ -136,7 +136,7 @@ class AllGather(PrimitiveWithInfer): | |||||
| >>> from mindspore.communication import init | >>> from mindspore.communication import init | ||||
| >>> from mindspore import Tensor | >>> from mindspore import Tensor | ||||
| >>> | >>> | ||||
| >>> init('nccl') | |||||
| >>> init() | |||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| >>> def __init__(self): | >>> def __init__(self): | ||||
| >>> super(Net, self).__init__() | >>> super(Net, self).__init__() | ||||
| @@ -246,7 +246,7 @@ class ReduceScatter(PrimitiveWithInfer): | |||||
| >>> import mindspore.nn as nn | >>> import mindspore.nn as nn | ||||
| >>> import mindspore.ops.operations as P | >>> import mindspore.ops.operations as P | ||||
| >>> | >>> | ||||
| >>> init('nccl') | |||||
| >>> init() | |||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| >>> def __init__(self): | >>> def __init__(self): | ||||
| >>> super(Net, self).__init__() | >>> super(Net, self).__init__() | ||||
| @@ -360,7 +360,7 @@ class Broadcast(PrimitiveWithInfer): | |||||
| >>> import mindspore.nn as nn | >>> import mindspore.nn as nn | ||||
| >>> import mindspore.ops.operations as P | >>> import mindspore.ops.operations as P | ||||
| >>> | >>> | ||||
| >>> init('nccl') | |||||
| >>> init() | |||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| >>> def __init__(self): | >>> def __init__(self): | ||||
| >>> super(Net, self).__init__() | >>> super(Net, self).__init__() | ||||
| @@ -81,7 +81,7 @@ if __name__ == '__main__': | |||||
| mirror_mean=True) | mirror_mean=True) | ||||
| init() | init() | ||||
| elif device_target == "GPU": | elif device_target == "GPU": | ||||
| init("nccl") | |||||
| init() | |||||
| if device_num > 1: | if device_num > 1: | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| @@ -57,10 +57,7 @@ if __name__ == '__main__': | |||||
| cfg = config_ascend if args_opt.platform == 'Ascend' else config_gpu | cfg = config_ascend if args_opt.platform == 'Ascend' else config_gpu | ||||
| # init distributed | # init distributed | ||||
| if args_opt.is_distributed: | if args_opt.is_distributed: | ||||
| if args_opt.platform == "Ascend": | |||||
| init() | |||||
| else: | |||||
| init("nccl") | |||||
| init() | |||||
| cfg.rank = get_rank() | cfg.rank = get_rank() | ||||
| cfg.group_size = get_group_size() | cfg.group_size = get_group_size() | ||||
| parallel_mode = ParallelMode.DATA_PARALLEL | parallel_mode = ParallelMode.DATA_PARALLEL | ||||
| @@ -64,7 +64,7 @@ elif args_opt.device_target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target="GPU", | device_target="GPU", | ||||
| save_graphs=False) | save_graphs=False) | ||||
| init("nccl") | |||||
| init() | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), | context.set_auto_parallel_context(device_num=get_group_size(), | ||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| mirror_mean=True) | mirror_mean=True) | ||||
| @@ -57,7 +57,7 @@ if args_opt.device_target == "Ascend": | |||||
| device_target="Ascend", | device_target="Ascend", | ||||
| device_id=device_id, save_graphs=False) | device_id=device_id, save_graphs=False) | ||||
| elif args_opt.device_target == "GPU": | elif args_opt.device_target == "GPU": | ||||
| init("nccl") | |||||
| init() | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), | context.set_auto_parallel_context(device_num=get_group_size(), | ||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| mirror_mean=True) | mirror_mean=True) | ||||
| @@ -54,7 +54,7 @@ if args_opt.device_target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target="GPU", | device_target="GPU", | ||||
| save_graphs=False) | save_graphs=False) | ||||
| init("nccl") | |||||
| init() | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), | context.set_auto_parallel_context(device_num=get_group_size(), | ||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| mirror_mean=True) | mirror_mean=True) | ||||
| @@ -38,7 +38,7 @@ def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target= | |||||
| if target == "Ascend": | if target == "Ascend": | ||||
| device_num, rank_id = _get_rank_info() | device_num, rank_id = _get_rank_info() | ||||
| else: | else: | ||||
| init("nccl") | |||||
| init() | |||||
| rank_id = get_rank() | rank_id = get_rank() | ||||
| device_num = get_group_size() | device_num = get_group_size() | ||||
| @@ -93,7 +93,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= | |||||
| if target == "Ascend": | if target == "Ascend": | ||||
| device_num, rank_id = _get_rank_info() | device_num, rank_id = _get_rank_info() | ||||
| else: | else: | ||||
| init("nccl") | |||||
| init() | |||||
| rank_id = get_rank() | rank_id = get_rank() | ||||
| device_num = get_group_size() | device_num = get_group_size() | ||||
| @@ -85,7 +85,7 @@ if __name__ == '__main__': | |||||
| init() | init() | ||||
| # GPU target | # GPU target | ||||
| else: | else: | ||||
| init("nccl") | |||||
| init() | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| mirror_mean=True) | mirror_mean=True) | ||||
| if args_opt.net == "resnet50": | if args_opt.net == "resnet50": | ||||
| @@ -46,7 +46,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" | |||||
| device_num = int(os.getenv("RANK_SIZE")) | device_num = int(os.getenv("RANK_SIZE")) | ||||
| rank_id = int(os.getenv("RANK_ID")) | rank_id = int(os.getenv("RANK_ID")) | ||||
| else: | else: | ||||
| init("nccl") | |||||
| init() | |||||
| rank_id = get_rank() | rank_id = get_rank() | ||||
| device_num = get_group_size() | device_num = get_group_size() | ||||
| @@ -114,7 +114,7 @@ def create_dataset_py(dataset_path, do_train, repeat_num=1, batch_size=32, targe | |||||
| device_num = int(os.getenv("RANK_SIZE")) | device_num = int(os.getenv("RANK_SIZE")) | ||||
| rank_id = int(os.getenv("RANK_ID")) | rank_id = int(os.getenv("RANK_ID")) | ||||
| else: | else: | ||||
| init("nccl") | |||||
| init() | |||||
| rank_id = get_rank() | rank_id = get_rank() | ||||
| device_num = get_group_size() | device_num = get_group_size() | ||||
| @@ -40,7 +40,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target=" | |||||
| if target == "Ascend": | if target == "Ascend": | ||||
| device_num, rank_id = _get_rank_info() | device_num, rank_id = _get_rank_info() | ||||
| else: | else: | ||||
| init("nccl") | |||||
| init() | |||||
| rank_id = get_rank() | rank_id = get_rank() | ||||
| device_num = get_group_size() | device_num = get_group_size() | ||||
| @@ -106,7 +106,7 @@ if __name__ == '__main__': | |||||
| init() | init() | ||||
| # GPU target | # GPU target | ||||
| else: | else: | ||||
| init("nccl") | |||||
| init() | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| mirror_mean=True) | mirror_mean=True) | ||||
| ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | ||||
| @@ -112,10 +112,7 @@ def test(cloud_args=None): | |||||
| # init distributed | # init distributed | ||||
| if args.is_distributed: | if args.is_distributed: | ||||
| if args.platform == "Ascend": | |||||
| init() | |||||
| elif args.platform == "GPU": | |||||
| init("nccl") | |||||
| init() | |||||
| args.rank = get_rank() | args.rank = get_rank() | ||||
| args.group_size = get_group_size() | args.group_size = get_group_size() | ||||
| parallel_mode = ParallelMode.DATA_PARALLEL | parallel_mode = ParallelMode.DATA_PARALLEL | ||||
| @@ -172,10 +172,7 @@ def train(cloud_args=None): | |||||
| # init distributed | # init distributed | ||||
| if args.is_distributed: | if args.is_distributed: | ||||
| if args.platform == "Ascend": | |||||
| init() | |||||
| else: | |||||
| init("nccl") | |||||
| init() | |||||
| args.rank = get_rank() | args.rank = get_rank() | ||||
| args.group_size = get_group_size() | args.group_size = get_group_size() | ||||
| parallel_mode = ParallelMode.DATA_PARALLEL | parallel_mode = ParallelMode.DATA_PARALLEL | ||||
| @@ -135,7 +135,7 @@ if __name__ == '__main__': | |||||
| init() | init() | ||||
| context.set_context(device_id=args.device_id) | context.set_context(device_id=args.device_id) | ||||
| elif args.device_target == "GPU": | elif args.device_target == "GPU": | ||||
| init("nccl") | |||||
| init() | |||||
| args.rank = get_rank() | args.rank = get_rank() | ||||
| args.group_size = get_group_size() | args.group_size = get_group_size() | ||||
| @@ -60,7 +60,7 @@ if __name__ == '__main__': | |||||
| device_num = int(os.environ.get("RANK_SIZE")) | device_num = int(os.environ.get("RANK_SIZE")) | ||||
| rank = int(os.environ.get("RANK_ID")) | rank = int(os.environ.get("RANK_ID")) | ||||
| else: | else: | ||||
| init('nccl') | |||||
| init() | |||||
| lr_scale = 0.5 | lr_scale = 0.5 | ||||
| device_num = get_group_size() | device_num = get_group_size() | ||||
| rank = get_rank() | rank = get_rank() | ||||
| @@ -70,11 +70,11 @@ def run_pretrain(): | |||||
| ckpt_save_dir = args_opt.save_checkpoint_path | ckpt_save_dir = args_opt.save_checkpoint_path | ||||
| if args_opt.distribute == "true": | if args_opt.distribute == "true": | ||||
| if args_opt.device_target == 'Ascend': | if args_opt.device_target == 'Ascend': | ||||
| D.init('hccl') | |||||
| D.init() | |||||
| device_num = args_opt.device_num | device_num = args_opt.device_num | ||||
| rank = args_opt.device_id % device_num | rank = args_opt.device_id % device_num | ||||
| else: | else: | ||||
| D.init('nccl') | |||||
| D.init() | |||||
| device_num = D.get_group_size() | device_num = D.get_group_size() | ||||
| rank = D.get_rank() | rank = D.get_rank() | ||||
| ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' | ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' | ||||
| @@ -73,11 +73,11 @@ def run_pretrain(): | |||||
| ckpt_save_dir = args_opt.save_checkpoint_path | ckpt_save_dir = args_opt.save_checkpoint_path | ||||
| if args_opt.distribute == "true": | if args_opt.distribute == "true": | ||||
| if args_opt.device_target == 'Ascend': | if args_opt.device_target == 'Ascend': | ||||
| D.init('hccl') | |||||
| D.init() | |||||
| device_num = args_opt.device_num | device_num = args_opt.device_num | ||||
| rank = args_opt.device_id % device_num | rank = args_opt.device_id % device_num | ||||
| else: | else: | ||||
| D.init('nccl') | |||||
| D.init() | |||||
| device_num = D.get_group_size() | device_num = D.get_group_size() | ||||
| rank = D.get_rank() | rank = D.get_rank() | ||||
| ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' | ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/' | ||||
| @@ -227,10 +227,7 @@ def _build_training_pipeline(config: TransformerConfig, | |||||
| def _setup_parallel_env(platform): | def _setup_parallel_env(platform): | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| if platform == "GPU": | |||||
| MultiAscend.init("nccl") | |||||
| else: | |||||
| MultiAscend.init() | |||||
| MultiAscend.init() | |||||
| context.set_auto_parallel_context( | context.set_auto_parallel_context( | ||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| device_num=MultiAscend.get_group_size(), | device_num=MultiAscend.get_group_size(), | ||||
| @@ -67,11 +67,11 @@ def run_general_distill(): | |||||
| if args_opt.distribute == "true": | if args_opt.distribute == "true": | ||||
| if args_opt.device_target == 'Ascend': | if args_opt.device_target == 'Ascend': | ||||
| D.init('hccl') | |||||
| D.init() | |||||
| device_num = args_opt.device_num | device_num = args_opt.device_num | ||||
| rank = args_opt.device_id % device_num | rank = args_opt.device_id % device_num | ||||
| else: | else: | ||||
| D.init('nccl') | |||||
| D.init() | |||||
| device_num = D.get_group_size() | device_num = D.get_group_size() | ||||
| rank = D.get_rank() | rank = D.get_rank() | ||||
| save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank) | save_ckpt_dir = save_ckpt_dir + '_ckpt_' + str(rank) | ||||
| @@ -59,7 +59,7 @@ if __name__ == '__main__': | |||||
| init() | init() | ||||
| rank_id = int(os.environ.get('RANK_ID')) | rank_id = int(os.environ.get('RANK_ID')) | ||||
| elif args_opt.device_target == "GPU": | elif args_opt.device_target == "GPU": | ||||
| init("nccl") | |||||
| init() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) | ||||
| context.reset_auto_parallel_context() | context.reset_auto_parallel_context() | ||||
| context.set_auto_parallel_context(device_num=get_group_size(), | context.set_auto_parallel_context(device_num=get_group_size(), | ||||
| @@ -128,10 +128,7 @@ if __name__ == "__main__": | |||||
| context.set_context(variable_memory_max_size="24GB") | context.set_context(variable_memory_max_size="24GB") | ||||
| context.set_context(enable_sparse=True) | context.set_context(enable_sparse=True) | ||||
| set_multi_subgraphs() | set_multi_subgraphs() | ||||
| if wide_deep_config.device_target == "Ascend": | |||||
| init("hccl") | |||||
| elif wide_deep_config.device_target == "GPU": | |||||
| init("nccl") | |||||
| init() | |||||
| if wide_deep_config.host_device_mix == 1: | if wide_deep_config.host_device_mix == 1: | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) | context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) | ||||
| else: | else: | ||||
| @@ -122,10 +122,7 @@ if __name__ == "__main__": | |||||
| wide_deep_config.argparse_init() | wide_deep_config.argparse_init() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) | context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) | ||||
| if wide_deep_config.device_target == "Ascend": | |||||
| init("hccl") | |||||
| elif wide_deep_config.device_target == "GPU": | |||||
| init("nccl") | |||||
| init() | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | ||||
| device_num=get_group_size()) | device_num=get_group_size()) | ||||
| @@ -119,10 +119,7 @@ if __name__ == "__main__": | |||||
| wide_deep_config.argparse_init() | wide_deep_config.argparse_init() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) | ||||
| if wide_deep_config.device_target == "Ascend": | |||||
| init("hccl") | |||||
| elif wide_deep_config.device_target == "GPU": | |||||
| init("nccl") | |||||
| init() | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | ||||
| device_num=get_group_size()) | device_num=get_group_size()) | ||||
| @@ -24,7 +24,7 @@ from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | ||||
| init('nccl') | |||||
| init() | |||||
| rank = get_rank() | rank = get_rank() | ||||
| size = get_group_size() | size = get_group_size() | ||||
| x = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) | x = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) | ||||
| @@ -24,7 +24,7 @@ from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | ||||
| init('nccl') | |||||
| init() | |||||
| rank = get_rank() | rank = get_rank() | ||||
| size = get_group_size() | size = get_group_size() | ||||
| x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) | x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) | ||||
| @@ -24,7 +24,7 @@ from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | ||||
| init('nccl') | |||||
| init() | |||||
| rank = get_rank() | rank = get_rank() | ||||
| size = get_group_size() | size = get_group_size() | ||||
| x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) | x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) | ||||
| @@ -25,7 +25,7 @@ from mindspore.nn.optim import Momentum | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| init('nccl') | |||||
| init() | |||||
| epoch = 5 | epoch = 5 | ||||
| total = 5000 | total = 5000 | ||||
| @@ -24,7 +24,7 @@ from mindspore.ops import operations as P | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | context.set_context(mode=context.GRAPH_MODE, device_target='GPU') | ||||
| init('nccl') | |||||
| init() | |||||
| rank = get_rank() | rank = get_rank() | ||||
| size = get_group_size() | size = get_group_size() | ||||
| x = np.ones([size, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) | x = np.ones([size, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) | ||||
| @@ -30,7 +30,7 @@ args, _ = parser.parse_known_args() | |||||
| device_target = args.device_target | device_target = args.device_target | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=device_target) | ||||
| if device_target == "GPU": | if device_target == "GPU": | ||||
| init('nccl') | |||||
| init() | |||||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | ||||
| @@ -75,7 +75,7 @@ def test_dataset_iter_normal(): | |||||
| @pytest.mark.skipif('not context.get_context("enable_ge")') | @pytest.mark.skipif('not context.get_context("enable_ge")') | ||||
| def test_dataset_iter_ge(): | def test_dataset_iter_ge(): | ||||
| init() | |||||
| init("hccl") | |||||
| dataset = get_dataset(32) | dataset = get_dataset(32) | ||||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | ||||
| count = 0 | count = 0 | ||||
| @@ -87,7 +87,7 @@ def test_dataset_iter_ge(): | |||||
| @pytest.mark.skipif('context.get_context("enable_ge")') | @pytest.mark.skipif('context.get_context("enable_ge")') | ||||
| def test_dataset_iter_ms_loop_sink(): | def test_dataset_iter_ms_loop_sink(): | ||||
| init() | |||||
| init("hccl") | |||||
| context.set_context(enable_loop_sink=True) | context.set_context(enable_loop_sink=True) | ||||
| dataset = get_dataset(32) | dataset = get_dataset(32) | ||||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | ||||
| @@ -101,7 +101,7 @@ def test_dataset_iter_ms_loop_sink(): | |||||
| @pytest.mark.skipif('context.get_context("enable_ge")') | @pytest.mark.skipif('context.get_context("enable_ge")') | ||||
| def test_dataset_iter_ms(): | def test_dataset_iter_ms(): | ||||
| init() | |||||
| init("hccl") | |||||
| context.set_context(enable_loop_sink=False) | context.set_context(enable_loop_sink=False) | ||||
| dataset = get_dataset(32) | dataset = get_dataset(32) | ||||
| DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | ||||