From: @gong_zi_yan Reviewed-by: @caozhou_huawei,@yao_yf,@stsuteng,@zh_qh Signed-off-by: @stsutengtags/v1.1.0
| @@ -15,7 +15,8 @@ | |||||
| """Utils of auto parallel""" | """Utils of auto parallel""" | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import log as logger | |||||
| from mindspore import context, log as logger | |||||
| from mindspore.context import ParallelMode | |||||
| from mindspore._c_expression import reset_op_id | from mindspore._c_expression import reset_op_id | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.dtype import dtype_to_nptype | from mindspore.common.dtype import dtype_to_nptype | ||||
| @@ -193,3 +194,70 @@ def _get_python_op(op_name, op_path, instance_name, arglist): | |||||
| def _reset_op_id(): | def _reset_op_id(): | ||||
| """Reset op id.""" | """Reset op id.""" | ||||
| reset_op_id() | reset_op_id() | ||||
| def _parallel_predict_check(): | |||||
| """validate parallel model prediction""" | |||||
| if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
| if not context.get_auto_parallel_context("full_batch"): | |||||
| raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.') | |||||
| if context.get_auto_parallel_context("enable_parallel_optimizer"): | |||||
| raise RuntimeError('Model prediction does not support parallel optimizer. Please set' | |||||
| '"enable_parallel_optimizer" with False.') | |||||
| def _check_similar_layout(tensor_layout1, tensor_layout2): | |||||
| """check if two tensor layouts are same""" | |||||
| if tensor_layout1[1] != tensor_layout2[1]: | |||||
| return False | |||||
| for i in tensor_layout1[1]: | |||||
| if i == -1: | |||||
| continue | |||||
| if tensor_layout1[0][-1-i] != tensor_layout2[0][-1-i]: | |||||
| return False | |||||
| return True | |||||
| def _remove_repeated_slices(tensor_layout): | |||||
| """generate unrepeated tensor layout""" | |||||
| import copy | |||||
| new_tensor_layout = copy.deepcopy(tensor_layout) | |||||
| dev_mat = tensor_layout[0][:] | |||||
| tensor_map = tensor_layout[1] | |||||
| for dim in range(len(dev_mat)): | |||||
| if dim not in tensor_map: | |||||
| dev_mat[-1-dim] = 1 | |||||
| new_tensor_layout[0] = dev_mat | |||||
| return new_tensor_layout | |||||
| def _infer_rank_list(train_map, predict_map=None): | |||||
| """infer checkpoint slices to be loaded""" | |||||
| ret = {} | |||||
| for param_name in train_map: | |||||
| train_layout = train_map[param_name] | |||||
| new_train_layout = _remove_repeated_slices(train_layout) | |||||
| train_dev_mat = train_layout[0] | |||||
| dev_num = np.array(train_dev_mat).prod() | |||||
| array = np.arange(dev_num).reshape(train_dev_mat) | |||||
| index = () | |||||
| for i in new_train_layout[0]: | |||||
| if i == 1: | |||||
| index = index + (0,) | |||||
| else: | |||||
| index = index + (slice(None),) | |||||
| rank_list = array[index].flatten() | |||||
| if not predict_map: | |||||
| ret[param_name] = rank_list | |||||
| continue | |||||
| if param_name not in predict_map: | |||||
| logger.warning("predict_map does not contain %s", param_name) | |||||
| continue | |||||
| predict_layout = predict_map[param_name] | |||||
| # optimization pass | |||||
| if _check_similar_layout(train_layout, predict_layout): | |||||
| dev_rank = _get_global_rank() | |||||
| ret[param_name] = [rank_list[dev_rank]] | |||||
| else: | |||||
| ret[param_name] = rank_list | |||||
| return ret | |||||
| @@ -26,7 +26,7 @@ from .._checkparam import check_input_data, check_output_data, Validator | |||||
| from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback | from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback | ||||
| from .. import context | from .. import context | ||||
| from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ | ||||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check | |||||
| _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check | |||||
| from ..parallel._ps_context import _is_role_pserver, _is_role_sched | from ..parallel._ps_context import _is_role_pserver, _is_role_sched | ||||
| from ..nn.metrics import Loss | from ..nn.metrics import Loss | ||||
| from .. import nn | from .. import nn | ||||
| @@ -736,10 +736,46 @@ class Model: | |||||
| """ | """ | ||||
| self._predict_network.set_train(False) | self._predict_network.set_train(False) | ||||
| check_input_data(*predict_data, data_class=Tensor) | check_input_data(*predict_data, data_class=Tensor) | ||||
| _parallel_predict_check() | |||||
| result = self._predict_network(*predict_data) | result = self._predict_network(*predict_data) | ||||
| check_output_data(result) | check_output_data(result) | ||||
| return result | return result | ||||
| def infer_predict_layout(self, *predict_data): | |||||
| """ | |||||
| Generate parameter layout for the predict network in auto or semi auto parallel mode. | |||||
| Data could be a single tensor, a list of tensor, or a tuple of tensor. | |||||
| Note: | |||||
| Batch data should be put together in one tensor. | |||||
| Args: | |||||
| predict_data (Tensor): Tensor of predict data. can be array, list or tuple. | |||||
| Returns: | |||||
| parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint | |||||
| Examples: | |||||
| >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) | |||||
| >>> model = Model(Net()) | |||||
| >>> model.infer_predict_layout(input_data) | |||||
| """ | |||||
| if context.get_context("mode") != context.GRAPH_MODE: | |||||
| raise RuntimeError('infer predict layout only supports GRAPH MODE currently.') | |||||
| # remove this restriction after support inferring repeated strategy | |||||
| if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
| raise RuntimeError('infer predict layout only supports semi auto parallel and auto parallel mode.') | |||||
| _parallel_predict_check() | |||||
| check_input_data(*predict_data, data_class=Tensor) | |||||
| predict_net = self._predict_network | |||||
| # Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set | |||||
| predict_net.set_auto_parallel() | |||||
| predict_net.set_train(False) | |||||
| predict_net.compile(*predict_data) | |||||
| return predict_net.parameter_layout_dict | |||||
| __all__ = ["Model"] | __all__ = ["Model"] | ||||
| @@ -26,7 +26,7 @@ init() | |||||
| def test_train_32k_8p(batch_size=32, num_classes=32768): | def test_train_32k_8p(batch_size=32, num_classes=32768): | ||||
| dev_num = 8 | dev_num = 8 | ||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) | |||||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num, full_batch=True) | |||||
| set_algo_parameters(elementwise_op_strategy_follow=True) | set_algo_parameters(elementwise_op_strategy_follow=True) | ||||
| np.random.seed(6) | np.random.seed(6) | ||||
| input_np = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32)) | input_np = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32)) | ||||
| @@ -0,0 +1,73 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test distribute predict """ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor, Model | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore import context | |||||
| class Net(nn.Cell): | |||||
| """Net definition""" | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.fc1 = nn.Dense(128, 768, activation='relu') | |||||
| self.fc2 = nn.Dense(128, 768, activation='relu') | |||||
| self.fc3 = nn.Dense(128, 768, activation='relu') | |||||
| self.fc4 = nn.Dense(768, 768, activation='relu') | |||||
| self.relu4 = nn.ReLU() | |||||
| self.relu5 = nn.ReLU() | |||||
| self.transpose = P.Transpose() | |||||
| self.matmul1 = P.MatMul() | |||||
| self.matmul2 = P.MatMul() | |||||
| def construct(self, x): | |||||
| q = self.fc1(x) | |||||
| k = self.fc2(x) | |||||
| v = self.fc3(x) | |||||
| k = self.transpose(k, (1, 0)) | |||||
| c = self.relu4(self.matmul1(q, k)) | |||||
| s = self.relu5(self.matmul2(c, v)) | |||||
| s = self.fc4(s) | |||||
| return s | |||||
| def test_distribute_predict(): | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True) | |||||
| inputs = Tensor(np.ones([32, 128]).astype(np.float32)) | |||||
| net = Net() | |||||
| model = Model(net) | |||||
| predict_map = model.infer_predict_layout(inputs) | |||||
| output = model.predict(inputs) | |||||
| context.reset_auto_parallel_context() | |||||
| return predict_map, output | |||||
| def test_edge_case(): | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| inputs = Tensor(np.ones([32, 48]).astype(np.float32)) | |||||
| net = Net() | |||||
| model = Model(net) | |||||
| with pytest.raises(RuntimeError): | |||||
| model.infer_predict_layout(inputs) | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||||
| with pytest.raises(RuntimeError): | |||||
| model.infer_predict_layout(inputs) | |||||
| context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True) | |||||
| with pytest.raises(RuntimeError): | |||||
| model.predict(inputs) | |||||