From 06d35d8d1825ce02c642f40c22c33d786d63bcec Mon Sep 17 00:00:00 2001 From: yao_yf Date: Fri, 15 May 2020 15:39:23 +0800 Subject: [PATCH] fix gatherv2 replace graph in auto parallel --- .../ccsrc/parallel/ops_info/gather_v2_p_info.cc | 14 +++++++++----- .../ccsrc/parallel/ops_info/gather_v2_p_info.h | 1 + mindspore/ccsrc/parallel/ops_info/reshape_info.cc | 8 ++++---- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc index 15d34c6677..5c7473cc90 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.cc @@ -258,16 +258,20 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { return SUCCESS; } +ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { + auto param_strategy = strategy_->GetInputDim().at(0); + if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { + MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; + return nullptr; + } + return replace_graph_; +} + Status GatherV2PInfo::Init(const StrategyPtr &strategy) { - auto param_strategy = strategy->GetInputDim().at(0); if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; } - if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode_) != SUCCESS) { - MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; - return FAILED; - } MS_LOG(INFO) << name_ << ": Init success."; return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h index 62553b5588..f05c3c171c 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_p_info.h @@ -43,6 +43,7 @@ class GatherV2PInfo : public OperatorInfo { Status GenerateStrategies(int32_t stage_id) override; Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; std::shared_ptr>> GenerateBatchStrategies() override; protected: diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc index f663eaa6a2..c470c379ff 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc @@ -138,9 +138,9 @@ Status ReshapeInfo::ComputeReplaceOp() { MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed."; return FAILED; } - MS_LOG(INFO) << name_ << ": input " << input_layout_.ToString(); - MS_LOG(INFO) << name_ << ": output " << output_layout_.ToString(); - MS_LOG(INFO) << name_ << ": dev_list " << dev_list.size(); + MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); + MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); + MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); if (redistribution_oplist_ptr == nullptr) { MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; @@ -148,7 +148,7 @@ Status ReshapeInfo::ComputeReplaceOp() { } replace_op_ = redistribution_oplist_ptr->first; replace_op_info_ = redistribution_oplist_ptr->second; - MS_LOG(INFO) << name_ << ": replace op size = " << replace_op_.size(); + MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); return SUCCESS; }