Browse Source

!25793 add output strategy for shard

Merge pull request !25793 from yangzhenzhang/add-output-strategy
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
7c14f984ed
98 changed files with 141 additions and 1069 deletions
  1. +1
    -1
      mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc
  2. +1
    -1
      mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc
  3. +3
    -69
      mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc
  4. +2
    -8
      mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h
  5. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc
  6. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h
  7. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc
  8. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h
  9. +4
    -26
      mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.cc
  10. +1
    -3
      mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.h
  11. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc
  12. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h
  13. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.cc
  14. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.h
  15. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc
  16. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/concat_info.h
  17. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc
  18. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h
  19. +0
    -20
      mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc
  20. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h
  21. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/dsd_matmul_info.cc
  22. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/dsd_matmul_info.h
  23. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.cc
  24. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.h
  25. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.cc
  26. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.h
  27. +6
    -29
      mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc
  28. +1
    -3
      mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h
  29. +5
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc
  30. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h
  31. +0
    -20
      mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc
  32. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h
  33. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/matmul_dds_info.cc
  34. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/matmul_dds_info.h
  35. +0
    -10
      mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc
  36. +0
    -1
      mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h
  37. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/maxpool_info.cc
  38. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/maxpool_info.h
  39. +0
    -23
      mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc
  40. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h
  41. +21
    -76
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc
  42. +3
    -4
      mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h
  43. +1
    -1
      mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h
  44. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/pack_info.cc
  45. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/pack_info.h
  46. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc
  47. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h
  48. +0
    -20
      mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc
  49. +0
    -3
      mindspore/ccsrc/frontend/parallel/ops_info/range_info.h
  50. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc
  51. +0
    -3
      mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h
  52. +0
    -20
      mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc
  53. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h
  54. +0
    -10
      mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc
  55. +0
    -1
      mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h
  56. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc
  57. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.h
  58. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc
  59. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h
  60. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/select_info.cc
  61. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/select_info.h
  62. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/slice_info.cc
  63. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/slice_info.h
  64. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc
  65. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/split_info.h
  66. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc
  67. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h
  68. +0
    -20
      mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.cc
  69. +0
    -3
      mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.h
  70. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc
  71. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/tile_info.h
  72. +0
    -20
      mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc
  73. +0
    -3
      mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h
  74. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/topk_info.cc
  75. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/topk_info.h
  76. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc
  77. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h
  78. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/uniform_candidate_sampler_info.cc
  79. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/uniform_candidate_sampler_info.h
  80. +0
    -19
      mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.cc
  81. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.h
  82. +0
    -18
      mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc
  83. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h
  84. +0
    -18
      mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc
  85. +0
    -2
      mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h
  86. +1
    -1
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc
  87. +1
    -1
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
  88. +7
    -7
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  89. +26
    -13
      mindspore/ops/primitive.py
  90. +3
    -3
      tests/ut/cpp/parallel/ops_info/onehot_info_test.cc
  91. +3
    -3
      tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc
  92. +7
    -7
      tests/ut/cpp/parallel/step_parallel_test.cc
  93. +2
    -2
      tests/ut/python/parallel/test_add_relu_redistribution.py
  94. +31
    -31
      tests/ut/python/parallel/test_one_hot_net.py
  95. +4
    -4
      tests/ut/python/parallel/test_parallel_optimizer.py
  96. +1
    -1
      tests/ut/python/parallel/test_scalar_loss.py
  97. +2
    -2
      tests/ut/python/parallel/test_semi_auto_two_subgraphs.py
  98. +4
    -4
      tests/ut/python/parallel/test_shared_param_and_mix_precision.py

+ 1
- 1
mindspore/ccsrc/frontend/parallel/auto_parallel/rec_core/rec_generate_strategy.cc View File

@@ -50,7 +50,7 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
// Set user-defined strategy
auto attrs = op->attrs();
if (StrategyFound(attrs)) {
StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs[STRATEGY]);
StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs[IN_STRATEGY]);
op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost());
}
// Set back to raw strategy for special node in predict/eval


+ 1
- 1
mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc View File

@@ -105,7 +105,7 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) {
elements.push_back(MakeValue(input_strategy));
}
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
cnode->AddPrimalAttr(STRATEGY, strategy);
cnode->AddPrimalAttr(IN_STRATEGY, strategy);
}

void InsertVirtualAssignAdd(const std::pair<AnfNodePtr, int> &node_user, const FuncGraphManagerPtr &manager,


+ 3
- 69
mindspore/ccsrc/frontend/parallel/ops_info/activation_info.cc View File

@@ -305,9 +305,9 @@ Status DropoutInfo::InferAsLossDivisor() {
return SUCCESS;
}

Status DropoutInfo::InferReplaceOps() {
void DropoutInfo::InferReplaceOps() {
if ((seed0_ != 0) || (seed1_ != 0) || (repeated_calc_num_ == 1)) {
return SUCCESS;
return;
}
int64_t seed = get_seed();
ValuePtr new_seed0 = MakeValue(seed);
@@ -319,38 +319,6 @@ Status DropoutInfo::InferReplaceOps() {
OperatorParams params;
OperatorArgs args = std::make_pair(attrs, params);
replace_op_ = {std::make_pair(DROPOUT, args)};
return SUCCESS;
}

Status DropoutInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed";
return FAILED;
}
(void)InferReplaceOps();

MS_LOG(INFO) << name_ << " : Init success";
return SUCCESS;
}

Status ActivationBase::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init success.";
return SUCCESS;
}

Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init for cost model success.";
return SUCCESS;
}

