Browse Source

fix gatherv2 replace graph in auto parallel

tags/v0.3.0-alpha
yao_yf 5 years ago
parent
commit
06d35d8d18
3 changed files with 14 additions and 9 deletions
  1. +9
    -5
      mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc
  2. +1
    -0
      mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h
  3. +4
    -4
      mindspore/ccsrc/parallel/ops_info/reshape_info.cc

+ 9
- 5
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc View File

@@ -258,16 +258,20 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return SUCCESS; return SUCCESS;
} }


ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
auto param_strategy = strategy_->GetInputDim().at(0);
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
return nullptr;
}
return replace_graph_;
}

Status GatherV2PInfo::Init(const StrategyPtr &strategy) { Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
auto param_strategy = strategy->GetInputDim().at(0);
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed."; MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED; return FAILED;
} }
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success."; MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS; return SUCCESS;
} }


+ 1
- 0
mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h View File

@@ -43,6 +43,7 @@ class GatherV2PInfo : public OperatorInfo {


Status GenerateStrategies(int32_t stage_id) override; Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;


protected: protected:


+ 4
- 4
mindspore/ccsrc/parallel/ops_info/reshape_info.cc View File

@@ -138,9 +138,9 @@ Status ReshapeInfo::ComputeReplaceOp() {
MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed."; MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed.";
return FAILED; return FAILED;
} }
MS_LOG(INFO) << name_ << ": input " << input_layout_.ToString();
MS_LOG(INFO) << name_ << ": output " << output_layout_.ToString();
MS_LOG(INFO) << name_ << ": dev_list " << dev_list.size();
MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString();
MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString();
MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size();
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
if (redistribution_oplist_ptr == nullptr) { if (redistribution_oplist_ptr == nullptr) {
MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
@@ -148,7 +148,7 @@ Status ReshapeInfo::ComputeReplaceOp() {
} }
replace_op_ = redistribution_oplist_ptr->first; replace_op_ = redistribution_oplist_ptr->first;
replace_op_info_ = redistribution_oplist_ptr->second; replace_op_info_ = redistribution_oplist_ptr->second;
MS_LOG(INFO) << name_ << ": replace op size = " << replace_op_.size();
MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size();
return SUCCESS; return SUCCESS;
} }




Loading…
Cancel
Save