| @@ -23,7 +23,7 @@ from mindspore import log as logger | |||||
| from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_ | from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_ | ||||
| from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend | from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend | ||||
| from .tensor import Tensor as MsTensor | from .tensor import Tensor as MsTensor | ||||
| from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_tensor | |||||
| from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor | |||||
| from ..parallel._ps_context import _is_role_pserver | from ..parallel._ps_context import _is_role_pserver | ||||
| # store ms_function class compiled pipeline cache | # store ms_function class compiled pipeline cache | ||||
| ms_compile_cache = {} | ms_compile_cache = {} | ||||
| @@ -384,6 +384,7 @@ class _Executor: | |||||
| Bool, if the graph has been compiled before, return False, else return True. | Bool, if the graph has been compiled before, return False, else return True. | ||||
| """ | """ | ||||
| obj.check_names() | obj.check_names() | ||||
| _check_full_batch() | |||||
| args_names, args_list = _generate_pip_args(obj, *args) | args_names, args_list = _generate_pip_args(obj, *args) | ||||
| dic = dict(zip(args_names, args_list)) | dic = dict(zip(args_names, args_list)) | ||||
| key = generate_key(phase, dic) | key = generate_key(phase, dic) | ||||
| @@ -32,6 +32,18 @@ def _get_full_batch(): | |||||
| """Get whether to use full_batch.""" | """Get whether to use full_batch.""" | ||||
| return auto_parallel_context().get_full_batch() | return auto_parallel_context().get_full_batch() | ||||
| def _check_full_batch(): | |||||
| """ | |||||
| full_batch could only be used under semi_auto_parallel or auto_parallel, check it. | |||||
| Raises: | |||||
| RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel. | |||||
| """ | |||||
| parallel_mode = _get_parallel_mode() | |||||
| full_batch = _get_full_batch() | |||||
| if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch): | |||||
| raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.") | |||||
| def _need_to_full(): | def _need_to_full(): | ||||
| """Check whether to convert input to full shape or tensor.""" | """Check whether to convert input to full shape or tensor.""" | ||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| @@ -88,3 +89,22 @@ def test_all_to_all(): | |||||
| strategy1 = ((8, 1),) | strategy1 = ((8, 1),) | ||||
| _reset_op_id() | _reset_op_id() | ||||
| all_to_all_common(strategy1) | all_to_all_common(strategy1) | ||||
| def test_data_parallel_mode(): | |||||
| _reset_op_id() | |||||
| learning_rate = 0.1 | |||||
| momentum = 0.9 | |||||
| epoch_size = 2 | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False) | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, full_batch=True) | |||||
| predict = Tensor(np.ones([256, 128]), dtype=ms.float32) | |||||
| label = Tensor(np.ones([256]), dtype=ms.int32) | |||||
| dataset = Dataset(predict, label, 2) | |||||
| net = all_to_all_net(None) | |||||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||||
| opt = Momentum(net.trainable_params(), learning_rate, momentum) | |||||
| model = Model(net, loss, opt) | |||||
| with pytest.raises(RuntimeError): | |||||
| model.train(epoch_size, dataset, dataset_sink_mode=False) | |||||