| @@ -24,6 +24,7 @@ | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| #include "frontend/optimizer/irpass/inline.h" | |||
| #include "frontend/optimizer/irpass/updatestate_eliminate.h" | |||
| #include "frontend/optimizer/irpass/load_eliminate.h" | |||
| #include "frontend/optimizer/irpass/stopgrad_eliminate.h" | |||
| #include "frontend/optimizer/irpass/incorporate_call.h" | |||
| #include "frontend/optimizer/irpass/incorporate_getitem.h" | |||
| @@ -156,6 +157,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(), | |||
| "switch_call_monad_eliminater", IsCNodeDup); | |||
| // Load eliminate | |||
| load_eliminater_ = MakeSubstitution(std::make_shared<LoadEliminater>(), "load_eliminater", prim::kPrimLoad); | |||
| // StopGradient eliminate | |||
| stopgrad_eliminater_ = | |||
| MakeSubstitution(std::make_shared<StopGradientEliminater>(), "stopgrad_eliminater", prim::kPrimStopGradient); | |||
| @@ -96,6 +96,7 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr updatestate_eliminater_; | |||
| SubstitutionPtr switch_call_monad_eliminater_; | |||
| SubstitutionPtr stopgrad_eliminater_; | |||
| SubstitutionPtr load_eliminater_; | |||
| // Incorporation | |||
| SubstitutionPtr incorporate_getitem_set_; | |||
| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/optimizer/irpass/load_eliminate.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "frontend/operator/ops.h" | |||
| namespace mindspore::opt::irpass { | |||
| namespace { | |||
| // Return true if the node has Ref abstract. | |||
| bool HasAbstractRef(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return false; | |||
| } | |||
| auto &abs = node->abstract(); | |||
| return (abs != nullptr) && abs->isa<abstract::AbstractRef>(); | |||
| } | |||
| } // namespace | |||
| AnfNodePtr LoadEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { | |||
| auto load_node = dyn_cast<CNode>(node); | |||
| if (load_node == nullptr || load_node->inputs().empty()) { | |||
| MS_LOG(WARNING) << "LoadEliminater encounter invalid node: " << node->DebugString(); | |||
| return nullptr; | |||
| } | |||
| constexpr size_t kFirstInputIndex = 1; | |||
| auto ¶m = load_node->inputs().at(kFirstInputIndex); | |||
| if (!HasAbstractRef(param)) { | |||
| return param; | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace mindspore::opt::irpass | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_ | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| namespace mindspore::opt::irpass { | |||
| // | |||
| // LoadEliminater eliminates redundant Load related nodes. | |||
| // | |||
| class LoadEliminater : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; | |||
| }; | |||
| } // namespace mindspore::opt::irpass | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_LOAD_ELIMINATE_H_ | |||
| @@ -103,6 +103,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| // Safe inlining | |||
| irpass.inline_, | |||
| irpass.updatestate_eliminater_, | |||
| irpass.load_eliminater_, | |||
| irpass.stopgrad_eliminater_, | |||
| irpass.partial_eliminate_, | |||
| irpass.replace_applicator_, | |||
| @@ -130,6 +131,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| // Safe inlining | |||
| irpass.inline_, | |||
| irpass.updatestate_eliminater_, | |||
| irpass.load_eliminater_, | |||
| irpass.stopgrad_eliminater_, | |||
| irpass.sparse_tensor_eliminate_, | |||
| }); | |||
| @@ -195,6 +197,7 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp | |||
| // Safe inlining, | |||
| irpass.inline_, | |||
| irpass.updatestate_eliminater_, | |||
| irpass.load_eliminater_, | |||
| irpass.switch_call_monad_eliminater_, | |||
| irpass.stopgrad_eliminater_, | |||
| irpass.partial_eliminate_, | |||
| @@ -220,9 +223,9 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib | |||
| OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig b_1 = opt::OptPassConfig( | |||
| {irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_, | |||
| irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.stopgrad_eliminater_, | |||
| irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.incorporate_env_getitem_, | |||
| irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | |||
| irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.load_eliminater_, | |||
| irpass.stopgrad_eliminater_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, | |||
| irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, | |||
| irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); | |||
| opt::OptPassConfig b_2 = opt::OptPassConfig({ | |||
| irpass.replace_refkey_by_param_, | |||
| @@ -86,6 +86,10 @@ AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &p | |||
| // Inputs: one tensor. | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 1); | |||
| auto ref = dyn_cast<abstract::AbstractRef>(args_spec_list[0]); | |||
| if (ref != nullptr) { | |||
| return ref->CloneAsTensor(); | |||
| } | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||