| @@ -84,6 +84,7 @@ class GlobalComm: | |||
| BACKEND = DEFAULT_BACKEND | |||
| WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP | |||
| INITED = False | |||
| CHECK_ENVS = True | |||
| def is_hccl_available(): | |||
| """ | |||
| @@ -24,7 +24,7 @@ HCCL_LIB_CTYPES = "" | |||
| def check_group(group): | |||
| """ | |||
| A function that check if a collection communication group is leagal. | |||
| A function that check if a collection communication group is legal. | |||
| Returns: | |||
| None | |||
| @@ -39,7 +39,7 @@ def check_group(group): | |||
| def check_rank_num(rank_num): | |||
| """ | |||
| A function that check if a collection communication rank number is leagal.If not raise error. | |||
| A function that check if a collection communication rank number is legal.If not raise error. | |||
| Returns: | |||
| None | |||
| @@ -53,7 +53,7 @@ def check_rank_num(rank_num): | |||
| def check_rank_id(rank_id): | |||
| """ | |||
| A function that check if a collection communication rank id is leagal.If not raise error. | |||
| A function that check if a collection communication rank id is legal.If not raise error. | |||
| Returns: | |||
| None | |||
| @@ -112,7 +112,7 @@ def create_group(group, rank_num, rank_ids): | |||
| c_group = c_str(group) | |||
| ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids) | |||
| if ret != 0: | |||
| raise RuntimeError('Create group error.') | |||
| raise RuntimeError('Create group error, the error code is ', ret) | |||
| else: | |||
| raise TypeError('Rank ids must be a python list.') | |||
| @@ -36,6 +36,28 @@ def _get_group(group): | |||
| return group | |||
| def _check_parallel_envs(): | |||
| """ | |||
| Check whether parallel environment variables have been exported or not. | |||
| Raises: | |||
| RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values. | |||
| """ | |||
| if not GlobalComm.CHECK_ENVS: | |||
| return | |||
| import os | |||
| rank_id_str = os.getenv("RANK_ID") | |||
| if not rank_id_str: | |||
| raise RuntimeError("Environment variables RANK_ID has not been exported") | |||
| try: | |||
| int(rank_id_str) | |||
| except ValueError: | |||
| print("RANK_ID should be number") | |||
| rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH") | |||
| rank_table_file_str_old = os.getenv("RANK_TABLE_FILE") | |||
| if not rank_table_file_str and not rank_table_file_str_old: | |||
| raise RuntimeError("Get hccl rank_table_file failed, " | |||
| "please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.") | |||
| def init(backend_name=None): | |||
| """ | |||
| @@ -68,6 +90,7 @@ def init(backend_name=None): | |||
| if backend_name == "hccl": | |||
| if device_target != "Ascend": | |||
| raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target)) | |||
| _check_parallel_envs() | |||
| init_hccl() | |||
| GlobalComm.BACKEND = Backend("hccl") | |||
| GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP | |||
| @@ -30,13 +30,16 @@ from mindspore.ops.operations.comm_ops import Broadcast, AllSwap | |||
| from mindspore.ops.operations.array_ops import Gather | |||
| import mindspore | |||
| # pylint: disable=W0212 | |||
| # W0212: protected-access | |||
| tag = 0 | |||
| context.set_context(device_target="Ascend") | |||
| GlobalComm.CHECK_ENVS = False | |||
| init("hccl") | |||
| GlobalComm.CHECK_ENVS = True | |||
| class AllReduceNet(nn.Cell): | |||
| @@ -31,11 +31,13 @@ from mindspore.parallel import set_algo_parameters | |||
| from mindspore.parallel._utils import _reset_op_id as resset_op_id | |||
| from mindspore.train.model import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication._comm_helper import GlobalComm | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| context.set_context(device_id=0) | |||
| GlobalComm.CHECK_ENVS = False | |||
| init() | |||
| GlobalComm.CHECK_ENVS = True | |||
| def weight_variable(): | |||
| return TruncatedNormal(0.02) | |||
| @@ -18,11 +18,15 @@ from mindspore.communication.management import init | |||
| from mindspore.parallel import set_algo_parameters | |||
| from mindspore.train.model import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication._comm_helper import GlobalComm | |||
| from .test_auto_parallel_resnet import resnet50 | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| context.set_context(device_id=0) | |||
| GlobalComm.CHECK_ENVS = False | |||
| init() | |||
| GlobalComm.CHECK_ENVS = True | |||
| def test_train_32k_8p(batch_size=32, num_classes=32768): | |||
| dev_num = 8 | |||
| @@ -19,7 +19,7 @@ import mindspore.nn as nn | |||
| from mindspore import Tensor, Parameter | |||
| from mindspore.communication.management import init | |||
| from mindspore.ops import operations as P | |||
| from mindspore.communication._comm_helper import GlobalComm | |||
| class DataParallelNet(nn.Cell): | |||
| def __init__(self): | |||
| @@ -49,7 +49,9 @@ def test_param_broadcast(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=True) | |||
| GlobalComm.CHECK_ENVS = False | |||
| init() | |||
| GlobalComm.CHECK_ENVS = True | |||
| network = DataParallelNet() | |||
| network.set_train() | |||
| @@ -62,7 +64,9 @@ def test_param_not_broadcast(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=False) | |||
| GlobalComm.CHECK_ENVS = False | |||
| init() | |||
| GlobalComm.CHECK_ENVS = True | |||
| network = ModelParallelNet() | |||
| network.set_train() | |||
| @@ -29,6 +29,7 @@ from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| from mindspore.train import Model | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication._comm_helper import GlobalComm | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| device_number = 32 | |||
| @@ -124,7 +125,9 @@ class TrainOneStepCell(Cell): | |||
| def net_trains(criterion, rank): | |||
| GlobalComm.CHECK_ENVS = False | |||
| init() | |||
| GlobalComm.CHECK_ENVS = True | |||
| lr = 0.1 | |||
| momentum = 0.9 | |||
| max_epoch = 20 | |||
| @@ -24,7 +24,7 @@ from mindspore.nn import Momentum | |||
| from mindspore.nn import TrainOneStepCell, WithLossCell | |||
| from mindspore.ops import operations as P | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication._comm_helper import GlobalComm | |||
| class Net(nn.Cell): | |||
| def __init__(self, input_channel, out_channel): | |||
| @@ -47,7 +47,9 @@ def test_dense_gen_graph(): | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.HYBRID_PARALLEL, gradients_mean=True, device_num=8) | |||
| GlobalComm.CHECK_ENVS = False | |||
| init() | |||
| GlobalComm.CHECK_ENVS = True | |||
| network = Net(512, 128) | |||
| loss_fn = nn.SoftmaxCrossEntropyWithLogits() | |||
| @@ -21,6 +21,7 @@ from mindspore import Tensor | |||
| from mindspore import amp | |||
| from mindspore import nn | |||
| from mindspore.communication.management import init | |||
| from mindspore.communication._comm_helper import GlobalComm | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.train import Model | |||
| from ....dataset_mock import MindData | |||
| @@ -156,8 +157,8 @@ def test_compile_model_train_O2_parallel(): | |||
| net = NetNoLoss(16, 16) | |||
| loss = nn.MSELoss() | |||
| optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0) | |||
| GlobalComm.CHECK_ENVS = False | |||
| init() | |||
| GlobalComm.CHECK_ENVS = True | |||
| model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") | |||
| model.train(2, dataset, dataset_sink_mode=False) | |||
| @@ -19,9 +19,9 @@ import numpy as np | |||
| import mindspore.context as context | |||
| from mindspore.communication.management import init | |||
| from mindspore.train.dataset_helper import DatasetHelper | |||
| from mindspore.communication._comm_helper import GlobalComm | |||
| from ....dataset_mock import MindData | |||
| def get_dataset(batch_size=1): | |||
| dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32) | |||
| dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1), | |||
| @@ -75,7 +75,9 @@ def test_dataset_iter_normal(): | |||
| @pytest.mark.skipif('not context.get_context("enable_ge")') | |||
| def test_dataset_iter_ge(): | |||
| GlobalComm.CHECK_ENVS = False | |||
| init("hccl") | |||
| GlobalComm.CHECK_ENVS = True | |||
| dataset = get_dataset(32) | |||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | |||
| count = 0 | |||
| @@ -87,7 +89,9 @@ def test_dataset_iter_ge(): | |||
| @pytest.mark.skipif('context.get_context("enable_ge")') | |||
| def test_dataset_iter_ms_loop_sink(): | |||
| GlobalComm.CHECK_ENVS = False | |||
| init("hccl") | |||
| GlobalComm.CHECK_ENVS = True | |||
| context.set_context(enable_loop_sink=True) | |||
| dataset = get_dataset(32) | |||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | |||
| @@ -101,7 +105,9 @@ def test_dataset_iter_ms_loop_sink(): | |||
| @pytest.mark.skipif('context.get_context("enable_ge")') | |||
| def test_dataset_iter_ms(): | |||
| GlobalComm.CHECK_ENVS = False | |||
| init("hccl") | |||
| GlobalComm.CHECK_ENVS = True | |||
| context.set_context(enable_loop_sink=False) | |||
| dataset = get_dataset(32) | |||
| DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10) | |||