Browse Source

update check strategy value

tags/v1.0.0
yangzhenzhang 5 years ago
parent
commit
048b88c41c
24 changed files with 64 additions and 252 deletions
  1. +5
    -26
      mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc
  2. +2
    -9
      mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc
  3. +2
    -6
      mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc
  4. +2
    -9
      mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc
  5. +2
    -9
      mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc
  6. +2
    -11
      mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc
  7. +2
    -8
      mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc
  8. +2
    -17
      mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc
  9. +1
    -7
      mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc
  10. +1
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc
  11. +2
    -8
      mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc
  12. +2
    -7
      mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc
  13. +1
    -1
      mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc
  14. +2
    -22
      mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc
  15. +19
    -17
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc
  16. +1
    -0
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
  17. +2
    -8
      mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc
  18. +2
    -16
      mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc
  19. +2
    -14
      mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc
  20. +2
    -7
      mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc
  21. +2
    -14
      mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc
  22. +2
    -13
      mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc
  23. +2
    -14
      mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc
  24. +2
    -7
      mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc

+ 5
- 26
mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc View File

@@ -30,26 +30,12 @@


namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
}

Status Activation::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED;
}
Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


return SUCCESS;
}
Status Activation::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }


Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) { Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy."; MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED; return FAILED;
} }
@@ -153,7 +139,7 @@ Status DropoutInfo::GenerateStrategies(int32_t stage_id) {
} }


Status Softmax::CheckStrategy(const StrategyPtr &strategy) { Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy."; MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED; return FAILED;
} }
@@ -229,14 +215,7 @@ Status Softmax::GetAttrs() {
return SUCCESS; return SUCCESS;
} }


Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
}
Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


Status Softmax::GenerateStrategies(int32_t stage_id) { Status Softmax::GenerateStrategies(int32_t stage_id) {
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {


+ 2
- 9
mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc View File

@@ -73,7 +73,7 @@ Strategys ExpendStrategy(const StrategyPtr &strategy) {
} }


Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy."; MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED; return FAILED;
} }
@@ -290,14 +290,7 @@ Status ArithmeticBase::InferTensorInfo() {
return SUCCESS; return SUCCESS;
} }


Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
}
Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


Status ArithmeticBase::GenerateStrategies(int32_t stage_id) { Status ArithmeticBase::GenerateStrategies(int32_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1); Shape input0_split(inputs_shape_[0].size(), 1);


+ 2
- 6
mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc View File

@@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy."; MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED; return FAILED;
} }
@@ -172,11 +172,7 @@ Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) {
} }


Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
return FAILED;
}
return SUCCESS;
return SetCostUnderStrategyBase(strategy);
} }


Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) { Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) {


+ 2
- 9
mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc View File

@@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy."; MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED; return FAILED;
} }
@@ -176,14 +176,7 @@ Status BiasAddInfo::InferTensorInfo() {
return SUCCESS; return SUCCESS;
} }


Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
}
Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


Status BiasAddInfo::GenerateStrategies(int32_t stage_id) { Status BiasAddInfo::GenerateStrategies(int32_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1); Shape input0_split(inputs_shape_[0].size(), 1);


+ 2
- 9
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc View File

@@ -60,7 +60,7 @@ Status ConcatInfo::GetAttrs() {


Status ConcatInfo::CheckStrategy(const StrategyPtr &strategy) { Status ConcatInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy); MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy"; MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED; return FAILED;
} }
@@ -197,14 +197,7 @@ void ConcatInfo::ReComputeBatchSplitFlagList() {
} }
} }


Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
}
Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


