/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_ #include #include #include #include #include #include "optimizer/irpass.h" #include "optimizer/optimizer.h" #include "ir/visitor.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" #include "operator/ops.h" #include "utils/symbolic.h" namespace mindspore { namespace opt { namespace irpass { namespace internal { class EnvGetitemTransform { public: EnvGetitemTransform() : cache_() {} ~EnvGetitemTransform() = default; FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) { if (cache_.find(fg) == cache_.end()) { cache_[fg] = {}; } auto &cache = cache_[fg]; auto hash_key = std::make_pair(key, default_node); if (cache.find(hash_key) == cache.end()) { std::ostringstream ss("env", std::ostringstream::app); if (key->node() != nullptr) { ss << key->node()->ToString(); } auto new_fg = TransformableClone(fg, std::make_shared(ss.str())); auto env = new_fg->output(); while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { // {prim::kPrimEnvSetItem, env, symbolickey, value} auto &inputs = env->cast()->inputs(); if (inputs.size() != 4 || !IsValueNode(inputs[2])) { MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance."; } env = inputs[1]; auto value = inputs[3]; auto key2 = GetValueNode(inputs[2]); if (*key2 == *key) { new_fg->set_output(value); cache[hash_key] = new_fg; cache_[fg] = cache; return new_fg; } } new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node})); cache[hash_key] = new_fg; } return cache[hash_key]; } private: std::unordered_map, FuncGraphPtr, PairHasher>> cache_; }; } // namespace internal // {prim::kPrimEnvGetItem, C1, C2, Y} -> Y class NewEnvGetItem : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { Reset(); auto gety = [this](const AnfNodePtr &node) -> bool { this->y_ = node; return true; }; AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode, IsVNode, gety})(node); if (env_ != nullptr && env_->Len() == 0) { return y_; } return nullptr; } void Visit(const ValueNodePtr &vnode) override { if (env_ == nullptr) { env_ = GetValueNode(vnode); } } void Reset() { y_ = nullptr; env_ = nullptr; } private: AnfNodePtr y_{nullptr}; EnvInstancePtr env_{nullptr}; }; // {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} -> // {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}} class AddEnvGetItem : public AnfVisitor { public: AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {} ~AddEnvGetItem() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { is_match_ = false; auto IsAddCNode = [](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast()->size() == 3; }; AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node); if (!is_match_ || node->func_graph() == nullptr) { return nullptr; } // {prim::kPrimEnvGetItem, {...}, C, Z} auto cnode = node->cast(); auto inp1 = cnode->input(1)->cast(); auto c = cnode->input(2); auto z = cnode->input(3); // {prim::kPrimEnvAdd, X, Y} auto x = inp1->input(1); auto y = inp1->input(2); auto fg = node->func_graph(); auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z}); auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z}); return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz}); } void Visit(const AnfNodePtr &) override { is_match_ = true; } private: bool is_match_{false}; ValuePtr PrimHyperAdd_; }; // {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z} class EnvGetSetItem : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { is_match_ = false; auto IsSetCNode = [](const AnfNodePtr &node) -> bool { if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) { return false; } // {prim::kPrimEnvSetItem, X, C1, Y} auto &inputs = node->cast()->inputs(); if (inputs.size() != 4) { return false; } return IsValueNode(inputs[2]); }; AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode, IsNode})(node); if (!is_match_ || node->func_graph() == nullptr) { return nullptr; } // {prim::kPrimEnvGetItem, {...}, C2, Z} auto cnode = node->cast(); auto inp1 = cnode->input(1)->cast(); auto key2 = cnode->input(2); auto c2 = GetValueNode(key2); auto default_v = cnode->input(3); // {prim::kPrimEnvSetItem, X, C1, Y} auto env = inp1->input(1); auto c1 = GetValueNode(inp1->input(2)); auto last_set = inp1->input(3); if (*c1 == *c2) { return last_set; } while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) { // {prim::kPrimEnvSetItem, env, symbolickey, value} auto &inputs = env->cast()->inputs(); if (inputs.size() != 4 || !IsValueNode(inputs[2])) { MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance."; } env = inputs[1]; last_set = inputs[3]; auto symbolic_c1 = GetValueNode(inputs[2]); if (*symbolic_c1 == *c2) { return last_set; } } return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v}); } void Visit(const AnfNodePtr &) override { is_match_ = true; } private: bool is_match_{false}; }; // {prim::kPrimEnvGetItem, {G, Xs}, C, Y} class IncorporateEnvGetitem : public AnfVisitor { public: IncorporateEnvGetitem() : env_get_item_transform_() {} ~IncorporateEnvGetitem() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { is_match_ = false; auto IsGCNode = [](const AnfNodePtr &node) -> bool { auto cnode = node->cast(); if (cnode == nullptr || cnode->size() < 1) { return false; } return IsValueNode(cnode->input(0)); }; AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode, IsNode})(node); if (!is_match_) { return nullptr; } // {prim::kPrimEnvGetItem, {...}, C, Y} auto cnode = node->cast(); auto inp1 = cnode->input(1)->cast(); auto key = GetValueNode(cnode->input(2)); auto default_v = cnode->input(3); // {G, Xs} auto inputs = inp1->inputs(); auto fg = GetValueNode(inputs[0]); auto new_fg = env_get_item_transform_(fg, key, default_v); std::vector args; args.push_back(NewValueNode(new_fg)); (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); return node->func_graph()->NewCNode(args); } void Visit(const AnfNodePtr &) override { is_match_ = true; } private: bool is_match_{false}; internal::EnvGetitemTransform env_get_item_transform_; }; // {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y} class IncorporateEnvGetitemSwitch : public AnfVisitor { public: IncorporateEnvGetitemSwitch() : env_get_item_transform_() {} ~IncorporateEnvGetitemSwitch() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { is_match_ = false; auto IsSwNode = [](const AnfNodePtr &node) -> bool { auto cnode = node->cast(); if (cnode == nullptr || cnode->size() < 1) { return false; } return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch); }; AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode, IsNode})(node); if (!is_match_ || node->func_graph() == nullptr) { return nullptr; } // {prim::kPrimEnvGetItem, {...}, C, Y} auto cnode = node->cast(); auto inp1 = cnode->input(1)->cast(); auto key = GetValueNode(cnode->input(2)); auto default_v = cnode->input(3); // {{prim::kPrimSwitch, X, G1, G2}, Xs} auto inputs = inp1->inputs(); is_match_ = false; AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode, IsValueNode})(inputs[0]); if (!is_match_) { return nullptr; } // {prim::kPrimSwitch, X, G1, G2} auto sw = inputs[0]->cast(); auto x = sw->input(1); auto g1 = GetValueNode(sw->input(2)); auto g2 = GetValueNode(sw->input(3)); auto new_g1 = env_get_item_transform_(g1, key, default_v); auto new_g2 = env_get_item_transform_(g2, key, default_v); auto fg = node->func_graph(); auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)}); std::vector args{new_sw}; (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); return fg->NewCNode(args); } void Visit(const AnfNodePtr &) override { is_match_ = true; } private: bool is_match_{false}; internal::EnvGetitemTransform env_get_item_transform_; }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_