Status CastInfo::InferMirrorOps() {
@@ -441,28 +409,6 @@ Status ExpandDimsInfo::InferTensorMap() {
return SUCCESS;
}

Status ExpandDimsInfo::InferTensorStrategy() {
if (strategy_ == nullptr) {
MS_LOG(ERROR) << name_ << ": The strategy is null";
return FAILED;
}

inputs_strategy_ = strategy_->GetInputDim();
if (inputs_strategy_.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED;
}

Shape output_strategy = inputs_strategy_[0];
if ((positive_axis_ < 0) || (positive_axis_ > SizeToLong(output_strategy.size()))) {
MS_LOG(ERROR) << name_ << ": Invalid positive axis " << positive_axis_;
return FAILED;
}
(void)output_strategy.insert(output_strategy.begin() + positive_axis_, NO_SPLIT_STRATEGY);
outputs_strategy_ = {output_strategy};
return SUCCESS;
}

Status ExpandDimsInfo::InferMirrorOps() {
mirror_ops_.clear();

@@ -538,13 +484,12 @@ Status SqueezeInfo::GetAttrs() {
return SUCCESS;
}

Status SqueezeInfo::InferReplaceOps() {
void SqueezeInfo::InferReplaceOps() {
Attr attr = std::make_pair(AXIS, axis_);
OperatorAttrs attrs = {attr};
OperatorParams params;
OperatorArgs args = std::make_pair(attrs, params);
replace_op_ = {std::make_pair(SQUEEZE, args)};
return SUCCESS;
}

Status SqueezeInfo::InferTensorMap() {
@@ -572,16 +517,5 @@ Status SqueezeInfo::InferTensorMap() {

return SUCCESS;
}

Status SqueezeInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";
}

(void)InferReplaceOps();

MS_LOG(INFO) << name_ << " : Init success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 2
- 8
mindspore/ccsrc/frontend/parallel/ops_info/activation_info.h View File

@@ -36,9 +36,6 @@ class ActivationBase : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
~ActivationBase() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

protected:
Status InferMirrorOps() override;
Status InferForwardCommunication() override;
@@ -222,7 +219,6 @@ class ExpandDimsInfo : public ActivationOther {
Status GetAttrs() override;
Status InferTensorMap() override;
Status InferMirrorOps() override;
Status InferTensorStrategy();

private:
int64_t positive_axis_ = -1;
@@ -240,9 +236,8 @@ class SqueezeInfo : public ActivationOther {
protected:
Status InferAxis(const ValueTuplePtr &value_tuple);
Status GetAttrs() override;
Status InferReplaceOps();
void InferReplaceOps() override;
Status InferTensorMap() override;
Status Init(const StrategyPtr &strategy) override;

private:
ValueTuplePtr axis_;
@@ -271,12 +266,11 @@ class DropoutInfo : public ActivationOther {
: ActivationOther(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutCost>()) {}
~DropoutInfo() override = default;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status Init(const StrategyPtr &strategy) override;

protected:
Status GetAttrs() override;
Status InferTensorMap() override;
Status InferReplaceOps();
void InferReplaceOps() override;
Status InferAsLossDivisor() override;

private:


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.cc View File

@@ -191,24 +191,5 @@ std::vector<StrategyPtr> ArithmeticBase::GenerateOpStrategies(int64_t stage_id)

return sp_vector;
}

Status ArithmeticBase::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << " : Init success.";
return SUCCESS;
}

Status ArithmeticBase::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/arithmetic_info.h View File

@@ -35,8 +35,6 @@ class ArithmeticBase : public OperatorInfo {
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {}
~ArithmeticBase() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.cc View File

@@ -126,25 +126,6 @@ Status BatchParallelInfo::GetAttrs() {
return SUCCESS;
}

Status BatchParallelInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << " : Init success.";
return SUCCESS;
}

Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init for cost model success.";
return SUCCESS;
}

Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
return SetCostUnderStrategyBase(strategy);
}


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/batch_parallel_info.h View File

@@ -37,8 +37,6 @@ class BatchParallelInfo : public OperatorInfo {
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()), dev_num_(1) {}

~BatchParallelInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
void ReplaceNodeInputOrAttrs() override;


+ 4
- 26
mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.cc View File

@@ -196,17 +196,17 @@ Status BatchNormInfo::InferForwardCommunication() {
return SUCCESS;
}

Status BatchNormInfo::InferReplaceOps() {
void BatchNormInfo::InferReplaceOps() {
replace_op_.clear();

if (!is_training_) {
MS_LOG(INFO) << name_ << ": It is not training, no need to replace op";
return SUCCESS;
return;
}

if (forward_allreduce_group_.empty()) {
MS_LOG(INFO) << name_ << ": The forward allreduce group is empty, no need to replace op";
return SUCCESS;
return;
}

auto ms_context = MsContext::GetInstance();
@@ -215,7 +215,7 @@ Status BatchNormInfo::InferReplaceOps() {

if (backend != kAscendDevice && backend != kDavinciDevice) {
MS_LOG(INFO) << name_ << ": The backend is " << backend << ", it does not support SyncBatchNorm operator";
return SUCCESS;
return;
}

ValuePtr epsilon = MakeValue(epsilon_);
@@ -232,7 +232,6 @@ Status BatchNormInfo::InferReplaceOps() {
OperatorParams params;
OperatorArgs args = std::make_pair(attrs, params);
replace_op_ = {std::make_pair(SYNC_BATCH_NORM, args)};
return SUCCESS;
}

Status BatchNormInfo::InferAsLossDivisor() {
@@ -261,26 +260,5 @@ std::vector<StrategyPtr> BatchNormInfo::GenerateOpStrategies(int64_t stage_id) {
sp_vector.push_back(sp);
return sp_vector;
}

Status BatchNormInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}

(void)InferReplaceOps();
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status BatchNormInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 1
- 3
mindspore/ccsrc/frontend/parallel/ops_info/batchnorm_info.h View File

@@ -36,8 +36,6 @@ class BatchNormInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
~BatchNormInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;

@@ -47,7 +45,7 @@ class BatchNormInfo : public OperatorInfo {
Status InferForwardCommunication() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferReplaceOps();
void InferReplaceOps() override;
Status InferAsLossDivisor() override;

private:


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.cc View File

@@ -101,24 +101,5 @@ std::vector<StrategyPtr> BiasAddInfo::GenerateOpStrategies(int64_t stage_id) {
}
return sp_vector;
}

Status BiasAddInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << " : Init success.";
return SUCCESS;
}

Status BiasAddInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/bias_add_info.h View File

@@ -37,8 +37,6 @@ class BiasAddInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BiasAddCost>()) {}
~BiasAddInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.cc View File

@@ -176,24 +176,5 @@ ReplaceGraphPtr BroadcastToInfo::replace_graph(const CNodePtr &cnode) {
}
return replace_graph_;
}

Status BroadcastToInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status BroadcastToInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/broadcast_to_info.h View File

@@ -39,8 +39,6 @@ class BroadcastToInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BroadcastToCost>()) {}
~BroadcastToInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.cc View File

@@ -182,24 +182,5 @@ std::vector<StrategyPtr> ConcatInfo::GenerateOpStrategies(int64_t stage_id) {

return sp_vector;
}

Status ConcatInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status ConcatInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/concat_info.h View File

@@ -36,8 +36,6 @@ class ConcatInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>()) {}
~ConcatInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc View File

@@ -837,25 +837,6 @@ std::vector<StrategyPtr> Conv2DInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector;
}

Status Conv2DInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status Conv2DInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

Status Conv2DBackpropInputInfo::GetOutShape() {
if (input_value_.size() != 3) {
MS_LOG(ERROR) << name_ << ": The size of input value must be 3, but got " << input_value_.size();


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h View File

@@ -37,8 +37,6 @@ class Conv2DInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<BatchParallelCost>()) {}
~Conv2DInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;


+ 0
- 20
mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc View File

@@ -117,26 +117,6 @@ std::shared_ptr<Strategys> DropoutDoMaskInfo::GenerateBatchStrategies() {
return std::make_shared<Strategys>(strategy_v);
}

Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success";
return SUCCESS;
}

size_t GetNonMonadInputSize(const CNodePtr &cnode) {
size_t cnode_non_monad_size = cnode->size();
for (auto &input : cnode->inputs()) {


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.h View File

@@ -36,10 +36,8 @@ class DropoutDoMaskInfo : public OperatorInfo {
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DropOutDoMaskCost>()) {}
~DropoutDoMaskInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
std::vector<Operator> GetDropoutGenMaskReplaceOp();
void ReplaceNodeInputOrAttrs() override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/dsd_matmul_info.cc View File

@@ -157,25 +157,6 @@ Status DSDMatmulInfo::GetAttrs() {
return SUCCESS;
}

Status DSDMatmulInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status DSDMatmulInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

std::vector<StrategyPtr> DSDMatmulInfo::GenerateOpStrategies(int64_t stage_id) {
// to generate the first input's strategy
Shape input0_split = {1, 1, 0, 0, 0, 0, 0};


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/dsd_matmul_info.h View File

@@ -37,8 +37,6 @@ class DSDMatmulInfo : public OperatorInfo {
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<DSDMatmulCost>()) {}
~DSDMatmulInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.cc View File

@@ -215,24 +215,5 @@ std::vector<StrategyPtr> GatherDInfo::GenerateOpStrategies(int64_t stage_id) {
}
return sp_vector;
}

Status GatherDInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status GatherDInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.h View File

@@ -36,8 +36,6 @@ class GatherDInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherDCost>()) {}
~GatherDInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.cc View File

@@ -149,24 +149,5 @@ std::vector<StrategyPtr> GatherNdInfo::GenerateOpStrategies(int64_t stage_id) {

return sp_vector;
}

Status GatherNdInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status GatherNdInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/gathernd_info.h View File

@@ -36,8 +36,6 @@ class GatherNdInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherNdCost>()) {}
~GatherNdInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;