Status ConcatInfo::GenerateStrategies(int32_t stage_id) { Status ConcatInfo::GenerateStrategies(int32_t stage_id) {
if (InferAttrs() != SUCCESS) { if (InferAttrs() != SUCCESS) {


+ 2
- 11
mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc View File

@@ -50,11 +50,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) {


// only check the input[0] // only check the input[0]
Shapes input_shape = {inputs_shape_[0]}; Shapes input_shape = {inputs_shape_[0]};
if (CheckStrategyValue(strategy, input_shape, is_auto_parallel_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
return SUCCESS;
return CheckStrategyValue(strategy, input_shape);
} }


Status DropoutDoMaskInfo::InferDevMatrixShape() { Status DropoutDoMaskInfo::InferDevMatrixShape() {
@@ -125,12 +121,7 @@ Status DropoutDoMaskInfo::InferTensorInfo() {
} }


Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed";
return FAILED;
}

return SUCCESS;
return SetCostUnderStrategyBase(strategy);
} }


Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) {


+ 2
- 8
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_info.cc View File

@@ -82,7 +82,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }
// Only strategy of the first input should be set. // Only strategy of the first input should be set.
if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy."; MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED; return FAILED;
} }
@@ -301,13 +301,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) {
return SUCCESS; return SUCCESS;
} }


Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}
return SUCCESS;
}
Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() { std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() {
if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) {


+ 2
- 17
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc View File

@@ -213,12 +213,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) {
} }


Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Invalid strategy.";
} else {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
}
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED; return FAILED;
} }


@@ -716,17 +711,7 @@ Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) {
return SUCCESS; return SUCCESS;
} }


Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed.";
} else {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
}
return FAILED;
}
return SUCCESS;
}
Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) {
if (GetAttrs() != SUCCESS) { if (GetAttrs() != SUCCESS) {


+ 1
- 7
mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc View File

@@ -240,13 +240,7 @@ Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) {
return SUCCESS; return SUCCESS;
} }


Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
return FAILED;
}
return SUCCESS;
}
Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


Status GetNextInfo::GenerateStrategies(int32_t stage_id) { Status GetNextInfo::GenerateStrategies(int32_t stage_id) {
Strategys stra; Strategys stra;


+ 1
- 2
mindspore/ccsrc/frontend/parallel/ops_info/l2_normalize_info.cc View File

@@ -27,8 +27,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
MS_LOG(INFO) << name_ << " : Init success.";
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
return FAILED; return FAILED;
} }




+ 2
- 8
mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc View File

@@ -55,7 +55,7 @@ Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED; return FAILED;
} }


if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy value"; MS_LOG(ERROR) << name_ << ": Invalid strategy value";
return FAILED; return FAILED;
} }
@@ -207,13 +207,7 @@ Status LayerNormInfo::InferAsLossDivisor() {
return SUCCESS; return SUCCESS;
} }


Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Set cost failed";
return FAILED;
}
return SUCCESS;
}
Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector) { Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector) {
if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) { if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) {


+ 2
- 7
mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc View File

@@ -28,7 +28,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy."; MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED; return FAILED;
} }
@@ -200,12 +200,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) {
} }


Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
PrintStrategy(strategy);
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
return FAILED;
}
return SUCCESS;
return SetCostUnderStrategyBase(strategy);
} }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

+ 1
- 1
mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc View File

@@ -150,7 +150,7 @@ Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions
} }


Status MatMul::CheckStrategy(const StrategyPtr &strategy) { Status MatMul::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Invalid strategy."; MS_LOG(ERROR) << name_ << " : Invalid strategy.";
return FAILED; return FAILED;
} }


+ 2
- 22
mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc View File

@@ -55,21 +55,7 @@ Status OneHotInfo::GetAttrs() {
} }


Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) {
if (inputs_shape_.size() != 3) {
MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size();
return FAILED;
}
if (outputs_shape_.size() != 1) {
MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size();
return FAILED;
}
if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)},
is_auto_parallel_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED;
}

return SUCCESS;
return CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)});
} }


