Merge pull request !7397 from yangzhenzhang/handle_repeated_calctags/v1.1.0
| @@ -350,7 +350,8 @@ Status GatherV2PInfo::InferDevMatrixShape() { | |||||
| auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>()); | auto param_product = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies<int>()); | ||||
| auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>()); | auto index_product = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies<int>()); | ||||
| if (param_product * index_product < SizeToInt(dev_num)) { | if (param_product * index_product < SizeToInt(dev_num)) { | ||||
| out_dev_matrix_shape_.insert(out_dev_matrix_shape_.begin(), SizeToInt(dev_num / (param_product * index_product))); | |||||
| // add the repeated calculation num to the last dimension of dev matrix | |||||
| out_dev_matrix_shape_.push_back(SizeToInt(dev_num / (param_product * index_product))); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -73,6 +73,7 @@ Status OneHotInfo::InferDevMatrixShape() { | |||||
| dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable | dev_matrix_shape_.push_back(input_strategy[0]); // the features is splittable | ||||
| dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable | dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable | ||||
| } | } | ||||
| old_dev_matrix_back_ = dev_matrix_shape_.back(); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -134,7 +135,7 @@ Status OneHotInfo::InferTensorInfo() { | |||||
| Status OneHotInfo::ExtractInputInfo() { | Status OneHotInfo::ExtractInputInfo() { | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| rank_ = g_device_manager->global_rank(); | rank_ = g_device_manager->global_rank(); | ||||
| mod_rank_ = rank_ % dev_matrix_shape_.back(); | |||||
| mod_rank_ = rank_ % old_dev_matrix_back_; | |||||
| if (!cnode_) { | if (!cnode_) { | ||||
| MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; | MS_LOG(ERROR) << "Failure:OneHot cnode_ is nullptr"; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -162,13 +163,13 @@ Status OneHotInfo::ExtractInputInfo() { | |||||
| MS_LOG(ERROR) << "OneHot Primitive depth type must be int"; | MS_LOG(ERROR) << "OneHot Primitive depth type must be int"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| classes_each_device_ = total_class_number_ / dev_matrix_shape_.back(); | |||||
| classes_each_device_ = total_class_number_ / old_dev_matrix_back_; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | ||||
| if (dev_matrix_shape_.back() == 1) { | |||||
| if (old_dev_matrix_back_ == 1) { | |||||
| replace_graph_ = nullptr; | replace_graph_ = nullptr; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -60,6 +60,7 @@ class OneHotInfo : public OperatorInfo { | |||||
| int32_t rank_ = 0; | int32_t rank_ = 0; | ||||
| int32_t total_class_number_ = 1; | int32_t total_class_number_ = 1; | ||||
| int32_t classes_each_device_ = 1; | int32_t classes_each_device_ = 1; | ||||
| int32_t old_dev_matrix_back_ = 1; | |||||
| ValuePtr axis_value_ptr_; | ValuePtr axis_value_ptr_; | ||||
| int32_t mod_rank_ = 0; | int32_t mod_rank_ = 0; | ||||
| }; | }; | ||||
| @@ -164,14 +164,42 @@ Status OperatorInfo::InferRepeatedCalcInfo() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| // if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix, | |||||
| // only use for infer tensor layout | |||||
| // If repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix, | |||||
| // only use for infer tensor layout. Because if the previous shard is (a, b), and the next shard is | |||||
| // (a, 1), adding the repeated_calc_num to the last dimension of dev-matrix, there is no need to redistribution. | |||||
| void OperatorInfo::SetRepeatedCalcDevMatrix() { | void OperatorInfo::SetRepeatedCalcDevMatrix() { | ||||
| if (repeated_calc_num_ <= 1) { | if (repeated_calc_num_ <= 1) { | ||||
| return; | return; | ||||
| } | } | ||||
| (void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_); | |||||
| (void)dev_matrix_shape_.push_back(repeated_calc_num_); | |||||
| } | |||||
| // If repeated calculation, since the repeated_calc_num is added to the last dimension of the dev-matrix, | |||||
| // the index value of tensor map needs to be increased by 1. | |||||
| void OperatorInfo::ResetTensorMapIfRepeatedCalc() { | |||||
| if (repeated_calc_num_ <= 1) { | |||||
| return; | |||||
| } | |||||
| MS_LOG(DEBUG) << name_ << ": the repeated calc num is " << repeated_calc_num_ << ", and reset the tensor maps"; | |||||
| for (auto &tensor_map : inputs_tensor_map_) { | |||||
| for (auto &element : tensor_map) { | |||||
| if (element == MAP_NONE) { | |||||
| continue; | |||||
| } | |||||
| element += 1; | |||||
| } | |||||
| } | |||||
| for (auto &tensor_map : outputs_tensor_map_) { | |||||
| for (auto &element : tensor_map) { | |||||
| if (element == MAP_NONE) { | |||||
| continue; | |||||
| } | |||||
| element += 1; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| // use for loss repeated calculation | // use for loss repeated calculation | ||||
| @@ -454,7 +482,7 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // if repeated calculation, need to set the repeated_calc_num as the first dimension of dev-matrix for layout | |||||
| // if repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix for layout | |||||
| SetRepeatedCalcDevMatrix(); | SetRepeatedCalcDevMatrix(); | ||||
| if (InferTensorMap() != SUCCESS) { | if (InferTensorMap() != SUCCESS) { | ||||
| @@ -462,6 +490,8 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| ResetTensorMapIfRepeatedCalc(); | |||||
| if (InferTensorInfo() != SUCCESS) { | if (InferTensorInfo() != SUCCESS) { | ||||
| MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; | MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; | ||||
| return FAILED; | return FAILED; | ||||
| @@ -184,6 +184,7 @@ class OperatorInfo { | |||||
| Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); | Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); | ||||
| void SetDeviceListByStrategy(); | void SetDeviceListByStrategy(); | ||||
| void SetRepeatedCalcDevMatrix(); | void SetRepeatedCalcDevMatrix(); | ||||
| void ResetTensorMapIfRepeatedCalc(); | |||||
| Status CreateGroupByDim(size_t axis, std::vector<Group> *group); | Status CreateGroupByDim(size_t axis, std::vector<Group> *group); | ||||
| Status InferAttrs(); | Status InferAttrs(); | ||||
| void ResetQueueMember(); | void ResetQueueMember(); | ||||
| @@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape2) { | |||||
| ASSERT_EQ(status, SUCCESS); | ASSERT_EQ(status, SUCCESS); | ||||
| Shape dev_matrix_shape = onehot_info->dev_matrix_shape(); | Shape dev_matrix_shape = onehot_info->dev_matrix_shape(); | ||||
| Shape expect = {2, 4, 1}; | |||||
| Shape expect = {4, 1, 2}; | |||||
| ASSERT_EQ(dev_matrix_shape, expect); | ASSERT_EQ(dev_matrix_shape, expect); | ||||
| } | } | ||||
| @@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape2) { | |||||
| ASSERT_EQ(status, SUCCESS); | ASSERT_EQ(status, SUCCESS); | ||||
| Shape dev_matrix_shape = onehot_info2->dev_matrix_shape(); | Shape dev_matrix_shape = onehot_info2->dev_matrix_shape(); | ||||
| Shape expect = {2, 4, 1}; | |||||
| Shape expect = {4, 1, 2}; | |||||
| ASSERT_EQ(dev_matrix_shape, expect); | ASSERT_EQ(dev_matrix_shape, expect); | ||||
| } | } | ||||
| @@ -70,7 +70,7 @@ TEST_F(TestPReLUInfo, InferDevMatrixShape1) { | |||||
| prelu->Init(strategy); | prelu->Init(strategy); | ||||
| Shape dev_matrix_shape = prelu->dev_matrix_shape(); | Shape dev_matrix_shape = prelu->dev_matrix_shape(); | ||||
| Shape expect = {4, 2, 1, 8, 16}; | |||||
| Shape expect = {2, 1, 8, 16, 4}; | |||||
| ASSERT_EQ(dev_matrix_shape, expect); | ASSERT_EQ(dev_matrix_shape, expect); | ||||
| } | } | ||||
| @@ -105,9 +105,9 @@ TEST_F(TestPReLUInfo, GetTensorLayout1) { | |||||
| std::vector<TensorInfo> inputs = prelu->inputs_tensor_info(); | std::vector<TensorInfo> inputs = prelu->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs = prelu->outputs_tensor_info(); | std::vector<TensorInfo> outputs = prelu->outputs_tensor_info(); | ||||
| TensorMap input_expect = {3, 2, 1, 0}; | |||||
| TensorMap input_expect = {4, 3, 2, 1}; | |||||
| TensorMap param_expect = {2}; | TensorMap param_expect = {2}; | ||||
| TensorMap output_expect = {3, 2, 1, 0}; | |||||
| TensorMap output_expect = {4, 3, 2, 1}; | |||||
| TensorInfo input_tensor_info = inputs.at(0); | TensorInfo input_tensor_info = inputs.at(0); | ||||
| TensorInfo param_tensor_info = inputs.at(1); | TensorInfo param_tensor_info = inputs.at(1); | ||||
| @@ -175,7 +175,7 @@ TEST_F(TestPReLUInfo, InferDevMatrixShape_2d1) { | |||||
| prelu_2d->Init(strategy); | prelu_2d->Init(strategy); | ||||
| Shape dev_matrix_shape = prelu_2d->dev_matrix_shape(); | Shape dev_matrix_shape = prelu_2d->dev_matrix_shape(); | ||||
| Shape expect = {8, 128, 1}; | |||||
| Shape expect = {128, 1, 8}; | |||||
| ASSERT_EQ(dev_matrix_shape, expect); | ASSERT_EQ(dev_matrix_shape, expect); | ||||
| } | } | ||||
| @@ -210,9 +210,9 @@ TEST_F(TestPReLUInfo, GetTensorLayout_2d1) { | |||||
| std::vector<TensorInfo> inputs = prelu_2d->inputs_tensor_info(); | std::vector<TensorInfo> inputs = prelu_2d->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs = prelu_2d->outputs_tensor_info(); | std::vector<TensorInfo> outputs = prelu_2d->outputs_tensor_info(); | ||||
| TensorMap input_expect = {1, 0}; | |||||
| TensorMap input_expect = {2, 1}; | |||||
| TensorMap param_expect = {0}; | TensorMap param_expect = {0}; | ||||
| TensorMap output_expect = {1, 0}; | |||||
| TensorMap output_expect = {2, 1}; | |||||
| TensorInfo input_tensor_info = inputs.at(0); | TensorInfo input_tensor_info = inputs.at(0); | ||||
| TensorInfo param_tensor_info = inputs.at(1); | TensorInfo param_tensor_info = inputs.at(1); | ||||
| @@ -74,7 +74,7 @@ TEST_F(TestReshapeInfo, InferDevMatrixShape1) { | |||||
| reshape->Init(strategy); | reshape->Init(strategy); | ||||
| Shape dev_matrix_shape = reshape->dev_matrix_shape(); | Shape dev_matrix_shape = reshape->dev_matrix_shape(); | ||||
| Shape expect = {8, 4}; | |||||
| Shape expect = {4, 8}; | |||||
| ASSERT_EQ(dev_matrix_shape, expect); | ASSERT_EQ(dev_matrix_shape, expect); | ||||
| } | } | ||||
| @@ -139,8 +139,8 @@ TEST_F(TestReshapeInfo, GetTensorLayout1) { | |||||
| std::vector<TensorInfo> inputs = reshape->inputs_tensor_info(); | std::vector<TensorInfo> inputs = reshape->inputs_tensor_info(); | ||||
| std::vector<TensorInfo> outputs = reshape->outputs_tensor_info(); | std::vector<TensorInfo> outputs = reshape->outputs_tensor_info(); | ||||
| TensorMap input_expect = {0, -1, -1, -1}; | |||||
| TensorMap output_expect = {0, -1}; | |||||
| TensorMap input_expect = {1, -1, -1, -1}; | |||||
| TensorMap output_expect = {1, -1}; | |||||
| TensorInfo input_tensor_info = inputs.at(0); | TensorInfo input_tensor_info = inputs.at(0); | ||||
| TensorInfo output_tensor_info = outputs.at(0); | TensorInfo output_tensor_info = outputs.at(0); | ||||
| @@ -85,7 +85,7 @@ TEST_F(TestTransposeInfo, InferDevMatrixShape2) { | |||||
| transpose->Init(strategy); | transpose->Init(strategy); | ||||
| Shape dev_matrix_shape = transpose->dev_matrix_shape(); | Shape dev_matrix_shape = transpose->dev_matrix_shape(); | ||||
| Shape expect = {8, 4, 1}; | |||||
| Shape expect = {4, 1, 8}; | |||||
| ASSERT_EQ(dev_matrix_shape, expect); | ASSERT_EQ(dev_matrix_shape, expect); | ||||
| } | } | ||||
| @@ -0,0 +1,107 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import numpy as np | |||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.common.api import _executor | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from tests.ut.python.ops.test_math_ops import VirtualLoss | |||||
| grad_all = C.GradOperation(get_all=True) | |||||
| class NetWithLoss(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(NetWithLoss, self).__init__() | |||||
| self.loss = VirtualLoss() | |||||
| self.network = network | |||||
| def construct(self, x, y, b): | |||||
| predict = self.network(x, y, b) | |||||
| return self.loss(predict) | |||||
| class GradWrap(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(GradWrap, self).__init__() | |||||
| self.network = network | |||||
| def construct(self, x, y, b): | |||||
| return grad_all(self.network)(x, y, b) | |||||
| def compile_net(net, x, y, b): | |||||
| net.set_auto_parallel() | |||||
| _executor.compile(net, x, y, b) | |||||
| # it has not redistribution | |||||
| def test_tensoradd_reshape_matmul(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2): | |||||
| super().__init__() | |||||
| self.add = P.TensorAdd().shard(strategy1) | |||||
| self.reshape = P.Reshape() | |||||
| self.matmul = P.MatMul().shard(strategy2) | |||||
| def construct(self, x, y, b): | |||||
| out = self.add(x, y) | |||||
| out = self.reshape(out, (256, 16)) | |||||
| out = self.matmul(out, b) | |||||
| return out | |||||
| context.set_auto_parallel_context(device_num=64, global_rank=0, gradients_mean=True) | |||||
| strategy1 = ((8, 1, 1), (8, 1, 1)) | |||||
| strategy2 = ((8, 1), (1, 8)) | |||||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||||
| context.set_context(save_graphs=True) | |||||
| x = Tensor(np.ones([32, 8, 16]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([32, 8, 16]), dtype=ms.float32) | |||||
| b = Tensor(np.ones([16, 16]), dtype=ms.float32) | |||||
| compile_net(net, x, y, b) | |||||
| def test_two_matmul(): | |||||
| class Net(nn.Cell): | |||||
| def __init__(self, strategy1, strategy2): | |||||
| super().__init__() | |||||
| self.matmul1 = P.MatMul().shard(strategy1) | |||||
| self.matmul2 = P.MatMul().shard(strategy2) | |||||
| def construct(self, x, y, b): | |||||
| out = self.matmul1(x, y) | |||||
| out = self.matmul2(out, b) | |||||
| return out | |||||
| context.set_auto_parallel_context(device_num=64, global_rank=0, gradients_mean=True) | |||||
| strategy1 = ((8, 8), (8, 1)) | |||||
| strategy2 = ((8, 1), (1, 1)) | |||||
| net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) | |||||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") | |||||
| context.set_context(save_graphs=True) | |||||
| x = Tensor(np.ones([128, 32]), dtype=ms.float32) | |||||
| y = Tensor(np.ones([32, 64]), dtype=ms.float32) | |||||
| b = Tensor(np.ones([64, 64]), dtype=ms.float32) | |||||
| compile_net(net, x, y, b) | |||||