From fc4ed975c404207be051cecb7f27856d0b37c6f7 Mon Sep 17 00:00:00 2001 From: yangzhenzhang <285824651@qq.com> Date: Fri, 16 Oct 2020 15:57:56 +0800 Subject: [PATCH] handle repeated calculation --- .../parallel/ops_info/gather_v2_p_info.cc | 3 +- .../frontend/parallel/ops_info/onehot_info.cc | 7 +- .../frontend/parallel/ops_info/onehot_info.h | 1 + .../parallel/ops_info/operator_info.cc | 38 ++++++- .../parallel/ops_info/operator_info.h | 1 + .../cpp/parallel/ops_info/onehot_info_test.cc | 2 +- .../ops_info/onehot_info_test_axis_0.cc | 2 +- tests/ut/cpp/parallel/ops_info/prelu_test.cc | 12 +- .../ut/cpp/parallel/ops_info/reshape_test.cc | 6 +- .../cpp/parallel/ops_info/transpose_test.cc | 2 +- .../ut/python/parallel/test_repeated_calc.py | 107 ++++++++++++++++++ 11 files changed, 161 insertions(+), 20 deletions(-) create mode 100644 tests/ut/python/parallel/test_repeated_calc.py diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index 313616fb30..9ad640189b 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -350,7 +350,8 @@ Status GatherV2PInfo::InferDevMatrixShape() { auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); if (param_product * index_product < SizeToInt(dev_num)) { - out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), SizeToInt(dev_num / (param_product * index_product))); + // add the repeated calculation num to the last dimension of dev matrix + out_dev_matrix_shape_.push_back(SizeToInt(dev_num / (param_product * index_product))); } return SUCCESS; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc index a9c48d6351..ccef6274d9 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc @@ -73,6 +73,7 @@ Status OneHotInfo::InferDevMatrixShape() { dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable } + old_dev_matrix_back_ = dev_matrix_shape_.back(); return SUCCESS; } @@ -134,7 +135,7 @@ Status OneHotInfo::InferTensorInfo() { Status OneHotInfo::ExtractInputInfo() { CheckGlobalDeviceManager(); rank_ = g_device_manager->global_rank(); - mod_rank_ = rank_ % dev_matrix_shape_.back(); + mod_rank_ = rank_ % old_dev_matrix_back_; if (!cnode_) { MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; return FAILED; @@ -162,13 +163,13 @@ Status OneHotInfo::ExtractInputInfo() { MS_LOG(ERROR) << "OneHot Primitive depth type must be int"; return FAILED; } - classes_each_device_ = total_class_number_ / dev_matrix_shape_.back(); + classes_each_device_ = total_class_number_ / old_dev_matrix_back_; return SUCCESS; } Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { - if (dev_matrix_shape_.back() == 1) { + if (old_dev_matrix_back_ == 1) { replace_graph_ = nullptr; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h index 362c5a57a3..5d65ff0f63 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h @@ -60,6 +60,7 @@ class OneHotInfo : public OperatorInfo { int32_t rank_ = 0; int32_t total_class_number_ = 1; int32_t classes_each_device_ = 1; + int32_t old_dev_matrix_back_ = 1; ValuePtr axis_value_ptr_; int32_t mod_rank_ = 0; }; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index a0ccac6867..e65275e469 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -164,14 +164,42 @@ Status OperatorInfo::InferRepeatedCalcInfo() { return SUCCESS; } -// if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix, -// only use for infer tensor layout +// If repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix, +// only use for infer tensor layout. Because if the previous shard is (a, b), and the next shard is +// (a, 1), adding the repeated_calc_num to the last dimension of dev-matrix, there is no need to redistribution. void OperatorInfo::SetRepeatedCalcDevMatrix() { if (repeated_calc_num_ <= 1) { return; } - (void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_); + (void)dev_matrix_shape_.push_back(repeated_calc_num_); +} + +// If repeated calculation, since the repeated_calc_num is added to the last dimension of the dev-matrix, +// the index value of tensor map needs to be increased by 1. +void OperatorInfo::ResetTensorMapIfRepeatedCalc() { + if (repeated_calc_num_ <= 1) { + return; + } + + MS_LOG(DEBUG) << name_ << ": the repeated calc num is " << repeated_calc_num_ << ", and reset the tensor maps"; + for (auto &tensor_map : inputs_tensor_map_) { + for (auto &element : tensor_map) { + if (element == MAP_NONE) { + continue; + } + element += 1; + } + } + + for (auto &tensor_map : outputs_tensor_map_) { + for (auto &element : tensor_map) { + if (element == MAP_NONE) { + continue; + } + element += 1; + } + } } // use for loss repeated calculation @@ -454,7 +482,7 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat return FAILED; } - // if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix for layout + // if repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix for layout SetRepeatedCalcDevMatrix(); if (InferTensorMap() != SUCCESS) { @@ -462,6 +490,8 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat return FAILED; } + ResetTensorMapIfRepeatedCalc(); + if (InferTensorInfo() != SUCCESS) { MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; return FAILED; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index e69b405217..86a9b31cfd 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -184,6 +184,7 @@ class OperatorInfo { Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); void SetDeviceListByStrategy(); void SetRepeatedCalcDevMatrix(); + void ResetTensorMapIfRepeatedCalc(); Status CreateGroupByDim(size_t axis, std::vector *group); Status InferAttrs(); void ResetQueueMember(); diff --git a/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc b/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc index 6efac9598b..e7527c134e 100644 --- a/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc @@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape2) { ASSERT_EQ(status, SUCCESS); Shape dev_matrix_shape = onehot_info->dev_matrix_shape(); - Shape expect = {2, 4, 1}; + Shape expect = {4, 1, 2}; ASSERT_EQ(dev_matrix_shape, expect); } diff --git a/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc b/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc index 239a7299cd..7cad3175d5 100644 --- a/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc +++ b/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc @@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape2) { ASSERT_EQ(status, SUCCESS); Shape dev_matrix_shape = onehot_info2->dev_matrix_shape(); - Shape expect = {2, 4, 1}; + Shape expect = {4, 1, 2}; ASSERT_EQ(dev_matrix_shape, expect); } diff --git a/tests/ut/cpp/parallel/ops_info/prelu_test.cc b/tests/ut/cpp/parallel/ops_info/prelu_test.cc index b92392234e..65410f45ff 100644 --- a/tests/ut/cpp/parallel/ops_info/prelu_test.cc +++ b/tests/ut/cpp/parallel/ops_info/prelu_test.cc @@ -70,7 +70,7 @@ TEST_F(TestPReLUInfo, InferDevMatrixShape1) { prelu->Init(strategy); Shape dev_matrix_shape = prelu->dev_matrix_shape(); - Shape expect = {4, 2, 1, 8, 16}; + Shape expect = {2, 1, 8, 16, 4}; ASSERT_EQ(dev_matrix_shape, expect); } @@ -105,9 +105,9 @@ TEST_F(TestPReLUInfo, GetTensorLayout1) { std::vector inputs = prelu->inputs_tensor_info(); std::vector outputs = prelu->outputs_tensor_info(); - TensorMap input_expect = {3, 2, 1, 0}; + TensorMap input_expect = {4, 3, 2, 1}; TensorMap param_expect = {2}; - TensorMap output_expect = {3, 2, 1, 0}; + TensorMap output_expect = {4, 3, 2, 1}; TensorInfo input_tensor_info = inputs.at(0); TensorInfo param_tensor_info = inputs.at(1); @@ -175,7 +175,7 @@ TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) { prelu_2d->Init(strategy); Shape dev_matrix_shape = prelu_2d->dev_matrix_shape(); - Shape expect = {8, 128, 1}; + Shape expect = {128, 1, 8}; ASSERT_EQ(dev_matrix_shape, expect); } @@ -210,9 +210,9 @@ TEST_F(TestPReLUInfo, GetTensorLayout_2d1) { std::vector inputs = prelu_2d->inputs_tensor_info(); std::vector outputs = prelu_2d->outputs_tensor_info(); - TensorMap input_expect = {1, 0}; + TensorMap input_expect = {2, 1}; TensorMap param_expect = {0}; - TensorMap output_expect = {1, 0}; + TensorMap output_expect = {2, 1}; TensorInfo input_tensor_info = inputs.at(0); TensorInfo param_tensor_info = inputs.at(1); diff --git a/tests/ut/cpp/parallel/ops_info/reshape_test.cc b/tests/ut/cpp/parallel/ops_info/reshape_test.cc index 71c793cf56..2818e6c403 100644 --- a/tests/ut/cpp/parallel/ops_info/reshape_test.cc +++ b/tests/ut/cpp/parallel/ops_info/reshape_test.cc @@ -74,7 +74,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape1) { reshape->Init(strategy); Shape dev_matrix_shape = reshape->dev_matrix_shape(); - Shape expect = {8, 4}; + Shape expect = {4, 8}; ASSERT_EQ(dev_matrix_shape, expect); } @@ -139,8 +139,8 @@ TEST_F(TestReshapeInfo, GetTensorLayout1) { std::vector inputs = reshape->inputs_tensor_info(); std::vector outputs = reshape->outputs_tensor_info(); - TensorMap input_expect = {0, -1, -1, -1}; - TensorMap output_expect = {0, -1}; + TensorMap input_expect = {1, -1, -1, -1}; + TensorMap output_expect = {1, -1}; TensorInfo input_tensor_info = inputs.at(0); TensorInfo output_tensor_info = outputs.at(0); diff --git a/tests/ut/cpp/parallel/ops_info/transpose_test.cc b/tests/ut/cpp/parallel/ops_info/transpose_test.cc index 149e49e854..306f04c92a 100644 --- a/tests/ut/cpp/parallel/ops_info/transpose_test.cc +++ b/tests/ut/cpp/parallel/ops_info/transpose_test.cc @@ -85,7 +85,7 @@ TEST_F(TestTransposeInfo, InferDevMatrixShape2) { transpose->Init(strategy); Shape dev_matrix_shape = transpose->dev_matrix_shape(); - Shape expect = {8, 4, 1}; + Shape expect = {4, 1, 8}; ASSERT_EQ(dev_matrix_shape, expect); } diff --git a/tests/ut/python/parallel/test_repeated_calc.py b/tests/ut/python/parallel/test_repeated_calc.py new file mode 100644 index 0000000000..5ce133a759 --- /dev/null +++ b/tests/ut/python/parallel/test_repeated_calc.py @@ -0,0 +1,107 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.common.api import _executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +grad_all = C.GradOperation(get_all=True) + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y, b): + predict = self.network(x, y, b) + return self.loss(predict) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y, b): + return grad_all(self.network)(x, y, b) + + +def compile_net(net, x, y, b): + net.set_auto_parallel() + _executor.compile(net, x, y, b) + + +# it has not redistribution +def test_tensoradd_reshape_matmul(): + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.add = P.TensorAdd().shard(strategy1) + self.reshape = P.Reshape() + self.matmul = P.MatMul().shard(strategy2) + + def construct(self, x, y, b): + out = self.add(x, y) + out = self.reshape(out, (256, 16)) + out = self.matmul(out, b) + return out + + context.set_auto_parallel_context(device_num=64, global_rank=0, gradients_mean=True) + strategy1 = ((8, 1, 1), (8, 1, 1)) + strategy2 = ((8, 1), (1, 8)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + context.set_context(save_graphs=True) + + x = Tensor(np.ones([32, 8, 16]), dtype=ms.float32) + y = Tensor(np.ones([32, 8, 16]), dtype=ms.float32) + b = Tensor(np.ones([16, 16]), dtype=ms.float32) + + compile_net(net, x, y, b) + + +def test_two_matmul(): + class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.matmul1 = P.MatMul().shard(strategy1) + self.matmul2 = P.MatMul().shard(strategy2) + + def construct(self, x, y, b): + out = self.matmul1(x, y) + out = self.matmul2(out, b) + return out + + context.set_auto_parallel_context(device_num=64, global_rank=0, gradients_mean=True) + strategy1 = ((8, 8), (8, 1)) + strategy2 = ((8, 1), (1, 1)) + net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + context.set_context(save_graphs=True) + + x = Tensor(np.ones([128, 32]), dtype=ms.float32) + y = Tensor(np.ones([32, 64]), dtype=ms.float32) + b = Tensor(np.ones([64, 64]), dtype=ms.float32) + + compile_net(net, x, y, b)