Browse Source

parallel envs variable check

pull/15973/head
yao_yf 5 years ago
parent
commit
e967f1939b
11 changed files with 59 additions and 10 deletions
  1. +1
    -0
      mindspore/communication/_comm_helper.py
  2. +4
    -4
      mindspore/communication/_hccl_management.py
  3. +23
    -0
      mindspore/communication/management.py
  4. +3
    -0
      tests/ut/python/communication/test_comm.py
  5. +3
    -1
      tests/ut/python/parallel/test_auto_parallel_resnet.py
  6. +4
    -0
      tests/ut/python/parallel/test_auto_parallel_resnet_predict.py
  7. +5
    -1
      tests/ut/python/parallel/test_broadcast_dict.py
  8. +3
    -0
      tests/ut/python/parallel/test_gather_v2_primitive.py
  9. +3
    -1
      tests/ut/python/parallel/test_optimizer.py
  10. +3
    -2
      tests/ut/python/train/test_amp.py
  11. +7
    -1
      tests/ut/python/train/test_dataset_helper.py

+ 1
- 0
mindspore/communication/_comm_helper.py View File

@@ -84,6 +84,7 @@ class GlobalComm:
BACKEND = DEFAULT_BACKEND
WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
INITED = False
CHECK_ENVS = True

def is_hccl_available():
"""


+ 4
- 4
mindspore/communication/_hccl_management.py View File

@@ -24,7 +24,7 @@ HCCL_LIB_CTYPES = ""
def check_group(group):
"""
A function that check if a collection communication group is leagal.
A function that check if a collection communication group is legal.
Returns:
None
@@ -39,7 +39,7 @@ def check_group(group):
def check_rank_num(rank_num):
"""
A function that check if a collection communication rank number is leagal.If not raise error.
A function that check if a collection communication rank number is legal.If not raise error.
Returns:
None
@@ -53,7 +53,7 @@ def check_rank_num(rank_num):
def check_rank_id(rank_id):
"""
A function that check if a collection communication rank id is leagal.If not raise error.
A function that check if a collection communication rank id is legal.If not raise error.
Returns:
None
@@ -112,7 +112,7 @@ def create_group(group, rank_num, rank_ids):
c_group = c_str(group)
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
if ret != 0:
raise RuntimeError('Create group error.')
raise RuntimeError('Create group error, the error code is ', ret)
else:
raise TypeError('Rank ids must be a python list.')


+ 23
- 0
mindspore/communication/management.py View File

@@ -36,6 +36,28 @@ def _get_group(group):
return group


def _check_parallel_envs():
"""
Check whether parallel environment variables have been exported or not.

Raises:
RuntimeError: If parallel environment variables have not been exported or have been exported to wrong values.
"""
if not GlobalComm.CHECK_ENVS:
return
import os
rank_id_str = os.getenv("RANK_ID")
if not rank_id_str:
raise RuntimeError("Environment variables RANK_ID has not been exported")
try:
int(rank_id_str)
except ValueError:
print("RANK_ID should be number")
rank_table_file_str = os.getenv("MINDSPORE_HCCL_CONFIG_PATH")
rank_table_file_str_old = os.getenv("RANK_TABLE_FILE")
if not rank_table_file_str and not rank_table_file_str_old:
raise RuntimeError("Get hccl rank_table_file failed, "
"please export MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE.")