+ 6
- 29
mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.cc View File

@@ -103,21 +103,6 @@ Status GetNextInfo::InferDevMatrixShape() {
return SUCCESS;
}

Status GetNextInfo::Init(const StrategyPtr &strategy) {
repeated_num_in_dev_matrix_right_ = false;
if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) {
repeated_num_in_dev_matrix_right_ = true;
}
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed";
return FAILED;
}
InferReplaceOps(strategy);

MS_LOG(INFO) << name_ << " : Init success";
return SUCCESS;
}

Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) {
Strategys stras = strategy->GetInputDim();
for (Dimensions stra : stras) {
@@ -200,6 +185,11 @@ Status GetNextInfo::GetAttrOutPutNum() {
}

Status GetNextInfo::GetAttrs() {
repeated_num_in_dev_matrix_right_ = false;
if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) {
repeated_num_in_dev_matrix_right_ = true;
}

if (GetAttrTypes() == FAILED || GetAttrShapes() == FAILED || GetAttrOutPutNum() == FAILED) {
return FAILED;
}
@@ -210,7 +200,7 @@ Status GetNextInfo::GetAttrs() {
return SUCCESS;
}

void GetNextInfo::InferReplaceOps(const StrategyPtr &) {
void GetNextInfo::InferReplaceOps() {
Shapes out_shapes;
(void)std::transform(outputs_tensor_info_.begin(), outputs_tensor_info_.end(), std::back_inserter(out_shapes),
[](auto tensor_info) { return tensor_info.slice_shape(); });
@@ -225,19 +215,6 @@ void GetNextInfo::InferReplaceOps(const StrategyPtr &) {
replace_op_ = {std::make_pair(GET_NEXT, args)};
}

Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) {
repeated_num_in_dev_matrix_right_ = false;
if (ParallelContext::GetInstance()->dataset_repeat_dim_right()) {
repeated_num_in_dev_matrix_right_ = true;
}
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << " : Init for cost model success.";
return SUCCESS;
}

Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }

std::vector<StrategyPtr> GetNextInfo::GenerateOpStrategies(int64_t stage_id) {


+ 1
- 3
mindspore/ccsrc/frontend/parallel/ops_info/get_next_info.h View File

@@ -35,9 +35,7 @@ class GetNextInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<GetNextCost>()) {}
~GetNextInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;

protected:
@@ -49,7 +47,7 @@ class GetNextInfo : public OperatorInfo {
Status InferDevMatrixShape() override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferForwardCommunication() override { return SUCCESS; }
void InferReplaceOps(const StrategyPtr &strategy);
void InferReplaceOps() override;
Status GetAttrTypes();
Status GetAttrShapes();
Status GetAttrOutPutNum();


+ 5
- 19
mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.cc View File

@@ -30,6 +30,11 @@ namespace parallel {
// if the begin-norm-axis is 3, the shape of second output is: [A, B, C, 1]
// the shape of third output is the same as the shape of second output
Status LayerNormInfo::GetAttrs() {
if (InitShapes() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init Shape failed";
return FAILED;
}

auto iter = attrs_.find(BEGIN_NORM_AXIS);
if (iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attr of begin norm axis";
@@ -239,24 +244,5 @@ Status LayerNormInfo::InitShapes() {
beta_shape_ = inputs_shape_[LAYER_NORM_BETA_INDEX];
return SUCCESS;
}

Status LayerNormInfo::Init(const StrategyPtr &strategy) {
if ((InitShapes() != SUCCESS) || (InitWithAutoRepeatCalc(strategy)) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success";
return SUCCESS;
}

Status LayerNormInfo::InitForCostModel(const StrategyPtr &strategy) {
if ((InitShapes() != SUCCESS) || (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS)) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/layer_norm_info.h View File

@@ -44,8 +44,6 @@ class LayerNormInfo : public OperatorInfo {
begin_norm_axis_(0) {}
~LayerNormInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;



+ 0
- 20
mindspore/ccsrc/frontend/parallel/ops_info/loss_info.cc View File

@@ -105,26 +105,6 @@ Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() {
return SUCCESS;
}

Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init success.";
return SUCCESS;
}

Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init for cost model success.";
return SUCCESS;
}

void SoftmaxCrossEntropyWithLogitsInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
split_flag_list_[i] = true;


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/loss_info.h View File

@@ -38,8 +38,6 @@ class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo {
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<SoftmaxCrossEntropyWithLogitsCost>()) {}
~SoftmaxCrossEntropyWithLogitsInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/matmul_dds_info.cc View File

@@ -265,25 +265,6 @@ ReplaceGraphPtr MatmulDDSInfo::replace_graph(const CNodePtr &cnode) {
return replace_graph_;
}

Status MatmulDDSInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status MatmulDDSInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

std::vector<StrategyPtr> MatmulDDSInfo::GenerateOpStrategies(int64_t stage_id) {
// to generate the first input's strategy
Shape input0_split = {1, 1, 0, 0};


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/matmul_dds_info.h View File

@@ -37,8 +37,6 @@ class MatmulDDSInfo : public OperatorInfo {
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<MatmulDDSCost>()) {}
~MatmulDDSInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;


+ 0
- 10
mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.cc View File

@@ -375,16 +375,6 @@ Status MatMulBase::Init(const StrategyPtr &strategy) {
return SUCCESS;
}

Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init for cost model success.";
return SUCCESS;
}

Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) {
if (input->size() < 2) {
MS_LOG(ERROR) << name_ << " : The size of inputs small than 2.";


+ 0
- 1
mindspore/ccsrc/frontend/parallel/ops_info/matmul_info.h View File

@@ -38,7 +38,6 @@ class MatMulBase : public OperatorInfo {
~MatMulBase() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

// Generate all strategies and the corresponding cost for this MatMul operator
Status GenerateStrategies(int64_t stage_id) override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/maxpool_info.cc View File

@@ -190,24 +190,5 @@ std::vector<StrategyPtr> MaxPoolInfo::GenerateOpStrategies(int64_t stage_id) {
}
return sp_vector;
}

Status MaxPoolInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status MaxPoolInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/maxpool_info.h View File

@@ -36,8 +36,6 @@ class MaxPoolInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<MaxPoolCost>()) {}
~MaxPoolInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;



+ 0
- 23
mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.cc View File

@@ -207,29 +207,6 @@ ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) {
return replace_graph_;
}

Status OneHotInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
Status status = ComputeReplaceGraph(cnode_);
if (status != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
return status;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

std::vector<StrategyPtr> OneHotInfo::GenerateOpStrategies(int64_t stage_id) {
Shapes splittable_inputs = {{1, 1}, {}, {}};
std::vector<StrategyPtr> sp_vector;


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/onehot_info.h View File

@@ -35,8 +35,6 @@ class OneHotInfo : public OperatorInfo {
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<OneHotCost>()) {}
~OneHotInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;


+ 21
- 76
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc View File

@@ -710,6 +710,26 @@ Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Str
return SUCCESS;
}

Status OperatorInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init success.";
return SUCCESS;
}

Status OperatorInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init for cost model success.";
return SUCCESS;
}

// method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1
Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) {
if (strategy == nullptr) {
@@ -770,54 +790,6 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strat
return SUCCESS;
}

// method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape
Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) {
if (strategy == nullptr) {
MS_LOG(ERROR) << name_ << ": The strategy is null.";
return FAILED;
}

if (InferAttrs() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferAttrs failed.";
return FAILED;
}

// must be after InferAttrs()
if (CheckStrategy(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": CheckStrategy failed.";
return FAILED;
}

// need to clear queues before Init(),
// because Init() may be called multiple times by cost model
ResetQueueMember();

strategy_ = strategy;

if (InferDevMatrixShape() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferDevMatrixShape failed.";
return FAILED;
}

// must be after InferDevMatrixShape
if (InferRepeatedCalcInfo() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferRepeatedCalcInfo failed.";
return FAILED;
}

if (InferTensorMap() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferTensorMap failed.";
return FAILED;
}

if (InferTensorInfo() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferTensorInfo failed.";
return FAILED;
}

return SUCCESS;
}

Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) {
if (strategy == nullptr) {
MS_LOG(ERROR) << name_ << ": The strategy is null.";
@@ -843,34 +815,7 @@ Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) {
return FAILED;
}

return SUCCESS;
}

Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) {
if (strategy == nullptr) {
MS_LOG(ERROR) << name_ << ": The strategy is null.";
return FAILED;
}

if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) {
return FAILED;
}

if (InferForwardCommunication() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferForwardCommunication failed.";
return FAILED;
}

if (InferMirrorOps() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferMirrorOps failed.";
return FAILED;
}

if (InferVirtualDivOps() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": InferVirtualDivOps failed.";
return FAILED;
}

InferReplaceOps();
return SUCCESS;
}



+ 3
- 4
mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h View File

@@ -81,8 +81,8 @@ class OperatorInfo {
// If output is tuple, outputs_type.size() is greater than 1.
Status set_outputs_type(const std::vector<TypePtr> &outputs_type);
const std::vector<TypePtr> &outputs_type() const { return outputs_type_; }
virtual Status Init(const StrategyPtr &strategy) = 0;
virtual Status InitForCostModel(const StrategyPtr &strategy) = 0; // only init the necessary parts
virtual Status Init(const StrategyPtr &strategy);
virtual Status InitForCostModel(const StrategyPtr &strategy); // only init the necessary parts

// Given the stage_id (which indicates the number of devices),
// generate all strategies for this operator
@@ -198,6 +198,7 @@ class OperatorInfo {
virtual Status InferDevMatrixShape() = 0;
virtual Status InferMirrorOps();
virtual Status InferTensorInfo();
virtual void InferReplaceOps() {}
Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape);
void SetRepeatedCalcDevMatrix();
void ResetTensorMapIfRepeatedCalc();
@@ -205,9 +206,7 @@ class OperatorInfo {
Status InferAttrs();
void ResetQueueMember();
Status InitWithAutoRepeatCalc(const StrategyPtr &strategy);
Status InitWithManualRepeatCalc(const StrategyPtr &strategy);
Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy);
Status InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy);
Status InferRepeatedCalcInfo();
Status InferVirtualDivOps();



+ 1
- 1
mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h View File

@@ -88,7 +88,7 @@ constexpr double COST_FACTOR = 2.0;
constexpr char AUTO_PARALLEL_RUN_ONCE_ONLY[] = "auto_parallel_run_once_only";
constexpr char SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY[] = "semi_auto_parallel_run_once_only";
constexpr char CHECK_SET_STRATEGY_VALID_ONCE_ONLY[] = "check_set_strategy_valid_once_only";
constexpr char STRATEGY[] = "strategy";
constexpr char IN_STRATEGY[] = "in_strategy";
constexpr char STAGE_ATTR[] = "stage";
constexpr char GEN_STRATEGY[] = "gen_strategy";
constexpr char REDUCE_OP_SUM[] = "sum";


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/pack_info.cc View File

@@ -160,24 +160,5 @@ std::vector<StrategyPtr> StackInfo::GenerateOpStrategies(int64_t stage_id) {

return sp_vector;
}

Status StackInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status StackInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/pack_info.h View File

@@ -36,8 +36,6 @@ class StackInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<PackCost>()) {}
~StackInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.cc View File

@@ -92,25 +92,6 @@ Status PReLUInfo::GetAttrs() {
return SUCCESS;
}

Status PReLUInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status PReLUInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

std::vector<StrategyPtr> PReLUInfo::GenerateOpStrategies(int64_t stage_id) {
Shape input0_split;
input0_split.emplace_back(1);


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/prelu_info.h View File

@@ -37,8 +37,6 @@ class PReLUInfo : public OperatorInfo {
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<PReLUCost>()) {}
~PReLUInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;


+ 0
- 20
mindspore/ccsrc/frontend/parallel/ops_info/range_info.cc View File

@@ -83,26 +83,6 @@ Status RangeInfo::InferTensorMap() {
return SUCCESS;
}

Status RangeInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init success";
return SUCCESS;
}

Status RangeInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success";
return SUCCESS;
}