Status OneHotInfo::InferDevMatrixShape() { Status OneHotInfo::InferDevMatrixShape() {
@@ -278,13 +264,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) {
return SUCCESS; return SUCCESS;
} }


Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}
return SUCCESS;
}
Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() { std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() {
CheckGlobalDeviceManager(); CheckGlobalDeviceManager();


+ 19
- 17
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc View File

@@ -33,19 +33,21 @@


namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) {
Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) {
if (strategy == nullptr) { if (strategy == nullptr) {
MS_LOG(ERROR) << "The strategy is null.";
MS_LOG(ERROR) << name_ << ": The strategy is null.";
return FAILED; return FAILED;
} }


size_t strategy_size = strategy->GetInputNumber(); size_t strategy_size = strategy->GetInputNumber();
size_t inputs_shape_size = inputs_shape.size(); size_t inputs_shape_size = inputs_shape.size();
if (strategy_size != inputs_shape_size) { if (strategy_size != inputs_shape_size) {
if (is_auto_parallel) {
MS_LOG(DEBUG) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size;
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Strategy size: " << strategy_size
<< " is not equal to inputs size: " << inputs_shape_size;
} else { } else {
MS_LOG(ERROR) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size;
MS_LOG(ERROR) << name_ << ": Strategy size: " << strategy_size
<< " is not equal to inputs size: " << inputs_shape_size;
} }
return FAILED; return FAILED;
} }
@@ -57,11 +59,11 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
size_t strategy_len = sub_strategy.size(); size_t strategy_len = sub_strategy.size();
size_t inputs_len = sub_input_shape.size(); size_t inputs_len = sub_input_shape.size();
if (strategy_len != inputs_len) { if (strategy_len != inputs_len) {
if (is_auto_parallel) {
MS_LOG(DEBUG) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len
<< ", index: " << i; << ", index: " << i;
} else { } else {
MS_LOG(ERROR) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len
MS_LOG(ERROR) << name_ << ": Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len
<< ", index: " << i; << ", index: " << i;
} }
return FAILED; return FAILED;
@@ -70,29 +72,29 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap
for (size_t j = 0; j < strategy_len; ++j) { for (size_t j = 0; j < strategy_len; ++j) {
int64_t strategy_value = sub_strategy.at(j); int64_t strategy_value = sub_strategy.at(j);
if (strategy_value < MIN_SLICE_NUM) { if (strategy_value < MIN_SLICE_NUM) {
if (is_auto_parallel) {
MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value;
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Invalid strategy value: " << strategy_value;
} else { } else {
MS_LOG(ERROR) << "Invalid strategy value: " << strategy_value;
MS_LOG(ERROR) << name_ << ": Invalid strategy value: " << strategy_value;
} }
return FAILED; return FAILED;
} }


if ((IntToUint(strategy_value) & IntToUint(strategy_value - 1)) != 0) { if ((IntToUint(strategy_value) & IntToUint(strategy_value - 1)) != 0) {
if (is_auto_parallel) {
MS_LOG(DEBUG) << "Invalid Strategy value it is not the power of 2, " << strategy_value;
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Invalid Strategy value it is not the power of 2, " << strategy_value;
} else { } else {
MS_LOG(ERROR) << "Invalid Strategy value it is not the power of 2, " << strategy_value;
MS_LOG(ERROR) << name_ << ": Invalid Strategy value it is not the power of 2, " << strategy_value;
} }
return FAILED; return FAILED;
} }


int64_t shape_value = sub_input_shape.at(j); int64_t shape_value = sub_input_shape.at(j);
if ((shape_value % strategy_value) != 0) { if ((shape_value % strategy_value) != 0) {
if (is_auto_parallel) {
MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value;
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Shape " << shape_value << " cannot be divisible by strategy " << strategy_value;
} else { } else {
MS_LOG(ERROR) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value;
MS_LOG(ERROR) << name_ << ": Shape " << shape_value << " cannot be divisible by strategy " << strategy_value;
} }
return FAILED; return FAILED;
} }


+ 1
- 0
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h View File

@@ -176,6 +176,7 @@ class OperatorInfo {
virtual Status GetAttrs() = 0; virtual Status GetAttrs() = 0;
virtual Status InferTensorInfo() = 0; virtual Status InferTensorInfo() = 0;
virtual Status InferDevMatrixShape() = 0; virtual Status InferDevMatrixShape() = 0;
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
void SetDeviceListByStrategy(); void SetDeviceListByStrategy();
void SetRepeatedCalcDevMatrix(); void SetRepeatedCalcDevMatrix();
Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group); Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group);


+ 2
- 8
mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc View File

@@ -34,7 +34,7 @@ namespace parallel {
* the strategy of w should equal to the channel dimension of strategy of A, or equal to 1 * the strategy of w should equal to the channel dimension of strategy of A, or equal to 1
*/ */
Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy."; MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED; return FAILED;
} }
@@ -220,12 +220,6 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) {
return SUCCESS; return SUCCESS;
} }


Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}
return SUCCESS;
}
Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

+ 2
- 16
mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc View File

@@ -29,14 +29,7 @@


namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED;
}

return SUCCESS;
}
Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }


