| @@ -258,16 +258,20 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) { | |||||
| auto param_strategy = strategy_->GetInputDim().at(0); | |||||
| if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; | |||||
| return nullptr; | |||||
| } | |||||
| return replace_graph_; | |||||
| } | |||||
| Status GatherV2PInfo::Init(const StrategyPtr &strategy) { | Status GatherV2PInfo::Init(const StrategyPtr &strategy) { | ||||
| auto param_strategy = strategy->GetInputDim().at(0); | |||||
| if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { | ||||
| MS_LOG(ERROR) << name_ << ": Init failed."; | MS_LOG(ERROR) << name_ << ": Init failed."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode_) != SUCCESS) { | |||||
| MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; | |||||
| return FAILED; | |||||
| } | |||||
| MS_LOG(INFO) << name_ << ": Init success."; | MS_LOG(INFO) << name_ << ": Init success."; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -43,6 +43,7 @@ class GatherV2PInfo : public OperatorInfo { | |||||
| Status GenerateStrategies(int32_t stage_id) override; | Status GenerateStrategies(int32_t stage_id) override; | ||||
| Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | Status SetCostUnderStrategy(const StrategyPtr &strategy) override; | ||||
| ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; | |||||
| std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; | std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override; | ||||
| protected: | protected: | ||||
| @@ -138,9 +138,9 @@ Status ReshapeInfo::ComputeReplaceOp() { | |||||
| MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed."; | MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed."; | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| MS_LOG(INFO) << name_ << ": input " << input_layout_.ToString(); | |||||
| MS_LOG(INFO) << name_ << ": output " << output_layout_.ToString(); | |||||
| MS_LOG(INFO) << name_ << ": dev_list " << dev_list.size(); | |||||
| MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString(); | |||||
| MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString(); | |||||
| MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size(); | |||||
| RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); | RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList(); | ||||
| if (redistribution_oplist_ptr == nullptr) { | if (redistribution_oplist_ptr == nullptr) { | ||||
| MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; | MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed."; | ||||
| @@ -148,7 +148,7 @@ Status ReshapeInfo::ComputeReplaceOp() { | |||||
| } | } | ||||
| replace_op_ = redistribution_oplist_ptr->first; | replace_op_ = redistribution_oplist_ptr->first; | ||||
| replace_op_info_ = redistribution_oplist_ptr->second; | replace_op_info_ = redistribution_oplist_ptr->second; | ||||
| MS_LOG(INFO) << name_ << ": replace op size = " << replace_op_.size(); | |||||
| MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size(); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||