Merge pull request !25793 from yangzhenzhang/add-output-strategytags/v1.6.0
| @@ -50,7 +50,7 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std | |||
| // Set user-defined strategy | |||
| auto attrs = op->attrs(); | |||
| if (StrategyFound(attrs)) { | |||
| StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs[STRATEGY]); | |||
| StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs[IN_STRATEGY]); | |||
| op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost()); | |||
| } | |||
| // Set back to raw strategy for special node in predict/eval | |||
| @@ -105,7 +105,7 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) { | |||
| elements.push_back(MakeValue(input_strategy)); | |||
| } | |||
| ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements); | |||
| cnode->AddPrimalAttr(STRATEGY, strategy); | |||
| cnode->AddPrimalAttr(IN_STRATEGY, strategy); | |||
| } | |||
| void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager, | |||
| @@ -305,9 +305,9 @@ Status DropoutInfo::InferAsLossDivisor() { | |||
| return SUCCESS; | |||
| } | |||
| Status DropoutInfo::InferReplaceOps() { | |||
| void DropoutInfo::InferReplaceOps() { | |||
| if ((seed0_ != 0) || (seed1_ != 0) || (repeated_calc_num_ == 1)) { | |||
| return SUCCESS; | |||
| return; | |||
| } | |||
| int64_t seed = get_seed(); | |||
| ValuePtr new_seed0 = MakeValue(seed); | |||
| @@ -319,38 +319,6 @@ Status DropoutInfo::InferReplaceOps() { | |||
| OperatorParams params; | |||
| OperatorArgs args = std::make_pair(attrs, params); | |||
| replace_op_ = {std::make_pair(DROPOUT, args)}; | |||
| return SUCCESS; | |||
| } | |||
| Status DropoutInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed"; | |||
| return FAILED; | |||
| } | |||
| (void)InferReplaceOps(); | |||
| MS_LOG(INFO) << name_ << " : Init success"; | |||
| return SUCCESS; | |||
| } | |||
| Status ActivationBase::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status CastInfo::InferMirrorOps() { | |||
| @@ -441,28 +409,6 @@ Status ExpandDimsInfo::InferTensorMap() { | |||
| return SUCCESS; | |||
| } | |||
| Status ExpandDimsInfo::InferTensorStrategy() { | |||
| if (strategy_ == nullptr) { | |||
| MS_LOG(ERROR) << name_ << ": The strategy is null"; | |||
| return FAILED; | |||
| } | |||
| inputs_strategy_ = strategy_->GetInputDim(); | |||
| if (inputs_strategy_.empty()) { | |||
| MS_LOG(ERROR) << name_ << ": The strategy is empty"; | |||
| return FAILED; | |||
| } | |||
| Shape output_strategy = inputs_strategy_[0]; | |||
| if ((positive_axis_ < 0) || (positive_axis_ > SizeToLong(output_strategy.size()))) { | |||
| MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_; | |||
| return FAILED; | |||
| } | |||
| (void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY); | |||
| outputs_strategy_ = {output_strategy}; | |||
| return SUCCESS; | |||
| } | |||
| Status ExpandDimsInfo::InferMirrorOps() { | |||
| mirror_ops_.clear(); | |||
| @@ -538,13 +484,12 @@ Status SqueezeInfo::GetAttrs() { | |||
| return SUCCESS; | |||
| } | |||
| Status SqueezeInfo::InferReplaceOps() { | |||
| void SqueezeInfo::InferReplaceOps() { | |||
| Attr attr = std::make_pair(AXIS, axis_); | |||
| OperatorAttrs attrs = {attr}; | |||
| OperatorParams params; | |||
| OperatorArgs args = std::make_pair(attrs, params); | |||
| replace_op_ = {std::make_pair(SQUEEZE, args)}; | |||
| return SUCCESS; | |||
| } | |||
| Status SqueezeInfo::InferTensorMap() { | |||
| @@ -572,16 +517,5 @@ Status SqueezeInfo::InferTensorMap() { | |||
| return SUCCESS; | |||
| } | |||
| Status SqueezeInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed."; | |||
| } | |||
| (void)InferReplaceOps(); | |||
| MS_LOG(INFO) << name_ << " : Init success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -36,9 +36,6 @@ class ActivationBase : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} | |||
| ~ActivationBase() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| protected: | |||
| Status InferMirrorOps() override; | |||
| Status InferForwardCommunication() override; | |||
| @@ -222,7 +219,6 @@ class ExpandDimsInfo : public ActivationOther { | |||
| Status GetAttrs() override; | |||
| Status InferTensorMap() override; | |||
| Status InferMirrorOps() override; | |||
| Status InferTensorStrategy(); | |||
| private: | |||
| int64_t positive_axis_ = -1; | |||
| @@ -240,9 +236,8 @@ class SqueezeInfo : public ActivationOther { | |||
| protected: | |||
| Status InferAxis(const ValueTuplePtr &value_tuple); | |||
| Status GetAttrs() override; | |||
| Status InferReplaceOps(); | |||
| void InferReplaceOps() override; | |||
| Status InferTensorMap() override; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| private: | |||
| ValueTuplePtr axis_; | |||
| @@ -271,12 +266,11 @@ class DropoutInfo : public ActivationOther { | |||
| : ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {} | |||
| ~DropoutInfo() override = default; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| protected: | |||
| Status GetAttrs() override; | |||
| Status InferTensorMap() override; | |||
| Status InferReplaceOps(); | |||
| void InferReplaceOps() override; | |||
| Status InferAsLossDivisor() override; | |||
| private: | |||
| @@ -191,24 +191,5 @@ std::vector<StrategyPtr> ArithmeticBase::GenerateOpStrategies(int64_t stage_id) | |||
| return sp_vector; | |||
| } | |||
| Status ArithmeticBase::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status ArithmeticBase::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -35,8 +35,6 @@ class ArithmeticBase : public OperatorInfo { | |||
| const PrimitiveAttrs &attrs, OperatorCostPtr cost) | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} | |||
| ~ArithmeticBase() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| @@ -126,25 +126,6 @@ Status BatchParallelInfo::GetAttrs() { | |||
| return SUCCESS; | |||
| } | |||
| Status BatchParallelInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { | |||
| return SetCostUnderStrategyBase(strategy); | |||
| } | |||
| @@ -37,8 +37,6 @@ class BatchParallelInfo : public OperatorInfo { | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {} | |||
| ~BatchParallelInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| void ReplaceNodeInputOrAttrs() override; | |||
| @@ -196,17 +196,17 @@ Status BatchNormInfo::InferForwardCommunication() { | |||
| return SUCCESS; | |||
| } | |||
| Status BatchNormInfo::InferReplaceOps() { | |||
| void BatchNormInfo::InferReplaceOps() { | |||
| replace_op_.clear(); | |||
| if (!is_training_) { | |||
| MS_LOG(INFO) << name_ << ": It is not training, no need to replace op"; | |||
| return SUCCESS; | |||
| return; | |||
| } | |||
| if (forward_allreduce_group_.empty()) { | |||
| MS_LOG(INFO) << name_ << ": The forward allreduce group is empty, no need to replace op"; | |||
| return SUCCESS; | |||
| return; | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| @@ -215,7 +215,7 @@ Status BatchNormInfo::InferReplaceOps() { | |||
| if (backend != kAscendDevice && backend != kDavinciDevice) { | |||
| MS_LOG(INFO) << name_ << ": The backend is " << backend << ", it does not support SyncBatchNorm operator"; | |||
| return SUCCESS; | |||
| return; | |||
| } | |||
| ValuePtr epsilon = MakeValue(epsilon_); | |||
| @@ -232,7 +232,6 @@ Status BatchNormInfo::InferReplaceOps() { | |||
| OperatorParams params; | |||
| OperatorArgs args = std::make_pair(attrs, params); | |||
| replace_op_ = {std::make_pair(SYNC_BATCH_NORM, args)}; | |||
| return SUCCESS; | |||
| } | |||
| Status BatchNormInfo::InferAsLossDivisor() { | |||
| @@ -261,26 +260,5 @@ std::vector<StrategyPtr> BatchNormInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| sp_vector.push_back(sp); | |||
| return sp_vector; | |||
| } | |||
| Status BatchNormInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| (void)InferReplaceOps(); | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status BatchNormInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -36,8 +36,6 @@ class BatchNormInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {} | |||
| ~BatchNormInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| @@ -47,7 +45,7 @@ class BatchNormInfo : public OperatorInfo { | |||
| Status InferForwardCommunication() override; | |||
| Status InferDevMatrixShape() override; | |||
| Status InferTensorMap() override; | |||
| Status InferReplaceOps(); | |||
| void InferReplaceOps() override; | |||
| Status InferAsLossDivisor() override; | |||
| private: | |||
| @@ -101,24 +101,5 @@ std::vector<StrategyPtr> BiasAddInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| } | |||
| return sp_vector; | |||
| } | |||
| Status BiasAddInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status BiasAddInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -37,8 +37,6 @@ class BiasAddInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {} | |||
| ~BiasAddInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| @@ -176,24 +176,5 @@ ReplaceGraphPtr BroadcastToInfo::replace_graph(const CNodePtr &cnode) { | |||
| } | |||
| return replace_graph_; | |||
| } | |||
| Status BroadcastToInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status BroadcastToInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -39,8 +39,6 @@ class BroadcastToInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BroadcastToCost>()) {} | |||
| ~BroadcastToInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||
| @@ -182,24 +182,5 @@ std::vector<StrategyPtr> ConcatInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| Status ConcatInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status ConcatInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -36,8 +36,6 @@ class ConcatInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>()) {} | |||
| ~ConcatInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| @@ -837,25 +837,6 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| Status Conv2DInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status Conv2DInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status Conv2DBackpropInputInfo::GetOutShape() { | |||
| if (input_value_.size() != 3) { | |||
| MS_LOG(ERROR) << name_ << ": The size of input value must be 3, but got " << input_value_.size(); | |||
| @@ -37,8 +37,6 @@ class Conv2DInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {} | |||
| ~Conv2DInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| @@ -117,26 +117,6 @@ std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() { | |||
| return std::make_shared<Strategys>(strategy_v); | |||
| } | |||
| Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success"; | |||
| return SUCCESS; | |||
| } | |||
| size_t GetNonMonadInputSize(const CNodePtr &cnode) { | |||
| size_t cnode_non_monad_size = cnode->size(); | |||
| for (auto &input : cnode->inputs()) { | |||
| @@ -36,10 +36,8 @@ class DropoutDoMaskInfo : public OperatorInfo { | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutDoMaskCost>()) {} | |||
| ~DropoutDoMaskInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| std::vector<Operator> GetDropoutGenMaskReplaceOp(); | |||
| void ReplaceNodeInputOrAttrs() override; | |||
| @@ -157,25 +157,6 @@ Status DSDMatmulInfo::GetAttrs() { | |||
| return SUCCESS; | |||
| } | |||
| Status DSDMatmulInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status DSDMatmulInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| std::vector<StrategyPtr> DSDMatmulInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| // to generate the first input's strategy | |||
| Shape input0_split = {1, 1, 0, 0, 0, 0, 0}; | |||
| @@ -37,8 +37,6 @@ class DSDMatmulInfo : public OperatorInfo { | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DSDMatmulCost>()) {} | |||
| ~DSDMatmulInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -215,24 +215,5 @@ std::vector<StrategyPtr> GatherDInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| } | |||
| return sp_vector; | |||
| } | |||
| Status GatherDInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status GatherDInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -36,8 +36,6 @@ class GatherDInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherDCost>()) {} | |||
| ~GatherDInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| @@ -149,24 +149,5 @@ std::vector<StrategyPtr> GatherNdInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| Status GatherNdInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status GatherNdInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -36,8 +36,6 @@ class GatherNdInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherNdCost>()) {} | |||
| ~GatherNdInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| @@ -103,21 +103,6 @@ Status GetNextInfo::InferDevMatrixShape() { | |||
| return SUCCESS; | |||
| } | |||
| Status GetNextInfo::Init(const StrategyPtr &strategy) { | |||
| repeated_num_in_dev_matrix_right_ = false; | |||
| if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) { | |||
| repeated_num_in_dev_matrix_right_ = true; | |||
| } | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed"; | |||
| return FAILED; | |||
| } | |||
| InferReplaceOps(strategy); | |||
| MS_LOG(INFO) << name_ << " : Init success"; | |||
| return SUCCESS; | |||
| } | |||
| Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| Strategys stras = strategy->GetInputDim(); | |||
| for (Dimensions stra : stras) { | |||
| @@ -200,6 +185,11 @@ Status GetNextInfo::GetAttrOutPutNum() { | |||
| } | |||
| Status GetNextInfo::GetAttrs() { | |||
| repeated_num_in_dev_matrix_right_ = false; | |||
| if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) { | |||
| repeated_num_in_dev_matrix_right_ = true; | |||
| } | |||
| if (GetAttrTypes() == FAILED || GetAttrShapes() == FAILED || GetAttrOutPutNum() == FAILED) { | |||
| return FAILED; | |||
| } | |||
| @@ -210,7 +200,7 @@ Status GetNextInfo::GetAttrs() { | |||
| return SUCCESS; | |||
| } | |||
| void GetNextInfo::InferReplaceOps(const StrategyPtr &) { | |||
| void GetNextInfo::InferReplaceOps() { | |||
| Shapes out_shapes; | |||
| (void)std::transform(outputs_tensor_info_.begin(), outputs_tensor_info_.end(), std::back_inserter(out_shapes), | |||
| [](auto tensor_info) { return tensor_info.slice_shape(); }); | |||
| @@ -225,19 +215,6 @@ void GetNextInfo::InferReplaceOps(const StrategyPtr &) { | |||
| replace_op_ = {std::make_pair(GET_NEXT, args)}; | |||
| } | |||
| Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| repeated_num_in_dev_matrix_right_ = false; | |||
| if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) { | |||
| repeated_num_in_dev_matrix_right_ = true; | |||
| } | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||
| std::vector<StrategyPtr> GetNextInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| @@ -35,9 +35,7 @@ class GetNextInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {} | |||
| ~GetNextInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| protected: | |||
| @@ -49,7 +47,7 @@ class GetNextInfo : public OperatorInfo { | |||
| Status InferDevMatrixShape() override; | |||
| Status InferMirrorOps() override { return SUCCESS; } | |||
| Status InferForwardCommunication() override { return SUCCESS; } | |||
| void InferReplaceOps(const StrategyPtr &strategy); | |||
| void InferReplaceOps() override; | |||
| Status GetAttrTypes(); | |||
| Status GetAttrShapes(); | |||
| Status GetAttrOutPutNum(); | |||
| @@ -30,6 +30,11 @@ namespace parallel { | |||
| // if the begin-norm-axis is 3, the shape of second output is: [A, B, C, 1] | |||
| // the shape of third output is the same as the shape of second output | |||
| Status LayerNormInfo::GetAttrs() { | |||
| if (InitShapes() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init Shape failed"; | |||
| return FAILED; | |||
| } | |||
| auto iter = attrs_.find(BEGIN_NORM_AXIS); | |||
| if (iter == attrs_.end()) { | |||
| MS_LOG(ERROR) << name_ << ": Can not find the attr of begin norm axis"; | |||
| @@ -239,24 +244,5 @@ Status LayerNormInfo::InitShapes() { | |||
| beta_shape_ = inputs_shape_[LAYER_NORM_BETA_INDEX]; | |||
| return SUCCESS; | |||
| } | |||
| Status LayerNormInfo::Init(const StrategyPtr &strategy) { | |||
| if ((InitShapes() != SUCCESS) || (InitWithAutoRepeatCalc(strategy)) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success"; | |||
| return SUCCESS; | |||
| } | |||
| Status LayerNormInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if ((InitShapes() != SUCCESS) || (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS)) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success"; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -44,8 +44,6 @@ class LayerNormInfo : public OperatorInfo { | |||
| begin_norm_axis_(0) {} | |||
| ~LayerNormInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| @@ -105,26 +105,6 @@ Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() { | |||
| return SUCCESS; | |||
| } | |||
| Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| void SoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() { | |||
| for (size_t i = 0; i < inputs_shape_.size(); ++i) { | |||
| split_flag_list_[i] = true; | |||
| @@ -38,8 +38,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {} | |||
| ~SoftmaxCrossEntropyWithLogitsInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -265,25 +265,6 @@ ReplaceGraphPtr MatmulDDSInfo::replace_graph(const CNodePtr &cnode) { | |||
| return replace_graph_; | |||
| } | |||
| Status MatmulDDSInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status MatmulDDSInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| std::vector<StrategyPtr> MatmulDDSInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| // to generate the first input's strategy | |||
| Shape input0_split = {1, 1, 0, 0}; | |||
| @@ -37,8 +37,6 @@ class MatmulDDSInfo : public OperatorInfo { | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatmulDDSCost>()) {} | |||
| ~MatmulDDSInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -375,16 +375,6 @@ Status MatMulBase::Init(const StrategyPtr &strategy) { | |||
| return SUCCESS; | |||
| } | |||
| Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) { | |||
| if (input->size() < 2) { | |||
| MS_LOG(ERROR) << name_ << " : The size of inputs small than 2."; | |||
| @@ -38,7 +38,6 @@ class MatMulBase : public OperatorInfo { | |||
| ~MatMulBase() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| // Generate all strategies and the corresponding cost for this MatMul operator | |||
| Status GenerateStrategies(int64_t stage_id) override; | |||
| @@ -190,24 +190,5 @@ std::vector<StrategyPtr> MaxPoolInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| } | |||
| return sp_vector; | |||
| } | |||
| Status MaxPoolInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status MaxPoolInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -36,8 +36,6 @@ class MaxPoolInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<MaxPoolCost>()) {} | |||
| ~MaxPoolInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| @@ -207,29 +207,6 @@ ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) { | |||
| return replace_graph_; | |||
| } | |||
| Status OneHotInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| Status status = ComputeReplaceGraph(cnode_); | |||
| if (status != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; | |||
| return status; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| Shapes splittable_inputs = {{1, 1}, {}, {}}; | |||
| std::vector<StrategyPtr> sp_vector; | |||
| @@ -35,8 +35,6 @@ class OneHotInfo : public OperatorInfo { | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {} | |||
| ~OneHotInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -710,6 +710,26 @@ Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Str | |||
| return SUCCESS; | |||
| } | |||
| Status OperatorInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status OperatorInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| // method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1 | |||
| Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) { | |||
| if (strategy == nullptr) { | |||
| @@ -770,54 +790,6 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat | |||
| return SUCCESS; | |||
| } | |||
| // method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape | |||
| Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) { | |||
| if (strategy == nullptr) { | |||
| MS_LOG(ERROR) << name_ << ": The strategy is null."; | |||
| return FAILED; | |||
| } | |||
| if (InferAttrs() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferAttrs failed."; | |||
| return FAILED; | |||
| } | |||
| // must be after InferAttrs() | |||
| if (CheckStrategy(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": CheckStrategy failed."; | |||
| return FAILED; | |||
| } | |||
| // need to clear queues before Init(), | |||
| // because Init() may be called multiple times by cost model | |||
| ResetQueueMember(); | |||
| strategy_ = strategy; | |||
| if (InferDevMatrixShape() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed."; | |||
| return FAILED; | |||
| } | |||
| // must be after InferDevMatrixShape | |||
| if (InferRepeatedCalcInfo() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferRepeatedCalcInfo failed."; | |||
| return FAILED; | |||
| } | |||
| if (InferTensorMap() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferTensorMap failed."; | |||
| return FAILED; | |||
| } | |||
| if (InferTensorInfo() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferTensorInfo failed."; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) { | |||
| if (strategy == nullptr) { | |||
| MS_LOG(ERROR) << name_ << ": The strategy is null."; | |||
| @@ -843,34 +815,7 @@ Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) { | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) { | |||
| if (strategy == nullptr) { | |||
| MS_LOG(ERROR) << name_ << ": The strategy is null."; | |||
| return FAILED; | |||
| } | |||
| if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| if (InferForwardCommunication() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed."; | |||
| return FAILED; | |||
| } | |||
| if (InferMirrorOps() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferMirrorOps failed."; | |||
| return FAILED; | |||
| } | |||
| if (InferVirtualDivOps() != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed."; | |||
| return FAILED; | |||
| } | |||
| InferReplaceOps(); | |||
| return SUCCESS; | |||
| } | |||
| @@ -81,8 +81,8 @@ class OperatorInfo { | |||
| // If output is tuple, outputs_type.size() is greater than 1. | |||
| Status set_outputs_type(const std::vector<TypePtr> &outputs_type); | |||
| const std::vector<TypePtr> &outputs_type() const { return outputs_type_; } | |||
| virtual Status Init(const StrategyPtr &strategy) = 0; | |||
| virtual Status InitForCostModel(const StrategyPtr &strategy) = 0; // only init the necessary parts | |||
| virtual Status Init(const StrategyPtr &strategy); | |||
| virtual Status InitForCostModel(const StrategyPtr &strategy); // only init the necessary parts | |||
| // Given the stage_id (which indicates the number of devices), | |||
| // generate all strategies for this operator | |||
| @@ -198,6 +198,7 @@ class OperatorInfo { | |||
| virtual Status InferDevMatrixShape() = 0; | |||
| virtual Status InferMirrorOps(); | |||
| virtual Status InferTensorInfo(); | |||
| virtual void InferReplaceOps() {} | |||
| Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape); | |||
| void SetRepeatedCalcDevMatrix(); | |||
| void ResetTensorMapIfRepeatedCalc(); | |||
| @@ -205,9 +206,7 @@ class OperatorInfo { | |||
| Status InferAttrs(); | |||
| void ResetQueueMember(); | |||
| Status InitWithAutoRepeatCalc(const StrategyPtr &strategy); | |||
| Status InitWithManualRepeatCalc(const StrategyPtr &strategy); | |||
| Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy); | |||
| Status InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy); | |||
| Status InferRepeatedCalcInfo(); | |||
| Status InferVirtualDivOps(); | |||
| @@ -88,7 +88,7 @@ constexpr double COST_FACTOR = 2.0; | |||
| constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only"; | |||
| constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only"; | |||
| constexpr char CHECK_SET_STRATEGY_VALID_ONCE_ONLY[] = "check_set_strategy_valid_once_only"; | |||
| constexpr char STRATEGY[] = "strategy"; | |||
| constexpr char IN_STRATEGY[] = "in_strategy"; | |||
| constexpr char STAGE_ATTR[] = "stage"; | |||
| constexpr char GEN_STRATEGY[] = "gen_strategy"; | |||
| constexpr char REDUCE_OP_SUM[] = "sum"; | |||
| @@ -160,24 +160,5 @@ std::vector<StrategyPtr> StackInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| Status StackInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status StackInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -36,8 +36,6 @@ class StackInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>()) {} | |||
| ~StackInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| @@ -92,25 +92,6 @@ Status PReLUInfo::GetAttrs() { | |||
| return SUCCESS; | |||
| } | |||
| Status PReLUInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status PReLUInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| std::vector<StrategyPtr> PReLUInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| Shape input0_split; | |||
| input0_split.emplace_back(1); | |||
| @@ -37,8 +37,6 @@ class PReLUInfo : public OperatorInfo { | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {} | |||
| ~PReLUInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -83,26 +83,6 @@ Status RangeInfo::InferTensorMap() { | |||
| return SUCCESS; | |||
| } | |||
| Status RangeInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success"; | |||
| return SUCCESS; | |||
| } | |||
| Status RangeInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success"; | |||
| return SUCCESS; | |||
| } | |||
| std::vector<StrategyPtr> RangeInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| Shape input0_split(inputs_shape_[0].size(), 1); | |||
| Shapes splittable_inputs = {input0_split}; | |||
| @@ -42,9 +42,6 @@ class RangeInfo : public OperatorInfo { | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RangeCost>()) {} | |||
| ~RangeInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -390,25 +390,6 @@ std::vector<StrategyPtr> ReduceMethod::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| Status ReduceMethod::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status ReduceMethod::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success"; | |||
| return SUCCESS; | |||
| } | |||
| std::vector<int64_t> ArgMaxWithValueInfo::reduce_dim() { | |||
| std::vector<int64_t> dim_list; | |||
| auto iter = attrs_.find(AXIS); | |||
| @@ -37,9 +37,6 @@ class ReduceMethod : public OperatorInfo { | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {} | |||
| ~ReduceMethod() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -120,25 +120,5 @@ Status ReLUV2Info::InferAsLossDivisor() { | |||
| << as_loss_divisor_; | |||
| return SUCCESS; | |||
| } | |||
| Status ReLUV2Info::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status ReLUV2Info::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -40,8 +40,6 @@ class ReLUV2Info : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUV2Cost>()) {} | |||
| ~ReLUV2Info() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -365,16 +365,6 @@ Status ReshapeInfo::Init(const StrategyPtr &strategy) { | |||
| return SUCCESS; | |||
| } | |||
| Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { | |||
| return SetCostUnderStrategyBase(strategy); | |||
| } | |||
| @@ -61,7 +61,6 @@ class ReshapeInfo : public OperatorInfo { | |||
| Status GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs, | |||
| const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int64_t out_index, | |||
| int64_t in_index, bool is_prev_param, bool is_next_reshape); | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| Status GenerateStrategies(int64_t stage_id) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -118,25 +118,6 @@ std::vector<StrategyPtr> ResizeBilinearInfo::GenerateOpStrategies(int64_t stage_ | |||
| return sp_vector; | |||
| } | |||
| Status ResizeBilinearInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status ResizeBilinearInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() { | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode_->input(0)); | |||
| prim->set_attr(SIZE, MakeValue(slice_size_)); | |||
| @@ -36,8 +36,6 @@ class ResizeBilinearInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ResizeBilinearCost>()) {} | |||
| ~ResizeBilinearInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| @@ -176,24 +176,5 @@ std::vector<StrategyPtr> ScatterUpdateInfo::GenerateOpStrategies(int64_t stage_i | |||
| return sp_vector; | |||
| } | |||
| Status ScatterUpdateInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status ScatterUpdateInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -36,8 +36,6 @@ class ScatterUpdateInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ScatterUpdateCost>()) {} | |||
| ~ScatterUpdateInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| @@ -125,24 +125,5 @@ std::vector<StrategyPtr> SelectInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| Status SelectInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status SelectInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -36,8 +36,6 @@ class SelectInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SelectCost>()) {} | |||
| ~SelectInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| void ReComputeBatchSplitFlagList() override; | |||
| @@ -174,25 +174,6 @@ std::vector<StrategyPtr> SliceInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| Status SliceInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status SliceInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| ReplaceGraphPtr SliceInfo::replace_graph(const CNodePtr &cnode) { | |||
| auto input_strategy = strategy_->GetInputDim().at(0); | |||
| if (std::any_of(input_strategy.begin(), input_strategy.end(), [](const int64_t &shard) { return shard > 1; })) { | |||
| @@ -38,8 +38,6 @@ class SliceInfo : public OperatorInfo { | |||
| slice_axis_(-1) {} | |||
| ~SliceInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| @@ -196,24 +196,5 @@ Status SplitInfo::InferAsLossDivisor() { | |||
| << as_loss_divisor_; | |||
| return SUCCESS; | |||
| } | |||
| Status SplitInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status SplitInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -35,8 +35,6 @@ class SplitInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SplitCost>()) {} | |||
| ~SplitInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| @@ -234,24 +234,5 @@ std::vector<StrategyPtr> StridedSliceInfo::GenerateOpStrategies(int64_t stage_id | |||
| return sp_vector; | |||
| } | |||
| Status StridedSliceInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status StridedSliceInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -37,8 +37,6 @@ class StridedSliceInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>()) {} | |||
| ~StridedSliceInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| @@ -306,26 +306,6 @@ Status TensorDotInfo::InferTensorMap() { | |||
| return SUCCESS; | |||
| } | |||
| Status TensorDotInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success"; | |||
| return SUCCESS; | |||
| } | |||
| Status TensorDotInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success"; | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() { | |||
| if (GetAttrs() != SUCCESS) { | |||
| MS_LOG(EXCEPTION) << name_ << ": Get attr failed"; | |||
| @@ -44,9 +44,6 @@ class TensorDotInfo : public OperatorInfo { | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorDotCost>()) {} | |||
| ~TensorDotInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, | |||
| @@ -208,24 +208,5 @@ std::vector<StrategyPtr> TileInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| Status TileInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status TileInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -37,8 +37,6 @@ class TileInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>()) {} | |||
| ~TileInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| @@ -48,26 +48,6 @@ Status TmpIdentityInfo::InferTensorMap() { | |||
| return SUCCESS; | |||
| } | |||
| Status TmpIdentityInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||
| std::vector<StrategyPtr> TmpIdentityInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| @@ -37,9 +37,6 @@ class TmpIdentityInfo : public OperatorInfo { | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {} | |||
| ~TmpIdentityInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -164,24 +164,5 @@ std::vector<StrategyPtr> TopKInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| return sp_vector; | |||
| } | |||
| Status TopKInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status TopKInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -37,8 +37,6 @@ class TopKInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TopKCost>()) {} | |||
| ~TopKInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| @@ -111,25 +111,6 @@ Status TransposeInfo::InferTensorMap() { | |||
| // compute axis_v_ during this method | |||
| Status TransposeInfo::GetAttrs() { return ComputeAxis(); } | |||
| Status TransposeInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { | |||
| return SetCostUnderStrategyBase(strategy); | |||
| } | |||
| @@ -37,8 +37,6 @@ class TransposeInfo : public OperatorInfo { | |||
| const PrimitiveAttrs &attrs) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {} | |||
| ~TransposeInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -181,25 +181,6 @@ std::shared_ptr<Strategys> UniformCandidateSamplerInfo::GenerateBatchStrategies( | |||
| return std::make_shared<Strategys>(strategy_v); | |||
| } | |||
| Status UniformCandidateSamplerInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status UniformCandidateSamplerInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) { | |||
| auto input_strategy = strategy_->GetInputDim().at(0); | |||
| // Only when the axis-1 is sharded, we need to modify the attribute | |||
| @@ -43,8 +43,6 @@ class UniformCandidateSamplerInfo : public OperatorInfo { | |||
| remove_accidental_hits_(false) {} | |||
| ~UniformCandidateSamplerInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| std::shared_ptr<Strategys> GenerateBatchStrategies() override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| @@ -102,25 +102,6 @@ std::vector<StrategyPtr> UniformRealInfo::GenerateOpStrategies(int64_t stage_id) | |||
| return sp_vector; | |||
| } | |||
| Status UniformRealInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status UniformRealInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| void UniformRealInfo::UpdateShape(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_node = cnode->input(1)->cast<ValueNodePtr>(); | |||
| @@ -36,8 +36,6 @@ class UniformRealInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<MaxPoolCost>()) {} | |||
| ~UniformRealInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &) override; | |||
| void UpdateShape(const CNodePtr &cnode); | |||
| @@ -58,15 +58,6 @@ Status UniqueInfo::InferDevMatrixShape() { | |||
| return SUCCESS; | |||
| } | |||
| Status UniqueInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init failed"; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init success"; | |||
| return SUCCESS; | |||
| } | |||
| Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) { | |||
| Strategys stras = strategy->GetInputDim(); | |||
| if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) { | |||
| @@ -96,15 +87,6 @@ Status UniqueInfo::GetAttrs() { | |||
| return SUCCESS; | |||
| } | |||
| Status UniqueInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << " : Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << " : Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| Status UniqueInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } | |||
| std::vector<StrategyPtr> UniqueInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| @@ -35,9 +35,7 @@ class UniqueInfo : public OperatorInfo { | |||
| : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<UniqueCost>()) {} | |||
| ~UniqueInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||
| @@ -163,24 +163,6 @@ Status UnsortedSegmentOpInfo::InferTensorMap() { | |||
| return SUCCESS; | |||
| } | |||
| Status UnsortedSegmentOpInfo::Init(const StrategyPtr &strategy) { | |||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init success."; | |||
| return SUCCESS; | |||
| } | |||
| Status UnsortedSegmentOpInfo::InitForCostModel(const StrategyPtr &strategy) { | |||
| if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { | |||
| MS_LOG(ERROR) << name_ << ": Init for cost model failed."; | |||
| return FAILED; | |||
| } | |||
| MS_LOG(INFO) << name_ << ": Init for cost model success."; | |||
| return SUCCESS; | |||
| } | |||
| // Set the default strategy | |||
| std::vector<StrategyPtr> UnsortedSegmentOpInfo::GenerateOpStrategies(int64_t stage_id) { | |||
| Shape input0_split(inputs_shape_[0].size(), 1); | |||
| @@ -44,8 +44,6 @@ class UnsortedSegmentOpInfo : public OperatorInfo { | |||
| const PrimitiveAttrs &attrs, OperatorCostPtr cost) | |||
| : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {} | |||
| ~UnsortedSegmentOpInfo() override = default; | |||
| Status Init(const StrategyPtr &strategy) override; | |||
| Status InitForCostModel(const StrategyPtr &strategy) override; | |||
| std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override; | |||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | |||
| @@ -335,7 +335,7 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode, int tup | |||
| if (!StrategyFound(attrs)) { | |||
| strategy = GenerateBatchParallelStrategy(op_info, prim); | |||
| } else { | |||
| strategy = ExtractStrategy(attrs[STRATEGY]); | |||
| strategy = ExtractStrategy(attrs[IN_STRATEGY]); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(strategy); | |||
| if (op_info->Init(strategy) == FAILED) { | |||
| @@ -253,7 +253,7 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive | |||
| // In this case, the configured strategy should be extracted to help setting cost | |||
| StrategyPtr strategyPtr; | |||
| if (StrategyFound(attrs)) { | |||
| strategyPtr = parallel::ExtractStrategy(attrs[STRATEGY]); | |||
| strategyPtr = parallel::ExtractStrategy(attrs[IN_STRATEGY]); | |||
| } else { | |||
| strategyPtr = (*stra_map)[strategy_key_name]; | |||
| } | |||
| @@ -484,7 +484,7 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera | |||
| } | |||
| bool StrategyFound(const std::unordered_map<std::string, ValuePtr> &attrs) { | |||
| auto iter = attrs.find(STRATEGY); | |||
| auto iter = attrs.find(IN_STRATEGY); | |||
| return !((iter == attrs.end()) || (iter->second->type_name() == NONE)); | |||
| } | |||
| @@ -1760,7 +1760,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||
| (void)std::transform(dataset_strategy.begin(), dataset_strategy.end(), std::back_inserter(elements), | |||
| [](auto input_stra) { return MakeValue(input_stra); }); | |||
| ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements); | |||
| attrs_temp[STRATEGY] = strategy; | |||
| attrs_temp[IN_STRATEGY] = strategy; | |||
| (void)prim->SetAttrs(attrs_temp); | |||
| if (prim->HasAttr(REPEAT_DIM_DIRECT) && GetValue<std::string>(prim->GetAttr(REPEAT_DIM_DIRECT)) == RIGHT) { | |||
| ParallelContext::GetInstance()->set_dataset_repeat_dim_right(true); | |||
| @@ -1798,7 +1798,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) { | |||
| elements.push_back(MakeValue(input_strategy)); | |||
| } | |||
| ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements); | |||
| attrs_temp[STRATEGY] = strategy; | |||
| attrs_temp[IN_STRATEGY] = strategy; | |||
| (void)prim->SetAttrs(attrs_temp); | |||
| } | |||
| } | |||
| @@ -1954,14 +1954,14 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { | |||
| } | |||
| bool load_strategy_from_ckpt = | |||
| StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); | |||
| if ((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(STRATEGY)) { | |||
| if ((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(IN_STRATEGY)) { | |||
| MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() | |||
| << " is empty, using batch parallel"; | |||
| strategyPtr = GenerateBatchParallelStrategy(operator_, prim); | |||
| } else if (cnode->HasPrimalAttr(STRATEGY)) { | |||
| strategyPtr = ExtractStrategy(cnode->GetPrimalAttr(STRATEGY)); | |||
| } else if (cnode->HasPrimalAttr(IN_STRATEGY)) { | |||
| strategyPtr = ExtractStrategy(cnode->GetPrimalAttr(IN_STRATEGY)); | |||
| } else if (StrategyFound(attrs)) { | |||
| strategyPtr = ExtractStrategy(attrs[STRATEGY]); | |||
| strategyPtr = ExtractStrategy(attrs[IN_STRATEGY]); | |||
| } else { | |||
| strategyPtr = stra_map[strategy_key_name]; | |||
| } | |||
| @@ -151,7 +151,7 @@ class Primitive(Primitive_): | |||
| self.add_prim_attr("stage", stage) | |||
| return self | |||
| def shard(self, strategy): | |||
| def shard(self, in_strategy=None, out_strategy=None): | |||
| """ | |||
| Add strategies to primitive attribute. | |||
| @@ -160,24 +160,37 @@ class Primitive(Primitive_): | |||
| In other parallel modes, strategies set here will be ignored. | |||
| Args: | |||
| strategy (tuple): Strategy describes the distributed parallel mode of the current primitive. | |||
| in_strategy (tuple): Describe the split strategy of operator input. | |||
| out_strategy (tuple): Describe the split strategy of operator output, | |||
| it is only for certain operators, such as MatMul. | |||
| Examples: | |||
| >>> from mindspore import ops | |||
| >>> add = ops.Add() | |||
| >>> print(add.shard(((1, 1), (1, 1)))) | |||
| Prim[Add]<strategy=((1, 1), (1, 1))> | |||
| Prim[Add]<in_strategy=((1, 1), (1, 1)), out_strategy=None> | |||
| """ | |||
| mode = context.get_auto_parallel_context("parallel_mode") | |||
| if strategy is not None: | |||
| if not isinstance(strategy, tuple): | |||
| raise TypeError(f'strategy must be tuple type, but got:{type(strategy)}') | |||
| for ele in strategy: | |||
| if not isinstance(ele, tuple): | |||
| raise TypeError(f'The element of strategy must be tuple type, but got:{type(ele)}') | |||
| if not _is_in_auto_parallel_mode() and strategy: | |||
| logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. " | |||
| f"Please use semi auto or auto parallel mode.") | |||
| self.add_prim_attr("strategy", strategy) | |||
| if in_strategy is not None: | |||
| if not isinstance(in_strategy, tuple): | |||
| raise TypeError(f'in_strategy must be tuple type, but got:{type(in_strategy)}') | |||
| for in_ele in in_strategy: | |||
| if not isinstance(in_ele, tuple): | |||
| raise TypeError(f'The element of strategy must be tuple type, but got:{type(in_ele)}') | |||
| if out_strategy is not None: | |||
| if not isinstance(out_strategy, tuple): | |||
| raise TypeError(f'out strategy must be tuple type, but got:{type(out_strategy)}') | |||
| for out_ele in out_strategy: | |||
| if not isinstance(out_ele, tuple): | |||
| raise TypeError(f'The element of strategy must be tuple type, but got:{type(out_ele)}') | |||
| if not _is_in_auto_parallel_mode(): | |||
| if in_strategy is not None: | |||
| logger.warning(f"The in_strategy: {in_strategy} of {self.name} is not valid in {mode}. " | |||
| f"Please use semi auto or auto parallel mode.") | |||
| if out_strategy is not None: | |||
| logger.warning(f"The out_strategy: {out_strategy} of {self.name} is not valid in {mode}. " | |||
| f"Please use semi auto or auto parallel mode.") | |||
| self.add_prim_attr("in_strategy", in_strategy) | |||
| self.add_prim_attr("out_strategy", out_strategy) | |||
| return self | |||
| def set_prim_instance_name(self, instance_name): | |||
| @@ -92,7 +92,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape3) { | |||
| StrategyPtr strategy = NewStrategy(0, inputs); | |||
| Status status = onehot_info->Init(strategy); | |||
| ASSERT_EQ(status, FAILED); | |||
| ASSERT_EQ(status, SUCCESS); | |||
| Shape dev_matrix_shape = onehot_info->dev_matrix_shape(); | |||
| Shape expect = {4, 2}; | |||
| @@ -148,7 +148,7 @@ TEST_F(TestOneHotInfo, InferSliceShape2) { | |||
| StrategyPtr strategy = NewStrategy(0, str); | |||
| Status status = onehot_info->Init(strategy); | |||
| ASSERT_EQ(status, FAILED); | |||
| ASSERT_EQ(status, SUCCESS); | |||
| std::vector<TensorInfo> inputs = onehot_info->inputs_tensor_info(); | |||
| std::vector<TensorInfo> outputs = onehot_info->outputs_tensor_info(); | |||
| @@ -170,7 +170,7 @@ TEST_F(TestOneHotInfo, InferSliceShape3) { | |||
| StrategyPtr strategy = NewStrategy(0, str); | |||
| Status status = onehot_info->Init(strategy); | |||
| ASSERT_EQ(status, FAILED); | |||
| ASSERT_EQ(status, SUCCESS); | |||
| std::vector<TensorInfo> inputs = onehot_info->inputs_tensor_info(); | |||
| std::vector<TensorInfo> outputs = onehot_info->outputs_tensor_info(); | |||
| @@ -92,7 +92,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape3) { | |||
| StrategyPtr strategy = NewStrategy(0, inputs); | |||
| Status status = onehot_info2->Init(strategy); | |||
| ASSERT_EQ(status, FAILED); | |||
| ASSERT_EQ(status, SUCCESS); | |||
| Shape dev_matrix_shape = onehot_info2->dev_matrix_shape(); | |||
| Shape expect = {4, 2}; | |||
| @@ -148,7 +148,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape2) { | |||
| StrategyPtr strategy = NewStrategy(0, str); | |||
| Status status = onehot_info2->Init(strategy); | |||
| ASSERT_EQ(status, FAILED); | |||
| ASSERT_EQ(status, SUCCESS); | |||
| std::vector<TensorInfo> inputs = onehot_info2->inputs_tensor_info(); | |||
| std::vector<TensorInfo> outputs = onehot_info2->outputs_tensor_info(); | |||
| @@ -170,7 +170,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape3) { | |||
| StrategyPtr strategy = NewStrategy(0, str); | |||
| Status status = onehot_info2->Init(strategy); | |||
| ASSERT_EQ(status, FAILED); | |||
| ASSERT_EQ(status, SUCCESS); | |||
| std::vector<TensorInfo> inputs = onehot_info2->inputs_tensor_info(); | |||
| std::vector<TensorInfo> outputs = onehot_info2->outputs_tensor_info(); | |||
| @@ -154,7 +154,7 @@ FuncGraphManagerPtr Make_Manager(int64_t condition = 0) { | |||
| prim1->AddAttr("transpose_a", transpose_a); | |||
| prim1->AddAttr("transpose_b", transpose_b); | |||
| prim1->AddAttr("instance_name", MakeValue("matmul1")); | |||
| prim1->AddAttr("strategy", var); | |||
| prim1->AddAttr("in_strategy", var); | |||
| inputs.clear(); | |||
| Dimensions v3 = {2, 2}; | |||
| Dimensions v4 = {2, 4}; | |||
| @@ -176,16 +176,16 @@ FuncGraphManagerPtr Make_Manager(int64_t condition = 0) { | |||
| prim2->AddAttr("transpose_a", transpose_a); | |||
| prim2->AddAttr("transpose_b", transpose_b); | |||
| prim2->AddAttr("instance_name", MakeValue("matmul2")); | |||
| prim2->AddAttr("strategy", var2); | |||
| prim2->AddAttr("in_strategy", var2); | |||
| switch (condition) { | |||
| case 1: { | |||
| prim1->set_attr("strategy", MakeValue(static_cast<int64_t>(0))); | |||
| prim1->set_attr("in_strategy", MakeValue(static_cast<int64_t>(0))); | |||
| break; | |||
| } | |||
| case 2: { | |||
| std::vector<ValuePtr> elements_t = {MakeValue(static_cast<int64_t>(0))}; | |||
| ValueTuplePtr var_t = std::make_shared<ValueTuple>(elements_t); | |||
| prim1->set_attr("strategy", var_t); | |||
| prim1->set_attr("in_strategy", var_t); | |||
| break; | |||
| } | |||
| case 3: { | |||
| @@ -193,7 +193,7 @@ FuncGraphManagerPtr Make_Manager(int64_t condition = 0) { | |||
| Dimensions vt2 = {2, 4}; | |||
| std::vector<ValuePtr> elements_t2 = {MakeValue(vt1), MakeValue(vt2)}; | |||
| ValueTuplePtr var_t2 = std::make_shared<ValueTuple>(elements_t2); | |||
| prim1->set_attr("strategy", var_t2); | |||
| prim1->set_attr("in_strategy", var_t2); | |||
| break; | |||
| } | |||
| } | |||
| @@ -226,9 +226,9 @@ TEST_F(TestStepParallel, ExtractStrategy) { | |||
| ValuePtr val2 = MakeValue(v2); | |||
| std::vector<ValuePtr> elements = {val1, val2}; | |||
| ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements); | |||
| attrs["strategy"] = strategy_tuple; | |||
| attrs["in_strategy"] = strategy_tuple; | |||
| Strategys strategy_expect = {v1, v2}; | |||
| StrategyPtr strategy = ExtractStrategy(attrs["strategy"]); | |||
| StrategyPtr strategy = ExtractStrategy(attrs["in_strategy"]); | |||
| Strategys strategy_test = strategy->GetInputDim(); | |||
| ASSERT_EQ(strategy_expect, strategy_test); | |||
| @@ -29,8 +29,8 @@ grad_all = C.GradOperation(get_all=True) | |||
| class AddRelu(nn.Cell): | |||
| def __init__(self, strategy0=None, strategy1=None): | |||
| super(AddRelu, self).__init__() | |||
| self.add = P.Add().shard(strategy=strategy0) | |||
| self.relu = P.ReLU().shard(strategy=strategy1) | |||
| self.add = P.Add().shard(strategy0) | |||
| self.relu = P.ReLU().shard(strategy1) | |||
| def construct(self, x, z): | |||
| out = self.add(x, z) | |||
| @@ -85,15 +85,15 @@ class SemiAutoOneHotNet(Cell): | |||
| self.d = args.d | |||
| self.e = args.e | |||
| self.cast = P.Cast() | |||
| self.cast.shard(strategy=strategy.twod_strategy) | |||
| self.cast.shard(strategy.twod_strategy) | |||
| self.cast1 = P.Cast() | |||
| self.cast1.shard(strategy=strategy.twod_strategy) | |||
| self.cast1.shard(strategy.twod_strategy) | |||
| self.cast2 = P.Cast() | |||
| self.cast2.shard(strategy=strategy.twod_strategy) | |||
| self.cast2.shard(strategy.twod_strategy) | |||
| self.cast3 = P.Cast() | |||
| self.cast3.shard(strategy=strategy.scalar_strategy) | |||
| self.cast3.shard(strategy.scalar_strategy) | |||
| self.cast4 = P.Cast() | |||
| self.cast4.shard(strategy=strategy.scalar_strategy) | |||
| self.cast4.shard(strategy.scalar_strategy) | |||
| self.a_const = Tensor(self.a, dtype=mstype.float32) | |||
| self.b_const = Tensor(self.b, dtype=mstype.float32) | |||
| self.c_const = Tensor(self.c, dtype=mstype.float32) | |||
| @@ -102,64 +102,64 @@ class SemiAutoOneHotNet(Cell): | |||
| self.m_const_zero = Tensor(0, dtype=mstype.float32) | |||
| self.a_const_one = Tensor(1, dtype=mstype.float32) | |||
| self.onehot = P.OneHot() | |||
| self.onehot.shard(strategy=strategy.onehot_strategy) | |||
| self.onehot.shard(strategy.onehot_strategy) | |||
| self.exp = P.Exp() | |||
| self.exp.shard(strategy=strategy.twod_strategy) | |||
| self.exp.shard(strategy.twod_strategy) | |||
| self.exp2 = P.Exp() | |||
| self.exp2.shard(strategy=strategy.twod_strategy) | |||
| self.exp2.shard(strategy.twod_strategy) | |||
| self.exp3 = P.Exp() | |||
| self.exp3.shard(strategy=strategy.twod_strategy) | |||
| self.exp3.shard(strategy.twod_strategy) | |||
| self.mul_const = P.Mul() | |||
| self.mul_const.shard(strategy=strategy.scalar_twod_strategy) | |||
| self.mul_const.shard(strategy.scalar_twod_strategy) | |||
| self.mul_const2 = P.Add() | |||
| self.mul_const2.shard(strategy=strategy.scalar_twod_strategy) | |||
| self.mul_const2.shard(strategy.scalar_twod_strategy) | |||
| self.mul_const3 = P.Sub() | |||
| self.mul_const3.shard(strategy=strategy.twod_scalar_strategy) | |||
| self.mul_const3.shard(strategy.twod_scalar_strategy) | |||
| self.mul_const4 = P.Sub() | |||
| self.mul_const4.shard(strategy=strategy.scalar_twod_strategy) | |||
| self.mul_const4.shard(strategy.scalar_twod_strategy) | |||
| self.mul_const5 = P.Mul() | |||
| self.mul_const5.shard(strategy=strategy.twod_scalar_strategy) | |||
| self.mul_const5.shard(strategy.twod_scalar_strategy) | |||
| self.mul = P.Mul() | |||
| self.mul.shard(strategy=strategy.twod_twod_strategy) | |||
| self.mul.shard(strategy.twod_twod_strategy) | |||
| self.mul2 = P.Mul() | |||
| self.mul2.shard(strategy=strategy.twod_twod_strategy) | |||
| self.mul2.shard(strategy.twod_twod_strategy) | |||
| self.mul3 = P.Add() | |||
| self.mul3.shard(strategy=strategy.twod_twod_strategy) | |||
| self.mul3.shard(strategy.twod_twod_strategy) | |||
| self.mul4 = P.Sub() | |||
| self.mul4.shard(strategy=strategy.twod_twodbc_strategy) | |||
| self.mul4.shard(strategy.twod_twodbc_strategy) | |||
| self.mul5 = P.RealDiv() | |||
| self.mul5.shard(strategy=strategy.twod_twodbc_strategy) | |||
| self.mul5.shard(strategy.twod_twodbc_strategy) | |||
| self.mul6 = P.Mul() | |||
| self.mul6.shard(strategy=strategy.twod_twod_strategy) | |||
| self.mul6.shard(strategy.twod_twod_strategy) | |||
| self.mul7 = P.Mul() | |||
| self.mul7.shard(strategy=strategy.twod_scalar_strategy) | |||
| self.mul7.shard(strategy.twod_scalar_strategy) | |||
| self.mul8 = P.RealDiv() | |||
| self.mul8.shard(strategy=strategy.scalar_scalar_strategy) | |||
| self.mul8.shard(strategy.scalar_scalar_strategy) | |||
| self.mul9 = P.Add() | |||
| self.mul9.shard(strategy=strategy.twod_scalar_strategy) | |||
| self.mul9.shard(strategy.twod_scalar_strategy) | |||
| self.reduce_max = P.ReduceMax(keep_dims=True) | |||
| self.reduce_max.shard(strategy=strategy.twod_strategy) | |||
| self.reduce_max.shard(strategy.twod_strategy) | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.reduce_sum.shard(strategy=strategy.twod_strategy) | |||
| self.reduce_sum.shard(strategy.twod_strategy) | |||
| self.reduce_sum_2 = P.ReduceSum(keep_dims=False) | |||
| self.reduce_sum_2.shard(strategy=strategy.twod_strategy) | |||
| self.reduce_sum_2.shard(strategy.twod_strategy) | |||
| self.reduce_sum_3 = P.ReduceSum(keep_dims=False) | |||
| self.reduce_sum_3.shard(strategy=strategy.oned_strategy) | |||
| self.reduce_sum_3.shard(strategy.oned_strategy) | |||
| self.reshape = P.Reshape() | |||
| self.log = P.Log() | |||
| self.log.shard(strategy=strategy.twod_strategy) | |||
| self.log.shard(strategy.twod_strategy) | |||
| self.on_value = Tensor(1.0, mstype.float32) | |||
| self.off_value = Tensor(0.0, mstype.float32) | |||
| self.normalize = P.L2Normalize(axis=1) | |||
| self.normalize.shard(strategy=strategy.twod_strategy_m) | |||
| self.normalize.shard(strategy.twod_strategy_m) | |||
| self.normalize2 = P.L2Normalize(axis=1) | |||
| self.normalize2.shard(strategy=strategy.twod_strategy_m) | |||
| self.normalize2.shard(strategy.twod_strategy_m) | |||
| self.fc = P.MatMul(transpose_b=True) | |||
| self.fc.shard(strategy=strategy.twodbc_twod_strategy) | |||
| self.fc.shard(strategy.twodbc_twod_strategy) | |||
| weight_shape = [args.num_classes, args.emb_size] | |||
| weight_np = np.zeros(weight_shape, np.float32) | |||
| self.weight = Parameter(Tensor(weight_np), name='model_parallel_weight') | |||
| @@ -55,8 +55,8 @@ class Net2(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self, strategy1, strategy2): | |||
| super(Net2, self).__init__() | |||
| self.fc1 = P.MatMul().shard(strategy=strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy=strategy2) | |||
| self.fc1 = P.MatMul().shard(strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy2) | |||
| self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1") | |||
| self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2") | |||
| @@ -70,8 +70,8 @@ class Net3(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self, strategy1, strategy2): | |||
| super(Net3, self).__init__() | |||
| self.fc1 = P.MatMul().shard(strategy=strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy=strategy2) | |||
| self.fc1 = P.MatMul().shard(strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy2) | |||
| self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1") | |||
| self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False) | |||
| @@ -42,7 +42,7 @@ def test_sum_as_loss(): | |||
| super().__init__() | |||
| self.fc_nobias = P.MatMul(transpose_b=True).shard(strategy0) | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy1) | |||
| self.mul = P.Mul().shard(strategy=((), ())) | |||
| self.mul = P.Mul().shard(((), ())) | |||
| def construct(self, x, y): | |||
| out = self.fc_nobias(x, y) | |||
| @@ -42,8 +42,8 @@ class Net(nn.Cell): | |||
| class NetWithLoss(nn.Cell): | |||
| def __init__(self, network): | |||
| super(NetWithLoss, self).__init__() | |||
| self.sum = P.ReduceSum(keep_dims=False).shard(strategy=((4, 1, 1, 1),)) | |||
| self.mean = P.ReduceMean(keep_dims=False).shard(strategy=((8, 1, 1, 1),)) | |||
| self.sum = P.ReduceSum(keep_dims=False).shard(((4, 1, 1, 1),)) | |||
| self.mean = P.ReduceMean(keep_dims=False).shard(((8, 1, 1, 1),)) | |||
| self.net = network | |||
| def construct(self, x): | |||
| @@ -29,8 +29,8 @@ class Net1(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self, strategy1, strategy2): | |||
| super(Net1, self).__init__() | |||
| self.fc1 = P.MatMul().shard(strategy=strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy=strategy2) | |||
| self.fc1 = P.MatMul().shard(strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy2) | |||
| self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1") | |||
| self.p2 = Parameter(Tensor(np.ones([64, 48]).astype(np.float32)), name="weight2") | |||
| @@ -45,8 +45,8 @@ class Net2(nn.Cell): | |||
| """Net definition""" | |||
| def __init__(self, strategy1, strategy2): | |||
| super(Net2, self).__init__() | |||
| self.fc1 = P.MatMul().shard(strategy=strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy=strategy2) | |||
| self.fc1 = P.MatMul().shard(strategy1) | |||
| self.fc2 = P.MatMul().shard(strategy2) | |||
| self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1") | |||
| self.p2 = Parameter(Tensor(np.ones([64, 48]).astype(np.float32)), name="weight2") | |||