| @@ -30,26 +30,12 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status Activation::CheckStrategy(const StrategyPtr &strategy) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Invalid strategy."; | |||||
| return FAILED; | |||||
| } | |||||
| Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status Activation::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } | |||||
| Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) { | Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Invalid strategy."; | MS_LOG(ERROR) << name_ << " : Invalid strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -153,7 +139,7 @@ Status DropoutInfo::GenerateStrategies(int32_t stage_id) { | |||||
| } | } | ||||
| Status Softmax::CheckStrategy(const StrategyPtr &strategy) { | Status Softmax::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Invalid strategy."; | MS_LOG(ERROR) << name_ << " : Invalid strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -229,14 +215,7 @@ Status Softmax::GetAttrs() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status Softmax::GenerateStrategies(int32_t stage_id) { | Status Softmax::GenerateStrategies(int32_t stage_id) { | ||||
| if (GetAttrs() != SUCCESS) { | if (GetAttrs() != SUCCESS) { | ||||
| @@ -73,7 +73,7 @@ Strategys ExpendStrategy(const StrategyPtr &strategy) { | |||||
| } | } | ||||
| Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { | Status ArithmeticBase::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Invalid strategy."; | MS_LOG(ERROR) << name_ << " : Invalid strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -290,14 +290,7 @@ Status ArithmeticBase::InferTensorInfo() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ArithmeticBase::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status ArithmeticBase::GenerateStrategies(int32_t stage_id) { | Status ArithmeticBase::GenerateStrategies(int32_t stage_id) { | ||||
| Shape input0_split(inputs_shape_[0].size(), 1); | Shape input0_split(inputs_shape_[0].size(), 1); | ||||
| @@ -27,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { | Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Invalid strategy."; | MS_LOG(ERROR) << name_ << " : Invalid strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -172,11 +172,7 @@ Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) { | |||||
| } | } | ||||
| Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | ||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return SetCostUnderStrategyBase(strategy); | |||||
| } | } | ||||
| Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) { | Status BatchParallelInfo::GenerateStrategies(int32_t stage_id) { | ||||
| @@ -27,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { | Status BiasAddInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Invalid strategy."; | MS_LOG(ERROR) << name_ << " : Invalid strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -176,14 +176,7 @@ Status BiasAddInfo::InferTensorInfo() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status BiasAddInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status BiasAddInfo::GenerateStrategies(int32_t stage_id) { | Status BiasAddInfo::GenerateStrategies(int32_t stage_id) { | ||||
| Shape input0_split(inputs_shape_[0].size(), 1); | Shape input0_split(inputs_shape_[0].size(), 1); | ||||
| @@ -60,7 +60,7 @@ Status ConcatInfo::GetAttrs() { | |||||
| Status ConcatInfo::CheckStrategy(const StrategyPtr &strategy) { | Status ConcatInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| MS_EXCEPTION_IF_NULL(strategy); | MS_EXCEPTION_IF_NULL(strategy); | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy"; | MS_LOG(ERROR) << name_ << ": Invalid strategy"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -197,14 +197,7 @@ void ConcatInfo::ReComputeBatchSplitFlagList() { | |||||
| } | } | ||||
| } | } | ||||
| Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status ConcatInfo::GenerateStrategies(int32_t stage_id) { | Status ConcatInfo::GenerateStrategies(int32_t stage_id) { | ||||
| if (InferAttrs() != SUCCESS) { | if (InferAttrs() != SUCCESS) { | ||||
| @@ -50,11 +50,7 @@ Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| // only check the input[0] | // only check the input[0] | ||||
| Shapes input_shape = {inputs_shape_[0]}; | Shapes input_shape = {inputs_shape_[0]}; | ||||
| if (CheckStrategyValue(strategy, input_shape, is_auto_parallel_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy"; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return CheckStrategyValue(strategy, input_shape); | |||||
| } | } | ||||
| Status DropoutDoMaskInfo::InferDevMatrixShape() { | Status DropoutDoMaskInfo::InferDevMatrixShape() { | ||||
| @@ -125,12 +121,7 @@ Status DropoutDoMaskInfo::InferTensorInfo() { | |||||
| } | } | ||||
| Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | ||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed"; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return SetCostUnderStrategyBase(strategy); | |||||
| } | } | ||||
| Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { | Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { | ||||
| @@ -82,7 +82,7 @@ Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // Only strategy of the first input should be set. | // Only strategy of the first input should be set. | ||||
| if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, {inputs_shape_.at(0)}) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | MS_LOG(ERROR) << name_ << ": Invalid strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -301,13 +301,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() { | std::shared_ptr<Strategys> GatherV2Info::GenerateBatchStrategies() { | ||||
| if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { | ||||
| @@ -213,12 +213,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) { | |||||
| } | } | ||||
| Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { | Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Invalid strategy."; | |||||
| } else { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||||
| } | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -716,17 +711,7 @@ Status GatherV2PInfo::InitForCostModel(const StrategyPtr &strategy) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; | |||||
| } else { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| } | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GatherV2PInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { | Status GatherV2PInfo::GenerateStrategies(int32_t stage_id) { | ||||
| if (GetAttrs() != SUCCESS) { | if (GetAttrs() != SUCCESS) { | ||||
| @@ -240,13 +240,7 @@ Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status GetNextInfo::GenerateStrategies(int32_t stage_id) { | Status GetNextInfo::GenerateStrategies(int32_t stage_id) { | ||||
| Strategys stra; | Strategys stra; | ||||
| @@ -27,8 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { | Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| MS_LOG(INFO) << name_ << " : Init success."; | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -55,7 +55,7 @@ Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy value"; | MS_LOG(ERROR) << name_ << ": Invalid strategy value"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -207,13 +207,7 @@ Status LayerNormInfo::InferAsLossDivisor() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Set cost failed"; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector) { | Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector) { | ||||
| if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) { | if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) { | ||||
| @@ -28,7 +28,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { | Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Invalid strategy."; | MS_LOG(ERROR) << name_ << " : Invalid strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -200,12 +200,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { | |||||
| } | } | ||||
| Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | ||||
| PrintStrategy(strategy); | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return SetCostUnderStrategyBase(strategy); | |||||
| } | } | ||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -150,7 +150,7 @@ Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions | |||||
| } | } | ||||
| Status MatMul::CheckStrategy(const StrategyPtr &strategy) { | Status MatMul::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << " : Invalid strategy."; | MS_LOG(ERROR) << name_ << " : Invalid strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -55,21 +55,7 @@ Status OneHotInfo::GetAttrs() { | |||||
| } | } | ||||
| Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { | Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (inputs_shape_.size() != 3) { | |||||
| MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| if (outputs_shape_.size() != 1) { | |||||
| MS_LOG(ERROR) << name_ << ": outputs_shape_ size must be 1, but is " << outputs_shape_.size(); | |||||
| return FAILED; | |||||
| } | |||||
| if (CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}, | |||||
| is_auto_parallel_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return CheckStrategyValue(strategy, {outputs_shape_.at(0), inputs_shape_.at(1), inputs_shape_.at(2)}); | |||||
| } | } | ||||
| Status OneHotInfo::InferDevMatrixShape() { | Status OneHotInfo::InferDevMatrixShape() { | ||||
| @@ -278,13 +264,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() { | std::shared_ptr<Strategys> OneHotInfo::GenerateBatchStrategies() { | ||||
| CheckGlobalDeviceManager(); | CheckGlobalDeviceManager(); | ||||
| @@ -33,19 +33,21 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) { | |||||
| Status OperatorInfo::CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape) { | |||||
| if (strategy == nullptr) { | if (strategy == nullptr) { | ||||
| MS_LOG(ERROR) << "The strategy is null."; | |||||
| MS_LOG(ERROR) << name_ << ": The strategy is null."; | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| size_t strategy_size = strategy->GetInputNumber(); | size_t strategy_size = strategy->GetInputNumber(); | ||||
| size_t inputs_shape_size = inputs_shape.size(); | size_t inputs_shape_size = inputs_shape.size(); | ||||
| if (strategy_size != inputs_shape_size) { | if (strategy_size != inputs_shape_size) { | ||||
| if (is_auto_parallel) { | |||||
| MS_LOG(DEBUG) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Strategy size: " << strategy_size | |||||
| << " is not equal to inputs size: " << inputs_shape_size; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Strategy size: " << strategy_size << " is not equal to inputs size: " << inputs_shape_size; | |||||
| MS_LOG(ERROR) << name_ << ": Strategy size: " << strategy_size | |||||
| << " is not equal to inputs size: " << inputs_shape_size; | |||||
| } | } | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -57,11 +59,11 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap | |||||
| size_t strategy_len = sub_strategy.size(); | size_t strategy_len = sub_strategy.size(); | ||||
| size_t inputs_len = sub_input_shape.size(); | size_t inputs_len = sub_input_shape.size(); | ||||
| if (strategy_len != inputs_len) { | if (strategy_len != inputs_len) { | ||||
| if (is_auto_parallel) { | |||||
| MS_LOG(DEBUG) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len | |||||
| << ", index: " << i; | << ", index: " << i; | ||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len | |||||
| MS_LOG(ERROR) << name_ << ": Strategy len: " << strategy_len << " is not equal to inputs len: " << inputs_len | |||||
| << ", index: " << i; | << ", index: " << i; | ||||
| } | } | ||||
| return FAILED; | return FAILED; | ||||
| @@ -70,29 +72,29 @@ Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shap | |||||
| for (size_t j = 0; j < strategy_len; ++j) { | for (size_t j = 0; j < strategy_len; ++j) { | ||||
| int64_t strategy_value = sub_strategy.at(j); | int64_t strategy_value = sub_strategy.at(j); | ||||
| if (strategy_value < MIN_SLICE_NUM) { | if (strategy_value < MIN_SLICE_NUM) { | ||||
| if (is_auto_parallel) { | |||||
| MS_LOG(DEBUG) << "Invalid strategy value: " << strategy_value; | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Invalid strategy value: " << strategy_value; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Invalid strategy value: " << strategy_value; | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy value: " << strategy_value; | |||||
| } | } | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if ((IntToUint(strategy_value) & IntToUint(strategy_value - 1)) != 0) { | if ((IntToUint(strategy_value) & IntToUint(strategy_value - 1)) != 0) { | ||||
| if (is_auto_parallel) { | |||||
| MS_LOG(DEBUG) << "Invalid Strategy value it is not the power of 2, " << strategy_value; | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Invalid Strategy value it is not the power of 2, " << strategy_value; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Invalid Strategy value it is not the power of 2, " << strategy_value; | |||||
| MS_LOG(ERROR) << name_ << ": Invalid Strategy value it is not the power of 2, " << strategy_value; | |||||
| } | } | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| int64_t shape_value = sub_input_shape.at(j); | int64_t shape_value = sub_input_shape.at(j); | ||||
| if ((shape_value % strategy_value) != 0) { | if ((shape_value % strategy_value) != 0) { | ||||
| if (is_auto_parallel) { | |||||
| MS_LOG(DEBUG) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; | |||||
| if (is_auto_parallel_) { | |||||
| MS_LOG(DEBUG) << name_ << ": Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; | |||||
| MS_LOG(ERROR) << name_ << ": Shape " << shape_value << " cannot be divisible by strategy " << strategy_value; | |||||
| } | } | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -176,6 +176,7 @@ class OperatorInfo { | |||||
| virtual Status GetAttrs() = 0; | virtual Status GetAttrs() = 0; | ||||
| virtual Status InferTensorInfo() = 0; | virtual Status InferTensorInfo() = 0; | ||||
| virtual Status InferDevMatrixShape() = 0; | virtual Status InferDevMatrixShape() = 0; | ||||
| Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); | |||||
| void SetDeviceListByStrategy(); | void SetDeviceListByStrategy(); | ||||
| void SetRepeatedCalcDevMatrix(); | void SetRepeatedCalcDevMatrix(); | ||||
| Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group); | Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector<Group> *group); | ||||
| @@ -34,7 +34,7 @@ namespace parallel { | |||||
| * the strategy of w should equal to the channel dimension of strategy of A, or equal to 1 | * the strategy of w should equal to the channel dimension of strategy of A, or equal to 1 | ||||
| */ | */ | ||||
| Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { | Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | MS_LOG(ERROR) << name_ << ": Invalid strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -220,12 +220,6 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| } // namespace parallel | } // namespace parallel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,14 +29,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ReduceMethod::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } | |||||
| Status ReduceMethod::InferDevMatrixShape() { | Status ReduceMethod::InferDevMatrixShape() { | ||||
| Strategys stra = strategy_->GetInputDim(); | Strategys stra = strategy_->GetInputDim(); | ||||
| @@ -354,14 +347,7 @@ Status ReduceMethod::InferTensorInfo() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ReduceMethod::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status ReduceMethod::GenerateStrategies(int32_t stage_id) { | Status ReduceMethod::GenerateStrategies(int32_t stage_id) { | ||||
| if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { | if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { | ||||
| @@ -29,14 +29,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } | |||||
| /* | /* | ||||
| * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of | * support parallel degree smaller than device number, set the duplicate device dimension to the first dimension of | ||||
| @@ -394,12 +387,7 @@ Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) { | |||||
| } | } | ||||
| Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { | Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { | ||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return SetCostUnderStrategyBase(strategy); | |||||
| } | } | ||||
| void ReshapeInfo::SetCostForReshapeWithParameter() { | void ReshapeInfo::SetCostForReshapeWithParameter() { | ||||
| @@ -98,7 +98,7 @@ Status StridedSliceInfo::GetAttrs() { | |||||
| Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) { | Status StridedSliceInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| MS_EXCEPTION_IF_NULL(strategy); | MS_EXCEPTION_IF_NULL(strategy); | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy"; | MS_LOG(ERROR) << name_ << ": Invalid strategy"; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -232,12 +232,7 @@ std::shared_ptr<Strategys> StridedSliceInfo::GenerateBatchStrategies() { | |||||
| } | } | ||||
| Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | Status StridedSliceInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | ||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return SetCostUnderStrategyBase(strategy); | |||||
| } | } | ||||
| Status StridedSliceInfo::GenerateStrategies(int32_t stage_id) { | Status StridedSliceInfo::GenerateStrategies(int32_t stage_id) { | ||||
| @@ -67,12 +67,7 @@ Status TileInfo::GetAttrs() { | |||||
| Status TileInfo::CheckStrategy(const StrategyPtr &strategy) { | Status TileInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| Shapes multiples = {full_multiples_}; | Shapes multiples = {full_multiples_}; | ||||
| if (CheckStrategyValue(strategy, multiples, is_auto_parallel_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return CheckStrategyValue(strategy, multiples); | |||||
| } | } | ||||
| Status TileInfo::InferDevMatrixShape() { | Status TileInfo::InferDevMatrixShape() { | ||||
| @@ -197,14 +192,7 @@ std::shared_ptr<Strategys> TileInfo::GenerateBatchStrategies() { | |||||
| return GenerateBatchStrategiesBySplitFlag(multiples_shape, split_flag_list_); | return GenerateBatchStrategiesBySplitFlag(multiples_shape, split_flag_list_); | ||||
| } | } | ||||
| Status TileInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TileInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status TileInfo::GenerateStrategies(int32_t stage_id) { | Status TileInfo::GenerateStrategies(int32_t stage_id) { | ||||
| if (InferAttrs() != SUCCESS) { | if (InferAttrs() != SUCCESS) { | ||||
| @@ -25,11 +25,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { | Status TmpIdentityInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": invalid strategy."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return CheckStrategyValue(strategy, inputs_shape_); | |||||
| } | } | ||||
| Status TmpIdentityInfo::InferDevMatrixShape() { | Status TmpIdentityInfo::InferDevMatrixShape() { | ||||
| @@ -98,14 +94,7 @@ Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||||
| Status TmpIdentityInfo::GenerateStrategies(int32_t stage_id) { | Status TmpIdentityInfo::GenerateStrategies(int32_t stage_id) { | ||||
| if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { | if ((inputs_shape_.size() != 1) || (outputs_shape_.size() != 1)) { | ||||
| @@ -27,14 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { return CheckStrategyValue(strategy, inputs_shape_); } | |||||
| Status TransposeInfo::InferDevMatrixShape() { | Status TransposeInfo::InferDevMatrixShape() { | ||||
| Strategys stra = strategy_->GetInputDim(); | Strategys stra = strategy_->GetInputDim(); | ||||
| @@ -195,12 +188,7 @@ Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) { | |||||
| } | } | ||||
| Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { | Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { | ||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return SetCostUnderStrategyBase(strategy); | |||||
| } | } | ||||
| Status TransposeInfo::GenerateStrategies(int32_t stage_id) { | Status TransposeInfo::GenerateStrategies(int32_t stage_id) { | ||||
| @@ -29,7 +29,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { | Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { | ||||
| if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { | |||||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Invalid strategy."; | MS_LOG(ERROR) << name_ << ": Invalid strategy."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -143,12 +143,7 @@ void VirtualDatasetInfo::ReComputeBatchSplitFlagList() { | |||||
| } | } | ||||
| Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | ||||
| if (SetCostUnderStrategyBase(strategy) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| return SetCostUnderStrategyBase(strategy); | |||||
| } | } | ||||
| Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { | Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { | ||||