| @@ -24,6 +24,7 @@ | |||||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | #include "frontend/optimizer/irpass/gradient_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/inline.h" | #include "frontend/optimizer/irpass/inline.h" | ||||
| #include "frontend/optimizer/irpass/updatestate_eliminate.h" | #include "frontend/optimizer/irpass/updatestate_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/load_eliminate.h" | |||||
| #include "frontend/optimizer/irpass/stopgrad_eliminate.h" | #include "frontend/optimizer/irpass/stopgrad_eliminate.h" | ||||
| #include "frontend/optimizer/irpass/incorporate_call.h" | #include "frontend/optimizer/irpass/incorporate_call.h" | ||||
| #include "frontend/optimizer/irpass/incorporate_getitem.h" | #include "frontend/optimizer/irpass/incorporate_getitem.h" | ||||
| @@ -156,6 +157,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||||
| switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(), | switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(), | ||||
| "switch_call_monad_eliminater", IsCNodeDup); | "switch_call_monad_eliminater", IsCNodeDup); | ||||
| // Load eliminate | |||||
| load_eliminater_ = MakeSubstitution(std::make_shared<LoadEliminater>(), "load_eliminater", prim::kPrimLoad); | |||||
| // StopGradient eliminate | // StopGradient eliminate | ||||
| stopgrad_eliminater_ = | stopgrad_eliminater_ = | ||||
| MakeSubstitution(std::make_shared<StopGradientEliminater>(), "stopgrad_eliminater", prim::kPrimStopGradient); | MakeSubstitution(std::make_shared<StopGradientEliminater>(), "stopgrad_eliminater", prim::kPrimStopGradient); | ||||
| @@ -96,6 +96,7 @@ class OptimizeIRPassLib { | |||||
| SubstitutionPtr updatestate_eliminater_; | SubstitutionPtr updatestate_eliminater_; | ||||
| SubstitutionPtr switch_call_monad_eliminater_; | SubstitutionPtr switch_call_monad_eliminater_; | ||||
| SubstitutionPtr stopgrad_eliminater_; | SubstitutionPtr stopgrad_eliminater_; | ||||
| SubstitutionPtr load_eliminater_; | |||||
| // Incorporation | // Incorporation | ||||
| SubstitutionPtr incorporate_getitem_set_; | SubstitutionPtr incorporate_getitem_set_; | ||||
| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "frontend/optimizer/irpass/load_eliminate.h" | |||||
| #include <algorithm> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <vector> | |||||
| #include "frontend/operator/ops.h" | |||||
| namespace mindspore::opt::irpass { | |||||
| namespace { | |||||
| // Return true if the node has Ref abstract. | |||||
| bool HasAbstractRef(const AnfNodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto &abs = node->abstract(); | |||||
| return (abs != nullptr) && abs->isa<abstract::AbstractRef>(); | |||||
| } | |||||
| } // namespace | |||||
| AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||||
| auto load_node = dyn_cast<CNode>(node); | |||||
| if (load_node == nullptr || load_node->inputs().empty()) { | |||||
| MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString(); | |||||
| return nullptr; | |||||
| } | |||||
| constexpr size_t kFirstInputIndex = 1; | |||||
| auto ¶m = load_node->inputs().at(kFirstInputIndex); | |||||
| if (!HasAbstractRef(param)) { | |||||
| return param; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace mindspore::opt::irpass | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_ | |||||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_ | |||||
| #include "frontend/optimizer/optimizer.h" | |||||
| #include "frontend/optimizer/anf_visitor.h" | |||||
| namespace mindspore::opt::irpass { | |||||
| // | |||||
| // LoadEliminater eliminates redundant Load related nodes. | |||||
| // | |||||
| class LoadEliminater : public AnfVisitor { | |||||
| public: | |||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||||
| }; | |||||
| } // namespace mindspore::opt::irpass | |||||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_ | |||||
| @@ -103,6 +103,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| // Safe inlining | // Safe inlining | ||||
| irpass.inline_, | irpass.inline_, | ||||
| irpass.updatestate_eliminater_, | irpass.updatestate_eliminater_, | ||||
| irpass.load_eliminater_, | |||||
| irpass.stopgrad_eliminater_, | irpass.stopgrad_eliminater_, | ||||
| irpass.partial_eliminate_, | irpass.partial_eliminate_, | ||||
| irpass.replace_applicator_, | irpass.replace_applicator_, | ||||
| @@ -130,6 +131,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||||
| // Safe inlining | // Safe inlining | ||||
| irpass.inline_, | irpass.inline_, | ||||
| irpass.updatestate_eliminater_, | irpass.updatestate_eliminater_, | ||||
| irpass.load_eliminater_, | |||||
| irpass.stopgrad_eliminater_, | irpass.stopgrad_eliminater_, | ||||
| irpass.sparse_tensor_eliminate_, | irpass.sparse_tensor_eliminate_, | ||||
| }); | }); | ||||
| @@ -195,6 +197,7 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp | |||||
| // Safe inlining, | // Safe inlining, | ||||
| irpass.inline_, | irpass.inline_, | ||||
| irpass.updatestate_eliminater_, | irpass.updatestate_eliminater_, | ||||
| irpass.load_eliminater_, | |||||
| irpass.switch_call_monad_eliminater_, | irpass.switch_call_monad_eliminater_, | ||||
| irpass.stopgrad_eliminater_, | irpass.stopgrad_eliminater_, | ||||
| irpass.partial_eliminate_, | irpass.partial_eliminate_, | ||||
| @@ -220,9 +223,9 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib | |||||
| OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | ||||
| opt::OptPassConfig b_1 = opt::OptPassConfig( | opt::OptPassConfig b_1 = opt::OptPassConfig( | ||||
| {irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_, | {irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_, | ||||
| irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.stopgrad_eliminater_, | |||||
| irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.incorporate_env_getitem_, | |||||
| irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | |||||
| irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, | |||||
| irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | |||||
| irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | |||||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); | irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); | ||||
| opt::OptPassConfig b_2 = opt::OptPassConfig({ | opt::OptPassConfig b_2 = opt::OptPassConfig({ | ||||
| irpass.replace_refkey_by_param_, | irpass.replace_refkey_by_param_, | ||||
| @@ -86,6 +86,10 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p | |||||
| // Inputs: one tensor. | // Inputs: one tensor. | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| CheckArgsSize(op_name, args_spec_list, 1); | CheckArgsSize(op_name, args_spec_list, 1); | ||||
| auto ref = dyn_cast<abstract::AbstractRef>(args_spec_list[0]); | |||||
| if (ref != nullptr) { | |||||
| return ref->CloneAsTensor(); | |||||
| } | |||||
| return args_spec_list[0]->Broaden(); | return args_spec_list[0]->Broaden(); | ||||
| } | } | ||||