Browse Source

!9577 support distributed predict

From: @gong_zi_yan
Reviewed-by: @caozhou_huawei,@yao_yf,@stsuteng,@zh_qh
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ec3983b77d
4 changed files with 180 additions and 3 deletions
  1. +69
    -1
      mindspore/parallel/_utils.py
  2. +37
    -1
      mindspore/train/model.py
  3. +1
    -1
      tests/ut/python/parallel/test_auto_parallel_resnet_predict.py
  4. +73
    -0
      tests/ut/python/parallel/test_distribute_predict.py

+ 69
- 1
mindspore/parallel/_utils.py View File

@@ -15,7 +15,8 @@
"""Utils of auto parallel""" """Utils of auto parallel"""


import numpy as np import numpy as np
from mindspore import log as logger
from mindspore import context, log as logger
from mindspore.context import ParallelMode
from mindspore._c_expression import reset_op_id from mindspore._c_expression import reset_op_id
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype from mindspore.common.dtype import dtype_to_nptype
@@ -193,3 +194,70 @@ def _get_python_op(op_name, op_path, instance_name, arglist):
def _reset_op_id(): def _reset_op_id():
"""Reset op id.""" """Reset op id."""
reset_op_id() reset_op_id()


def _parallel_predict_check():
"""validate parallel model prediction"""
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
if not context.get_auto_parallel_context("full_batch"):
raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')
if context.get_auto_parallel_context("enable_parallel_optimizer"):
raise RuntimeError('Model prediction does not support parallel optimizer. Please set'
'"enable_parallel_optimizer" with False.')


def _check_similar_layout(tensor_layout1, tensor_layout2):
"""check if two tensor layouts are same"""
if tensor_layout1[1] != tensor_layout2[1]:
return False
for i in tensor_layout1[1]:
if i == -1:
continue
if tensor_layout1[0][-1-i] != tensor_layout2[0][-1-i]:
return False
return True


def _remove_repeated_slices(tensor_layout):
"""generate unrepeated tensor layout"""
import copy
new_tensor_layout = copy.deepcopy(tensor_layout)
dev_mat = tensor_layout[0][:]
tensor_map = tensor_layout[1]
for dim in range(len(dev_mat)):
if dim not in tensor_map:
dev_mat[-1-dim] = 1
new_tensor_layout[0] = dev_mat
return new_tensor_layout


def _infer_rank_list(train_map, predict_map=None):
"""infer checkpoint slices to be loaded"""
ret = {}
for param_name in train_map:
train_layout = train_map[param_name]
new_train_layout = _remove_repeated_slices(train_layout)
train_dev_mat = train_layout[0]
dev_num = np.array(train_dev_mat).prod()
array = np.arange(dev_num).reshape(train_dev_mat)
index = ()
for i in new_train_layout[0]:
if i == 1:
index = index + (0,)
else:
index = index + (slice(None),)
rank_list = array[index].flatten()
if not predict_map:
ret[param_name] = rank_list
continue
if param_name not in predict_map:
logger.warning("predict_map does not contain %s", param_name)
continue
predict_layout = predict_map[param_name]
# optimization pass
if _check_similar_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = [rank_list[dev_rank]]
else:
ret[param_name] = rank_list
return ret

+ 37
- 1
mindspore/train/model.py View File

@@ -26,7 +26,7 @@ from .._checkparam import check_input_data, check_output_data, Validator
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
from .. import context from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check
from ..parallel._ps_context import _is_role_pserver, _is_role_sched from ..parallel._ps_context import _is_role_pserver, _is_role_sched
from ..nn.metrics import Loss from ..nn.metrics import Loss
from .. import nn from .. import nn
@@ -736,10 +736,46 @@ class Model:
""" """
self._predict_network.set_train(False) self._predict_network.set_train(False)
check_input_data(*predict_data, data_class=Tensor) check_input_data(*predict_data, data_class=Tensor)
_parallel_predict_check()
result = self._predict_network(*predict_data) result = self._predict_network(*predict_data)


check_output_data(result) check_output_data(result)
return result return result


def infer_predict_layout(self, *predict_data):
"""
Generate parameter layout for the predict network in auto or semi auto parallel mode.

Data could be a single tensor, a list of tensor, or a tuple of tensor.

Note:
Batch data should be put together in one tensor.

Args:
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.

Returns:
parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint

Examples:
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> model = Model(Net())
>>> model.infer_predict_layout(input_data)
"""
if context.get_context("mode") != context.GRAPH_MODE:
raise RuntimeError('infer predict layout only supports GRAPH MODE currently.')
# remove this restriction after support inferring repeated strategy
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
raise RuntimeError('infer predict layout only supports semi auto parallel and auto parallel mode.')
_parallel_predict_check()
check_input_data(*predict_data, data_class=Tensor)

predict_net = self._predict_network
# Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
predict_net.set_auto_parallel()
predict_net.set_train(False)
predict_net.compile(*predict_data)
return predict_net.parameter_layout_dict



__all__ = ["Model"] __all__ = ["Model"]

+ 1
- 1
tests/ut/python/parallel/test_auto_parallel_resnet_predict.py View File

@@ -26,7 +26,7 @@ init()


def test_train_32k_8p(batch_size=32, num_classes=32768): def test_train_32k_8p(batch_size=32, num_classes=32768):
dev_num = 8 dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num, full_batch=True)
set_algo_parameters(elementwise_op_strategy_follow=True) set_algo_parameters(elementwise_op_strategy_follow=True)
np.random.seed(6) np.random.seed(6)
input_np = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32)) input_np = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32))


+ 73
- 0
tests/ut/python/parallel/test_distribute_predict.py View File

@@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test distribute predict """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Model
from mindspore.ops import operations as P
from mindspore import context


class Net(nn.Cell):
"""Net definition"""
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Dense(128, 768, activation='relu')
self.fc2 = nn.Dense(128, 768, activation='relu')
self.fc3 = nn.Dense(128, 768, activation='relu')
self.fc4 = nn.Dense(768, 768, activation='relu')
self.relu4 = nn.ReLU()
self.relu5 = nn.ReLU()
self.transpose = P.Transpose()
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()

def construct(self, x):
q = self.fc1(x)
k = self.fc2(x)
v = self.fc3(x)
k = self.transpose(k, (1, 0))
c = self.relu4(self.matmul1(q, k))
s = self.relu5(self.matmul2(c, v))
s = self.fc4(s)
return s


def test_distribute_predict():
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
net = Net()
model = Model(net)
predict_map = model.infer_predict_layout(inputs)
output = model.predict(inputs)
context.reset_auto_parallel_context()
return predict_map, output


def test_edge_case():
context.set_context(mode=context.GRAPH_MODE)
inputs = Tensor(np.ones([32, 48]).astype(np.float32))
net = Net()
model = Model(net)
with pytest.raises(RuntimeError):
model.infer_predict_layout(inputs)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
with pytest.raises(RuntimeError):
model.infer_predict_layout(inputs)
context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True)
with pytest.raises(RuntimeError):
model.predict(inputs)

Loading…
Cancel
Save