From eb6f4e3ce8dd632e5bb35d9b195bd57b114386f2 Mon Sep 17 00:00:00 2001 From: yangzhenzhang <285824651@qq.com> Date: Wed, 21 Oct 2020 14:46:05 +0800 Subject: [PATCH] update repeated calculation --- .../frontend/parallel/ops_info/matmul_info.cc | 8 +++-- .../frontend/parallel/ops_info/matmul_info.h | 1 + .../frontend/parallel/ops_info/onehot_info.cc | 2 +- .../parallel/ops_info/operator_info.cc | 17 +++++---- .../parallel/ops_info/operator_info.h | 4 ++- .../parallel/ops_info/reduce_method_info.cc | 36 +++++++++++++++++-- .../cpp/parallel/ops_info/onehot_info_test.cc | 2 +- .../ops_info/onehot_info_test_axis_0.cc | 2 +- 8 files changed, 56 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc index dd54b0ddd8..4b9d9c3445 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc @@ -205,6 +205,7 @@ Status MatMulBase::InferDevMatrixShape() { Dimensions mat_b_strategy = stra.at(1); SetDevMatrixShape(mat_a_strategy, mat_b_strategy, transpose_b_, &dev_matrix_shape_); + origin_dev_matrix_shape_ = dev_matrix_shape_; return SUCCESS; } @@ -236,10 +237,11 @@ Status MatMulBase::InferMirrorOps() { Status MatMulBase::InferForwardCommunication() { forward_op_.clear(); - size_t dimension = dev_matrix_shape_.size(); + size_t dimension = origin_dev_matrix_shape_.size(); size_t relevant_dimension_index = SECOND_FROM_END(dimension); - // Relevant dimension is not split and all reduce is not required - if (dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) { + // Relevant dimension is not split and all reduce is not required, + // need to use origin_dev_matrix_shape_ here, since the dev_matrix_shape_ will be changed if repeated calculation. + if (origin_dev_matrix_shape_.at(relevant_dimension_index) == MIN_SLICE_NUM) { MS_LOG(INFO) << name_ << " : Forward all reduce is not required."; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h index ad94014410..e83ae3493d 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h @@ -65,6 +65,7 @@ class MatMulBase : public OperatorInfo { int32_t field_size_ = 0; size_t mat_a_dimension_ = 0; size_t mat_b_dimension_ = 0; + Shape origin_dev_matrix_shape_; }; class MatMul : public MatMulBase { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc index ccef6274d9..617f34b059 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc @@ -74,7 +74,7 @@ Status OneHotInfo::InferDevMatrixShape() { dev_matrix_shape_.push_back(input_strategy[1]); // the depth is un-splittable } old_dev_matrix_back_ = dev_matrix_shape_.back(); - + repeated_num_in_dev_matrix_right_ = false; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index e65275e469..62283892c8 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -164,21 +164,24 @@ Status OperatorInfo::InferRepeatedCalcInfo() { return SUCCESS; } -// If repeated calculation, need to set the repeated_calc_num as the last dimension of dev-matrix, -// only use for infer tensor layout. Because if the previous shard is (a, b), and the next shard is -// (a, 1), adding the repeated_calc_num to the last dimension of dev-matrix, there is no need to redistribution. +// If repeated calculation, set the repeated_calc_num as the last dimension of dev-matrix in default, +// because if the previous shard is (a, b), and the next shard is (a, 1), adding the repeated_calc_num +// to the last dimension of dev-matrix, there is no need to redistribution. void OperatorInfo::SetRepeatedCalcDevMatrix() { if (repeated_calc_num_ <= 1) { return; } - - (void)dev_matrix_shape_.push_back(repeated_calc_num_); + if (repeated_num_in_dev_matrix_right_) { + dev_matrix_shape_.push_back(repeated_calc_num_); + } else { + (void)dev_matrix_shape_.insert(dev_matrix_shape_.begin(), repeated_calc_num_); + } } -// If repeated calculation, since the repeated_calc_num is added to the last dimension of the dev-matrix, +// If repeated calculation, and the repeated_calc_num is inserted to the last dimension of the dev-matrix, // the index value of tensor map needs to be increased by 1. void OperatorInfo::ResetTensorMapIfRepeatedCalc() { - if (repeated_calc_num_ <= 1) { + if ((repeated_calc_num_ <= 1) || !repeated_num_in_dev_matrix_right_) { return; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index 86a9b31cfd..0eff3e15f9 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -214,7 +214,7 @@ class OperatorInfo { StrategyPtr strategy_; std::vector inputs_tensor_info_; std::vector outputs_tensor_info_; - Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num as the first dimension + Shape dev_matrix_shape_; // if repeated calculation, it contains the repeated_calc_num_ int32_t repeated_calc_num_ = 1; int32_t as_loss_divisor_ = 1; TensorMaps inputs_tensor_map_; @@ -263,6 +263,8 @@ class OperatorInfo { std::string refkey_parameter_name_; CNodePtr cnode_; int32_t used_devices_ = -1; + // the repeated_calc_num_ will be inserted to the last dimension of dev matrix in default + bool repeated_num_in_dev_matrix_right_ = true; private: OperatorCostPtr operator_cost_; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc index 6f6e75b53f..236536baec 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -158,7 +158,10 @@ Status ReduceMethod::InferForwardCommunication() { size_t size = stra.size(); // judge if the reduce dim is partitioned. Shape group_creat_map; - if (dev_matrix_shape_.size() > size) { + + // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, + // it need to handle the first dimention of map. + if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); } for (size_t index = 0; index < size; ++index) { @@ -169,6 +172,18 @@ Status ReduceMethod::InferForwardCommunication() { } group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); } + + // if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix, + // it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map. + if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) { + for (auto &ele : group_creat_map) { + if (ele == MAP_NONE) { + continue; + } + ele += 1; + } + group_creat_map.push_back(0); + } std::vector forward_group; if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; @@ -220,9 +235,13 @@ Status ReduceMeanInfo::InferForwardCommunication() { size_t size = stra.size(); // judge if the reduce dim is partitioned. Shape group_creat_map; - if (dev_matrix_shape_.size() > size) { + + // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, + // it need to handle the first dimention of map. + if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); } + for (size_t index = 0; index < size; ++index) { auto pos = std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; }); @@ -231,6 +250,19 @@ Status ReduceMeanInfo::InferForwardCommunication() { } group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); } + + // if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix, + // it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map. + if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) { + for (auto &ele : group_creat_map) { + if (ele == MAP_NONE) { + continue; + } + ele += 1; + } + group_creat_map.push_back(0); + } + std::vector forward_group; if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; diff --git a/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc b/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc index e7527c134e..6efac9598b 100644 --- a/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc +++ b/tests/ut/cpp/parallel/ops_info/onehot_info_test.cc @@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape2) { ASSERT_EQ(status, SUCCESS); Shape dev_matrix_shape = onehot_info->dev_matrix_shape(); - Shape expect = {4, 1, 2}; + Shape expect = {2, 4, 1}; ASSERT_EQ(dev_matrix_shape, expect); } diff --git a/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc b/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc index 7cad3175d5..239a7299cd 100644 --- a/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc +++ b/tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc @@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape2) { ASSERT_EQ(status, SUCCESS); Shape dev_matrix_shape = onehot_info2->dev_matrix_shape(); - Shape expect = {4, 1, 2}; + Shape expect = {2, 4, 1}; ASSERT_EQ(dev_matrix_shape, expect); }