std::vector<StrategyPtr> RangeInfo::GenerateOpStrategies(int64_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split};


+ 0
- 3
mindspore/ccsrc/frontend/parallel/ops_info/range_info.h View File

@@ -42,9 +42,6 @@ class RangeInfo : public OperatorInfo {
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<RangeCost>()) {}
~RangeInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;



+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc View File

@@ -390,25 +390,6 @@ std::vector<StrategyPtr> ReduceMethod::GenerateOpStrategies(int64_t stage_id) {
return sp_vector;
}

Status ReduceMethod::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}

return SUCCESS;
}

Status ReduceMethod::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success";
return SUCCESS;
}

std::vector<int64_t> ArgMaxWithValueInfo::reduce_dim() {
std::vector<int64_t> dim_list;
auto iter = attrs_.find(AXIS);


+ 0
- 3
mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.h View File

@@ -37,9 +37,6 @@ class ReduceMethod : public OperatorInfo {
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {}
~ReduceMethod() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;



+ 0
- 20
mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.cc View File

@@ -120,25 +120,5 @@ Status ReLUV2Info::InferAsLossDivisor() {
<< as_loss_divisor_;
return SUCCESS;
}

Status ReLUV2Info::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init success.";
return SUCCESS;
}

Status ReLUV2Info::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << " : Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/reluv2_info.h View File

@@ -40,8 +40,6 @@ class ReLUV2Info : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ReLUV2Cost>()) {}
~ReLUV2Info() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;



+ 0
- 10
mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.cc View File

@@ -365,16 +365,6 @@ Status ReshapeInfo::Init(const StrategyPtr &strategy) {
return SUCCESS;
}

Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
return SetCostUnderStrategyBase(strategy);
}


+ 0
- 1
mindspore/ccsrc/frontend/parallel/ops_info/reshape_info.h View File

@@ -61,7 +61,6 @@ class ReshapeInfo : public OperatorInfo {
Status GenetateStrategyCosts(const std::vector<std::shared_ptr<StrategyWithCost>> &pre_stra_costs,
const std::vector<std::shared_ptr<StrategyWithCost>> &next_stra_costs, int64_t out_index,
int64_t in_index, bool is_prev_param, bool is_next_reshape);
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int64_t stage_id) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.cc View File

@@ -118,25 +118,6 @@ std::vector<StrategyPtr> ResizeBilinearInfo::GenerateOpStrategies(int64_t stage_
return sp_vector;
}