Status ReduceMethod::InferDevMatrixShape() { Status ReduceMethod::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
@@ -354,14 +347,7 @@ Status ReduceMethod::InferTensorInfo() {
return SUCCESS; return SUCCESS;
} }


Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
}
Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


Status ReduceMethod::GenerateStrategies(int32_t stage_id) { Status ReduceMethod::GenerateStrategies(int32_t stage_id) {
if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {


+ 2
- 14
mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc View File

@@ -29,14 +29,7 @@


namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED;
}

return SUCCESS;
}
Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }


/* /*
* support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of
@@ -394,12 +387,7 @@ Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) {
} }


Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
return SetCostUnderStrategyBase(strategy);
} }


void ReshapeInfo::SetCostForReshapeWithParameter() { void ReshapeInfo::SetCostForReshapeWithParameter() {


+ 2
- 7
mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc View File

@@ -98,7 +98,7 @@ Status StridedSliceInfo::GetAttrs() {


Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) { Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy); MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy"; MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED; return FAILED;
} }
@@ -232,12 +232,7 @@ std::shared_ptr<Strategys> StridedSliceInfo::GenerateBatchStrategies() {
} }


Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
return SetCostUnderStrategyBase(strategy);
} }


Status StridedSliceInfo::GenerateStrategies(int32_t stage_id) { Status StridedSliceInfo::GenerateStrategies(int32_t stage_id) {


+ 2
- 14
mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc View File

@@ -67,12 +67,7 @@ Status TileInfo::GetAttrs() {


Status TileInfo::CheckStrategy(const StrategyPtr &strategy) { Status TileInfo::CheckStrategy(const StrategyPtr &strategy) {
Shapes multiples = {full_multiples_}; Shapes multiples = {full_multiples_};
if (CheckStrategyValue(strategy, multiples, is_auto_parallel_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED;
}

return SUCCESS;
return CheckStrategyValue(strategy, multiples);
} }


Status TileInfo::InferDevMatrixShape() { Status TileInfo::InferDevMatrixShape() {
@@ -197,14 +192,7 @@ std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() {
return GenerateBatchStrategiesBySplitFlag(multiples_shape, split_flag_list_); return GenerateBatchStrategiesBySplitFlag(multiples_shape, split_flag_list_);
} }


Status TileInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
}
Status TileInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


Status TileInfo::GenerateStrategies(int32_t stage_id) { Status TileInfo::GenerateStrategies(int32_t stage_id) {
if (InferAttrs() != SUCCESS) { if (InferAttrs() != SUCCESS) {


+ 2
- 13
mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc View File

@@ -25,11 +25,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": invalid strategy.";
return FAILED;
}
return SUCCESS;
return CheckStrategyValue(strategy, inputs_shape_);
} }


Status TmpIdentityInfo::InferDevMatrixShape() { Status TmpIdentityInfo::InferDevMatrixShape() {
@@ -98,14 +94,7 @@ Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) {
return SUCCESS; return SUCCESS;
} }


Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
}
Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }


Status TmpIdentityInfo::GenerateStrategies(int32_t stage_id) { Status TmpIdentityInfo::GenerateStrategies(int32_t stage_id) {
if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) {


+ 2
- 14
mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc View File

@@ -27,14 +27,7 @@


namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED;
}

return SUCCESS;
}
Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); }


Status TransposeInfo::InferDevMatrixShape() { Status TransposeInfo::InferDevMatrixShape() {
Strategys stra = strategy_->GetInputDim(); Strategys stra = strategy_->GetInputDim();
@@ -195,12 +188,7 @@ Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) {
} }


Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
return SetCostUnderStrategyBase(strategy);
} }


Status TransposeInfo::GenerateStrategies(int32_t stage_id) { Status TransposeInfo::GenerateStrategies(int32_t stage_id) {


+ 2
- 7
mindspore/ccsrc/frontend/parallel/ops_info/virtual_dataset_info.cc View File

@@ -29,7 +29,7 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy."; MS_LOG(ERROR) << name_ << ": Invalid strategy.";
return FAILED; return FAILED;
} }
@@ -143,12 +143,7 @@ void VirtualDatasetInfo::ReComputeBatchSplitFlagList() {
} }


Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}

return SUCCESS;
return SetCostUnderStrategyBase(strategy);
} }


Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) {


Loading…
Cancel
Save