def init(backend_name=None):
"""
@@ -68,6 +90,7 @@ def init(backend_name=None):
if backend_name == "hccl":
if device_target != "Ascend":
raise RuntimeError("Device target should be 'Ascend' to init hccl, but got {}".format(device_target))
_check_parallel_envs()
init_hccl()
GlobalComm.BACKEND = Backend("hccl")
GlobalComm.WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP


+ 3
- 0
tests/ut/python/communication/test_comm.py View File

@@ -30,13 +30,16 @@ from mindspore.ops.operations.comm_ops import Broadcast, AllSwap
from mindspore.ops.operations.array_ops import Gather
import mindspore


# pylint: disable=W0212
# W0212: protected-access

tag = 0

context.set_context(device_target="Ascend")
GlobalComm.CHECK_ENVS = False
init("hccl")
GlobalComm.CHECK_ENVS = True


class AllReduceNet(nn.Cell):


+ 3
- 1
tests/ut/python/parallel/test_auto_parallel_resnet.py View File

@@ -31,11 +31,13 @@ from mindspore.parallel import set_algo_parameters
from mindspore.parallel._utils import _reset_op_id as resset_op_id
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True

def weight_variable():
return TruncatedNormal(0.02)


+ 4
- 0
tests/ut/python/parallel/test_auto_parallel_resnet_predict.py View File

@@ -18,11 +18,15 @@ from mindspore.communication.management import init
from mindspore.parallel import set_algo_parameters
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm
from .test_auto_parallel_resnet import resnet50


context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True

def test_train_32k_8p(batch_size=32, num_classes=32768):
dev_num = 8


+ 5
- 1
tests/ut/python/parallel/test_broadcast_dict.py View File

@@ -19,7 +19,7 @@ import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.communication.management import init
from mindspore.ops import operations as P
from mindspore.communication._comm_helper import GlobalComm

class DataParallelNet(nn.Cell):
def __init__(self):
@@ -49,7 +49,9 @@ def test_param_broadcast():
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=True)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
network = DataParallelNet()
network.set_train()

@@ -62,7 +64,9 @@ def test_param_not_broadcast():
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="data_parallel", parameter_broadcast=False)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
network = ModelParallelNet()
network.set_train()



+ 3
- 0
tests/ut/python/parallel/test_gather_v2_primitive.py View File

@@ -29,6 +29,7 @@ from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm

context.set_context(mode=context.GRAPH_MODE)
device_number = 32
@@ -124,7 +125,9 @@ class TrainOneStepCell(Cell):


def net_trains(criterion, rank):
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
lr = 0.1
momentum = 0.9
max_epoch = 20


+ 3
- 1
tests/ut/python/parallel/test_optimizer.py View File

@@ -24,7 +24,7 @@ from mindspore.nn import Momentum
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.ops import operations as P
from mindspore.context import ParallelMode
from mindspore.communication._comm_helper import GlobalComm

class Net(nn.Cell):
def __init__(self, input_channel, out_channel):
@@ -47,7 +47,9 @@ def test_dense_gen_graph():
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.HYBRID_PARALLEL, gradients_mean=True, device_num=8)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
network = Net(512, 128)

loss_fn = nn.SoftmaxCrossEntropyWithLogits()


+ 3
- 2
tests/ut/python/train/test_amp.py View File

@@ -21,6 +21,7 @@ from mindspore import Tensor
from mindspore import amp
from mindspore import nn
from mindspore.communication.management import init
from mindspore.communication._comm_helper import GlobalComm
from mindspore.context import ParallelMode
from mindspore.train import Model
from ....dataset_mock import MindData
@@ -156,8 +157,8 @@ def test_compile_model_train_O2_parallel():
net = NetNoLoss(16, 16)
loss = nn.MSELoss()
optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0)
GlobalComm.CHECK_ENVS = False
init()
GlobalComm.CHECK_ENVS = True
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2")
model.train(2, dataset, dataset_sink_mode=False)

+ 7
- 1
tests/ut/python/train/test_dataset_helper.py View File

@@ -19,9 +19,9 @@ import numpy as np
import mindspore.context as context
from mindspore.communication.management import init
from mindspore.train.dataset_helper import DatasetHelper
from mindspore.communication._comm_helper import GlobalComm
from ....dataset_mock import MindData


def get_dataset(batch_size=1):
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
@@ -75,7 +75,9 @@ def test_dataset_iter_normal():

@pytest.mark.skipif('not context.get_context("enable_ge")')
def test_dataset_iter_ge():
GlobalComm.CHECK_ENVS = False
init("hccl")
GlobalComm.CHECK_ENVS = True
dataset = get_dataset(32)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
count = 0
@@ -87,7 +89,9 @@ def test_dataset_iter_ge():

@pytest.mark.skipif('context.get_context("enable_ge")')
def test_dataset_iter_ms_loop_sink():
GlobalComm.CHECK_ENVS = False
init("hccl")
GlobalComm.CHECK_ENVS = True
context.set_context(enable_loop_sink=True)
dataset = get_dataset(32)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
@@ -101,7 +105,9 @@ def test_dataset_iter_ms_loop_sink():

@pytest.mark.skipif('context.get_context("enable_ge")')
def test_dataset_iter_ms():
GlobalComm.CHECK_ENVS = False
init("hccl")
GlobalComm.CHECK_ENVS = True
context.set_context(enable_loop_sink=False)
dataset = get_dataset(32)
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)

Loading…
Cancel
Save