Status ResizeBilinearInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status ResizeBilinearInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() {
auto prim = GetValueNode<PrimitivePtr>(cnode_->input(0));
prim->set_attr(SIZE, MakeValue(slice_size_));


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/resizebilinear_info.h View File

@@ -36,8 +36,6 @@ class ResizeBilinearInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ResizeBilinearCost>()) {}
~ResizeBilinearInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;



+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.cc View File

@@ -176,24 +176,5 @@ std::vector<StrategyPtr> ScatterUpdateInfo::GenerateOpStrategies(int64_t stage_i

return sp_vector;
}

Status ScatterUpdateInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status ScatterUpdateInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/scatter_update_info.h View File

@@ -36,8 +36,6 @@ class ScatterUpdateInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ScatterUpdateCost>()) {}
~ScatterUpdateInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/select_info.cc View File

@@ -125,24 +125,5 @@ std::vector<StrategyPtr> SelectInfo::GenerateOpStrategies(int64_t stage_id) {

return sp_vector;
}

Status SelectInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status SelectInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/select_info.h View File

@@ -36,8 +36,6 @@ class SelectInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SelectCost>()) {}
~SelectInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/slice_info.cc View File

@@ -174,25 +174,6 @@ std::vector<StrategyPtr> SliceInfo::GenerateOpStrategies(int64_t stage_id) {
return sp_vector;
}

Status SliceInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status SliceInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

ReplaceGraphPtr SliceInfo::replace_graph(const CNodePtr &cnode) {
auto input_strategy = strategy_->GetInputDim().at(0);
if (std::any_of(input_strategy.begin(), input_strategy.end(), [](const int64_t &shard) { return shard > 1; })) {


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/slice_info.h View File

@@ -38,8 +38,6 @@ class SliceInfo : public OperatorInfo {
slice_axis_(-1) {}
~SliceInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/split_info.cc View File

@@ -196,24 +196,5 @@ Status SplitInfo::InferAsLossDivisor() {
<< as_loss_divisor_;
return SUCCESS;
}

Status SplitInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status SplitInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/split_info.h View File

@@ -35,8 +35,6 @@ class SplitInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<SplitCost>()) {}
~SplitInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
Status SetCostUnderStrategy(const StrategyPtr &) override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.cc View File

@@ -234,24 +234,5 @@ std::vector<StrategyPtr> StridedSliceInfo::GenerateOpStrategies(int64_t stage_id

return sp_vector;
}

Status StridedSliceInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status StridedSliceInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/strided_slice_info.h View File

@@ -37,8 +37,6 @@ class StridedSliceInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<StridedSliceCost>()) {}
~StridedSliceInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;


+ 0
- 20
mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.cc View File

@@ -306,26 +306,6 @@ Status TensorDotInfo::InferTensorMap() {
return SUCCESS;
}

Status TensorDotInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init success";
return SUCCESS;
}

Status TensorDotInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success";
return SUCCESS;
}

std::shared_ptr<Strategys> TensorDotInfo::GenerateBatchStrategies() {
if (GetAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Get attr failed";


+ 0
- 3
mindspore/ccsrc/frontend/parallel/ops_info/tensordot_info.h View File

@@ -44,9 +44,6 @@ class TensorDotInfo : public OperatorInfo {
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TensorDotCost>()) {}
~TensorDotInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size,


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/tile_info.cc View File

@@ -208,24 +208,5 @@ std::vector<StrategyPtr> TileInfo::GenerateOpStrategies(int64_t stage_id) {

return sp_vector;
}

Status TileInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status TileInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/tile_info.h View File

@@ -37,8 +37,6 @@ class TileInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TileCost>()) {}
~TileInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;


+ 0
- 20
mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.cc View File

@@ -48,26 +48,6 @@ Status TmpIdentityInfo::InferTensorMap() {
return SUCCESS;
}

Status TmpIdentityInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status TmpIdentityInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

Status TmpIdentityInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }

std::vector<StrategyPtr> TmpIdentityInfo::GenerateOpStrategies(int64_t stage_id) {


+ 0
- 3
mindspore/ccsrc/frontend/parallel/ops_info/tmp_identity_info.h View File

@@ -37,9 +37,6 @@ class TmpIdentityInfo : public OperatorInfo {
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TmpIdentityCost>()) {}
~TmpIdentityInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;



+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/topk_info.cc View File

@@ -164,24 +164,5 @@ std::vector<StrategyPtr> TopKInfo::GenerateOpStrategies(int64_t stage_id) {

return sp_vector;
}

Status TopKInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status TopKInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore

+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/topk_info.h View File

@@ -37,8 +37,6 @@ class TopKInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<TopKCost>()) {}
~TopKInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;



+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.cc View File

@@ -111,25 +111,6 @@ Status TransposeInfo::InferTensorMap() {
// compute axis_v_ during this method
Status TransposeInfo::GetAttrs() { return ComputeAxis(); }

Status TransposeInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) {
return SetCostUnderStrategyBase(strategy);
}


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/transpose_info.h View File

@@ -37,8 +37,6 @@ class TransposeInfo : public OperatorInfo {
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<TransposeCost>()) {}
~TransposeInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;



+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/uniform_candidate_sampler_info.cc View File

@@ -181,25 +181,6 @@ std::shared_ptr<Strategys> UniformCandidateSamplerInfo::GenerateBatchStrategies(
return std::make_shared<Strategys>(strategy_v);
}

Status UniformCandidateSamplerInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status UniformCandidateSamplerInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

ReplaceGraphPtr UniformCandidateSamplerInfo::replace_graph(const CNodePtr &cnode) {
auto input_strategy = strategy_->GetInputDim().at(0);
// Only when the axis-1 is sharded, we need to modify the attribute


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/uniform_candidate_sampler_info.h View File

@@ -43,8 +43,6 @@ class UniformCandidateSamplerInfo : public OperatorInfo {
remove_accidental_hits_(false) {}
~UniformCandidateSamplerInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
std::shared_ptr<Strategys> GenerateBatchStrategies() override;
Status SetCostUnderStrategy(const StrategyPtr &) override;


+ 0
- 19
mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.cc View File

@@ -102,25 +102,6 @@ std::vector<StrategyPtr> UniformRealInfo::GenerateOpStrategies(int64_t stage_id)
return sp_vector;
}

Status UniformRealInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status UniformRealInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}

MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

void UniformRealInfo::UpdateShape(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = cnode->input(1)->cast<ValueNodePtr>();


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/uniform_real_info.h View File

@@ -36,8 +36,6 @@ class UniformRealInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<MaxPoolCost>()) {}
~UniformRealInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void UpdateShape(const CNodePtr &cnode);


+ 0
- 18
mindspore/ccsrc/frontend/parallel/ops_info/unique_info.cc View File

@@ -58,15 +58,6 @@ Status UniqueInfo::InferDevMatrixShape() {
return SUCCESS;
}

Status UniqueInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init failed";
return FAILED;
}
MS_LOG(INFO) << name_ << " : Init success";
return SUCCESS;
}

Status UniqueInfo::CheckStrategy(const StrategyPtr &strategy) {
Strategys stras = strategy->GetInputDim();
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
@@ -96,15 +87,6 @@ Status UniqueInfo::GetAttrs() {
return SUCCESS;
}

Status UniqueInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << " : Init for cost model success.";
return SUCCESS;
}

Status UniqueInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }

std::vector<StrategyPtr> UniqueInfo::GenerateOpStrategies(int64_t stage_id) {


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/unique_info.h View File

@@ -35,9 +35,7 @@ class UniqueInfo : public OperatorInfo {
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<UniqueCost>()) {}
~UniqueInfo() override = default;

Status Init(const StrategyPtr &strategy) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;



+ 0
- 18
mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc View File

@@ -163,24 +163,6 @@ Status UnsortedSegmentOpInfo::InferTensorMap() {
return SUCCESS;
}

Status UnsortedSegmentOpInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}

Status UnsortedSegmentOpInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}

// Set the default strategy
std::vector<StrategyPtr> UnsortedSegmentOpInfo::GenerateOpStrategies(int64_t stage_id) {
Shape input0_split(inputs_shape_[0].size(), 1);


+ 0
- 2
mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h View File

@@ -44,8 +44,6 @@ class UnsortedSegmentOpInfo : public OperatorInfo {
const PrimitiveAttrs &attrs, OperatorCostPtr cost)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost) {}
~UnsortedSegmentOpInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;

std::vector<StrategyPtr> GenerateOpStrategies(int64_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;


+ 1
- 1
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc View File

@@ -335,7 +335,7 @@ OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode, int tup
if (!StrategyFound(attrs)) {
strategy = GenerateBatchParallelStrategy(op_info, prim);
} else {
strategy = ExtractStrategy(attrs[STRATEGY]);
strategy = ExtractStrategy(attrs[IN_STRATEGY]);
}
MS_EXCEPTION_IF_NULL(strategy);
if (op_info->Init(strategy) == FAILED) {


+ 1
- 1
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc View File

@@ -253,7 +253,7 @@ void SetStrategyToOperator(const OperatorInfoPtr &operator_info, const Primitive
// In this case, the configured strategy should be extracted to help setting cost
StrategyPtr strategyPtr;
if (StrategyFound(attrs)) {
strategyPtr = parallel::ExtractStrategy(attrs[STRATEGY]);
strategyPtr = parallel::ExtractStrategy(attrs[IN_STRATEGY]);
} else {
strategyPtr = (*stra_map)[strategy_key_name];
}


+ 7
- 7
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -484,7 +484,7 @@ void Redistribution(const std::pair<AnfNodePtr, int64_t> &node_pair, const Opera
}

bool StrategyFound(const std::unordered_map<std::string, ValuePtr> &attrs) {
auto iter = attrs.find(STRATEGY);
auto iter = attrs.find(IN_STRATEGY);
return !((iter == attrs.end()) || (iter->second->type_name() == NONE));
}

@@ -1760,7 +1760,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
(void)std::transform(dataset_strategy.begin(), dataset_strategy.end(), std::back_inserter(elements),
[](auto input_stra) { return MakeValue(input_stra); });
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
attrs_temp[STRATEGY] = strategy;
attrs_temp[IN_STRATEGY] = strategy;
(void)prim->SetAttrs(attrs_temp);
if (prim->HasAttr(REPEAT_DIM_DIRECT) && GetValue<std::string>(prim->GetAttr(REPEAT_DIM_DIRECT)) == RIGHT) {
ParallelContext::GetInstance()->set_dataset_repeat_dim_right(true);
@@ -1798,7 +1798,7 @@ void SetVirtualDatasetStrategy(const CNodePtr &node) {
elements.push_back(MakeValue(input_strategy));
}
ValueTuplePtr strategy = std::make_shared<ValueTuple>(elements);
attrs_temp[STRATEGY] = strategy;
attrs_temp[IN_STRATEGY] = strategy;
(void)prim->SetAttrs(attrs_temp);
}
}
@@ -1954,14 +1954,14 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
}
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
if ((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(STRATEGY)) {
if ((!StrategyFound(attrs) && !load_strategy_from_ckpt) && !cnode->HasPrimalAttr(IN_STRATEGY)) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
strategyPtr = GenerateBatchParallelStrategy(operator_, prim);
} else if (cnode->HasPrimalAttr(STRATEGY)) {
strategyPtr = ExtractStrategy(cnode->GetPrimalAttr(STRATEGY));
} else if (cnode->HasPrimalAttr(IN_STRATEGY)) {
strategyPtr = ExtractStrategy(cnode->GetPrimalAttr(IN_STRATEGY));
} else if (StrategyFound(attrs)) {
strategyPtr = ExtractStrategy(attrs[STRATEGY]);
strategyPtr = ExtractStrategy(attrs[IN_STRATEGY]);
} else {
strategyPtr = stra_map[strategy_key_name];
}


+ 26
- 13
mindspore/ops/primitive.py View File

@@ -151,7 +151,7 @@ class Primitive(Primitive_):
self.add_prim_attr("stage", stage)
return self

def shard(self, strategy):
def shard(self, in_strategy=None, out_strategy=None):
"""
Add strategies to primitive attribute.

@@ -160,24 +160,37 @@ class Primitive(Primitive_):
In other parallel modes, strategies set here will be ignored.

Args:
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
in_strategy (tuple): Describe the split strategy of operator input.
out_strategy (tuple): Describe the split strategy of operator output,
it is only for certain operators, such as MatMul.
Examples:
>>> from mindspore import ops
>>> add = ops.Add()
>>> print(add.shard(((1, 1), (1, 1))))
Prim[Add]<strategy=((1, 1), (1, 1))>
Prim[Add]<in_strategy=((1, 1), (1, 1)), out_strategy=None>
"""
mode = context.get_auto_parallel_context("parallel_mode")
if strategy is not None:
if not isinstance(strategy, tuple):
raise TypeError(f'strategy must be tuple type, but got:{type(strategy)}')
for ele in strategy:
if not isinstance(ele, tuple):
raise TypeError(f'The element of strategy must be tuple type, but got:{type(ele)}')
if not _is_in_auto_parallel_mode() and strategy:
logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. "
f"Please use semi auto or auto parallel mode.")
self.add_prim_attr("strategy", strategy)
if in_strategy is not None:
if not isinstance(in_strategy, tuple):
raise TypeError(f'in_strategy must be tuple type, but got:{type(in_strategy)}')
for in_ele in in_strategy:
if not isinstance(in_ele, tuple):
raise TypeError(f'The element of strategy must be tuple type, but got:{type(in_ele)}')
if out_strategy is not None:
if not isinstance(out_strategy, tuple):
raise TypeError(f'out strategy must be tuple type, but got:{type(out_strategy)}')
for out_ele in out_strategy:
if not isinstance(out_ele, tuple):
raise TypeError(f'The element of strategy must be tuple type, but got:{type(out_ele)}')
if not _is_in_auto_parallel_mode():
if in_strategy is not None:
logger.warning(f"The in_strategy: {in_strategy} of {self.name} is not valid in {mode}. "
f"Please use semi auto or auto parallel mode.")
if out_strategy is not None:
logger.warning(f"The out_strategy: {out_strategy} of {self.name} is not valid in {mode}. "
f"Please use semi auto or auto parallel mode.")
self.add_prim_attr("in_strategy", in_strategy)
self.add_prim_attr("out_strategy", out_strategy)
return self

def set_prim_instance_name(self, instance_name):


+ 3
- 3
tests/ut/cpp/parallel/ops_info/onehot_info_test.cc View File

@@ -92,7 +92,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape3) {
StrategyPtr strategy = NewStrategy(0, inputs);

Status status = onehot_info->Init(strategy);
ASSERT_EQ(status, FAILED);
ASSERT_EQ(status, SUCCESS);
Shape dev_matrix_shape = onehot_info->dev_matrix_shape();

Shape expect = {4, 2};
@@ -148,7 +148,7 @@ TEST_F(TestOneHotInfo, InferSliceShape2) {
StrategyPtr strategy = NewStrategy(0, str);

Status status = onehot_info->Init(strategy);
ASSERT_EQ(status, FAILED);
ASSERT_EQ(status, SUCCESS);
std::vector<TensorInfo> inputs = onehot_info->inputs_tensor_info();
std::vector<TensorInfo> outputs = onehot_info->outputs_tensor_info();

@@ -170,7 +170,7 @@ TEST_F(TestOneHotInfo, InferSliceShape3) {
StrategyPtr strategy = NewStrategy(0, str);

Status status = onehot_info->Init(strategy);
ASSERT_EQ(status, FAILED);
ASSERT_EQ(status, SUCCESS);
std::vector<TensorInfo> inputs = onehot_info->inputs_tensor_info();
std::vector<TensorInfo> outputs = onehot_info->outputs_tensor_info();



+ 3
- 3
tests/ut/cpp/parallel/ops_info/onehot_info_test_axis_0.cc View File

@@ -92,7 +92,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape3) {
StrategyPtr strategy = NewStrategy(0, inputs);

Status status = onehot_info2->Init(strategy);
ASSERT_EQ(status, FAILED);
ASSERT_EQ(status, SUCCESS);
Shape dev_matrix_shape = onehot_info2->dev_matrix_shape();

Shape expect = {4, 2};
@@ -148,7 +148,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape2) {
StrategyPtr strategy = NewStrategy(0, str);

Status status = onehot_info2->Init(strategy);
ASSERT_EQ(status, FAILED);
ASSERT_EQ(status, SUCCESS);
std::vector<TensorInfo> inputs = onehot_info2->inputs_tensor_info();
std::vector<TensorInfo> outputs = onehot_info2->outputs_tensor_info();

@@ -170,7 +170,7 @@ TEST_F(TestOneHotInfo2, InferSliceShape3) {
StrategyPtr strategy = NewStrategy(0, str);

Status status = onehot_info2->Init(strategy);
ASSERT_EQ(status, FAILED);
ASSERT_EQ(status, SUCCESS);
std::vector<TensorInfo> inputs = onehot_info2->inputs_tensor_info();
std::vector<TensorInfo> outputs = onehot_info2->outputs_tensor_info();



+ 7
- 7
tests/ut/cpp/parallel/step_parallel_test.cc View File

@@ -154,7 +154,7 @@ FuncGraphManagerPtr Make_Manager(int64_t condition = 0) {
prim1->AddAttr("transpose_a", transpose_a);
prim1->AddAttr("transpose_b", transpose_b);
prim1->AddAttr("instance_name", MakeValue("matmul1"));
prim1->AddAttr("strategy", var);
prim1->AddAttr("in_strategy", var);
inputs.clear();
Dimensions v3 = {2, 2};
Dimensions v4 = {2, 4};
@@ -176,16 +176,16 @@ FuncGraphManagerPtr Make_Manager(int64_t condition = 0) {
prim2->AddAttr("transpose_a", transpose_a);
prim2->AddAttr("transpose_b", transpose_b);
prim2->AddAttr("instance_name", MakeValue("matmul2"));
prim2->AddAttr("strategy", var2);
prim2->AddAttr("in_strategy", var2);
switch (condition) {
case 1: {
prim1->set_attr("strategy", MakeValue(static_cast<int64_t>(0)));
prim1->set_attr("in_strategy", MakeValue(static_cast<int64_t>(0)));
break;
}
case 2: {
std::vector<ValuePtr> elements_t = {MakeValue(static_cast<int64_t>(0))};
ValueTuplePtr var_t = std::make_shared<ValueTuple>(elements_t);
prim1->set_attr("strategy", var_t);
prim1->set_attr("in_strategy", var_t);
break;
}
case 3: {
@@ -193,7 +193,7 @@ FuncGraphManagerPtr Make_Manager(int64_t condition = 0) {
Dimensions vt2 = {2, 4};
std::vector<ValuePtr> elements_t2 = {MakeValue(vt1), MakeValue(vt2)};
ValueTuplePtr var_t2 = std::make_shared<ValueTuple>(elements_t2);
prim1->set_attr("strategy", var_t2);
prim1->set_attr("in_strategy", var_t2);
break;
}
}
@@ -226,9 +226,9 @@ TEST_F(TestStepParallel, ExtractStrategy) {
ValuePtr val2 = MakeValue(v2);
std::vector<ValuePtr> elements = {val1, val2};
ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(elements);
attrs["strategy"] = strategy_tuple;
attrs["in_strategy"] = strategy_tuple;
Strategys strategy_expect = {v1, v2};
StrategyPtr strategy = ExtractStrategy(attrs["strategy"]);
StrategyPtr strategy = ExtractStrategy(attrs["in_strategy"]);
Strategys strategy_test = strategy->GetInputDim();

ASSERT_EQ(strategy_expect, strategy_test);


+ 2
- 2
tests/ut/python/parallel/test_add_relu_redistribution.py View File

@@ -29,8 +29,8 @@ grad_all = C.GradOperation(get_all=True)
class AddRelu(nn.Cell):
def __init__(self, strategy0=None, strategy1=None):
super(AddRelu, self).__init__()
self.add = P.Add().shard(strategy=strategy0)
self.relu = P.ReLU().shard(strategy=strategy1)
self.add = P.Add().shard(strategy0)
self.relu = P.ReLU().shard(strategy1)

def construct(self, x, z):
out = self.add(x, z)


+ 31
- 31
tests/ut/python/parallel/test_one_hot_net.py View File

@@ -85,15 +85,15 @@ class SemiAutoOneHotNet(Cell):
self.d = args.d
self.e = args.e
self.cast = P.Cast()
self.cast.shard(strategy=strategy.twod_strategy)
self.cast.shard(strategy.twod_strategy)
self.cast1 = P.Cast()
self.cast1.shard(strategy=strategy.twod_strategy)
self.cast1.shard(strategy.twod_strategy)
self.cast2 = P.Cast()
self.cast2.shard(strategy=strategy.twod_strategy)
self.cast2.shard(strategy.twod_strategy)
self.cast3 = P.Cast()
self.cast3.shard(strategy=strategy.scalar_strategy)
self.cast3.shard(strategy.scalar_strategy)
self.cast4 = P.Cast()
self.cast4.shard(strategy=strategy.scalar_strategy)
self.cast4.shard(strategy.scalar_strategy)
self.a_const = Tensor(self.a, dtype=mstype.float32)
self.b_const = Tensor(self.b, dtype=mstype.float32)
self.c_const = Tensor(self.c, dtype=mstype.float32)
@@ -102,64 +102,64 @@ class SemiAutoOneHotNet(Cell):
self.m_const_zero = Tensor(0, dtype=mstype.float32)
self.a_const_one = Tensor(1, dtype=mstype.float32)
self.onehot = P.OneHot()
self.onehot.shard(strategy=strategy.onehot_strategy)
self.onehot.shard(strategy.onehot_strategy)
self.exp = P.Exp()
self.exp.shard(strategy=strategy.twod_strategy)
self.exp.shard(strategy.twod_strategy)
self.exp2 = P.Exp()
self.exp2.shard(strategy=strategy.twod_strategy)
self.exp2.shard(strategy.twod_strategy)
self.exp3 = P.Exp()
self.exp3.shard(strategy=strategy.twod_strategy)
self.exp3.shard(strategy.twod_strategy)
self.mul_const = P.Mul()
self.mul_const.shard(strategy=strategy.scalar_twod_strategy)
self.mul_const.shard(strategy.scalar_twod_strategy)
self.mul_const2 = P.Add()
self.mul_const2.shard(strategy=strategy.scalar_twod_strategy)
self.mul_const2.shard(strategy.scalar_twod_strategy)
self.mul_const3 = P.Sub()
self.mul_const3.shard(strategy=strategy.twod_scalar_strategy)
self.mul_const3.shard(strategy.twod_scalar_strategy)
self.mul_const4 = P.Sub()
self.mul_const4.shard(strategy=strategy.scalar_twod_strategy)
self.mul_const4.shard(strategy.scalar_twod_strategy)
self.mul_const5 = P.Mul()
self.mul_const5.shard(strategy=strategy.twod_scalar_strategy)
self.mul_const5.shard(strategy.twod_scalar_strategy)
self.mul = P.Mul()
self.mul.shard(strategy=strategy.twod_twod_strategy)
self.mul.shard(strategy.twod_twod_strategy)
self.mul2 = P.Mul()
self.mul2.shard(strategy=strategy.twod_twod_strategy)
self.mul2.shard(strategy.twod_twod_strategy)
self.mul3 = P.Add()
self.mul3.shard(strategy=strategy.twod_twod_strategy)
self.mul3.shard(strategy.twod_twod_strategy)
self.mul4 = P.Sub()
self.mul4.shard(strategy=strategy.twod_twodbc_strategy)
self.mul4.shard(strategy.twod_twodbc_strategy)
self.mul5 = P.RealDiv()
self.mul5.shard(strategy=strategy.twod_twodbc_strategy)
self.mul5.shard(strategy.twod_twodbc_strategy)
self.mul6 = P.Mul()
self.mul6.shard(strategy=strategy.twod_twod_strategy)
self.mul6.shard(strategy.twod_twod_strategy)
self.mul7 = P.Mul()
self.mul7.shard(strategy=strategy.twod_scalar_strategy)
self.mul7.shard(strategy.twod_scalar_strategy)
self.mul8 = P.RealDiv()
self.mul8.shard(strategy=strategy.scalar_scalar_strategy)
self.mul8.shard(strategy.scalar_scalar_strategy)
self.mul9 = P.Add()
self.mul9.shard(strategy=strategy.twod_scalar_strategy)
self.mul9.shard(strategy.twod_scalar_strategy)

self.reduce_max = P.ReduceMax(keep_dims=True)
self.reduce_max.shard(strategy=strategy.twod_strategy)
self.reduce_max.shard(strategy.twod_strategy)

self.reduce_sum = P.ReduceSum(keep_dims=False)
self.reduce_sum.shard(strategy=strategy.twod_strategy)
self.reduce_sum.shard(strategy.twod_strategy)
self.reduce_sum_2 = P.ReduceSum(keep_dims=False)
self.reduce_sum_2.shard(strategy=strategy.twod_strategy)
self.reduce_sum_2.shard(strategy.twod_strategy)
self.reduce_sum_3 = P.ReduceSum(keep_dims=False)
self.reduce_sum_3.shard(strategy=strategy.oned_strategy)
self.reduce_sum_3.shard(strategy.oned_strategy)

self.reshape = P.Reshape()
self.log = P.Log()
self.log.shard(strategy=strategy.twod_strategy)
self.log.shard(strategy.twod_strategy)

self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.normalize = P.L2Normalize(axis=1)
self.normalize.shard(strategy=strategy.twod_strategy_m)
self.normalize.shard(strategy.twod_strategy_m)
self.normalize2 = P.L2Normalize(axis=1)
self.normalize2.shard(strategy=strategy.twod_strategy_m)
self.normalize2.shard(strategy.twod_strategy_m)
self.fc = P.MatMul(transpose_b=True)
self.fc.shard(strategy=strategy.twodbc_twod_strategy)
self.fc.shard(strategy.twodbc_twod_strategy)
weight_shape = [args.num_classes, args.emb_size]
weight_np = np.zeros(weight_shape, np.float32)
self.weight = Parameter(Tensor(weight_np), name='model_parallel_weight')


+ 4
- 4
tests/ut/python/parallel/test_parallel_optimizer.py View File

@@ -55,8 +55,8 @@ class Net2(nn.Cell):
"""Net definition"""
def __init__(self, strategy1, strategy2):
super(Net2, self).__init__()
self.fc1 = P.MatMul().shard(strategy=strategy1)
self.fc2 = P.MatMul().shard(strategy=strategy2)
self.fc1 = P.MatMul().shard(strategy1)
self.fc2 = P.MatMul().shard(strategy2)
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2")

@@ -70,8 +70,8 @@ class Net3(nn.Cell):
"""Net definition"""
def __init__(self, strategy1, strategy2):
super(Net3, self).__init__()
self.fc1 = P.MatMul().shard(strategy=strategy1)
self.fc2 = P.MatMul().shard(strategy=strategy2)
self.fc1 = P.MatMul().shard(strategy1)
self.fc2 = P.MatMul().shard(strategy2)
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
self.p2 = Parameter(Tensor(np.ones([64, 16]).astype(np.float32)), name="weight2", parallel_optimizer=False)



+ 1
- 1
tests/ut/python/parallel/test_scalar_loss.py View File

@@ -42,7 +42,7 @@ def test_sum_as_loss():
super().__init__()
self.fc_nobias = P.MatMul(transpose_b=True).shard(strategy0)
self.reduce_sum = P.ReduceSum(keep_dims=False).shard(strategy1)
self.mul = P.Mul().shard(strategy=((), ()))
self.mul = P.Mul().shard(((), ()))

def construct(self, x, y):
out = self.fc_nobias(x, y)


+ 2
- 2
tests/ut/python/parallel/test_semi_auto_two_subgraphs.py View File

@@ -42,8 +42,8 @@ class Net(nn.Cell):
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.sum = P.ReduceSum(keep_dims=False).shard(strategy=((4, 1, 1, 1),))
self.mean = P.ReduceMean(keep_dims=False).shard(strategy=((8, 1, 1, 1),))
self.sum = P.ReduceSum(keep_dims=False).shard(((4, 1, 1, 1),))
self.mean = P.ReduceMean(keep_dims=False).shard(((8, 1, 1, 1),))
self.net = network

def construct(self, x):


+ 4
- 4
tests/ut/python/parallel/test_shared_param_and_mix_precision.py View File

@@ -29,8 +29,8 @@ class Net1(nn.Cell):
"""Net definition"""
def __init__(self, strategy1, strategy2):
super(Net1, self).__init__()
self.fc1 = P.MatMul().shard(strategy=strategy1)
self.fc2 = P.MatMul().shard(strategy=strategy2)
self.fc1 = P.MatMul().shard(strategy1)
self.fc2 = P.MatMul().shard(strategy2)
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
self.p2 = Parameter(Tensor(np.ones([64, 48]).astype(np.float32)), name="weight2")

@@ -45,8 +45,8 @@ class Net2(nn.Cell):
"""Net definition"""
def __init__(self, strategy1, strategy2):
super(Net2, self).__init__()
self.fc1 = P.MatMul().shard(strategy=strategy1)
self.fc2 = P.MatMul().shard(strategy=strategy2)
self.fc1 = P.MatMul().shard(strategy1)
self.fc2 = P.MatMul().shard(strategy2)
self.p1 = Parameter(Tensor(np.ones([48, 64]).astype(np.float32)), name="weight1")
self.p2 = Parameter(Tensor(np.ones([64, 48]).astype(np.float32)), name="weight2")



Loading…
Cancel
Save