|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
- #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
-
- #include <vector>
- #include <utility>
- #include <algorithm>
- #include <unordered_map>
- #include <memory>
-
- #include "optimizer/irpass.h"
- #include "optimizer/optimizer.h"
- #include "ir/visitor.h"
- #include "ir/func_graph.h"
- #include "ir/func_graph_cloner.h"
- #include "operator/ops.h"
- #include "utils/symbolic.h"
-
- namespace mindspore {
- namespace opt {
- namespace irpass {
- namespace internal {
- class EnvGetitemTransform {
- public:
- EnvGetitemTransform() : cache_() {}
- ~EnvGetitemTransform() = default;
-
- FuncGraphPtr operator()(const FuncGraphPtr &fg, const SymbolicKeyInstancePtr &key, const AnfNodePtr &default_node) {
- if (cache_.find(fg) == cache_.end()) {
- cache_[fg] = {};
- }
-
- auto &cache = cache_[fg];
- auto hash_key = std::make_pair(key, default_node);
- if (cache.find(hash_key) == cache.end()) {
- std::ostringstream ss("env", std::ostringstream::app);
- if (key->node() != nullptr) {
- ss << key->node()->ToString();
- }
-
- auto new_fg = TransformableClone(fg, std::make_shared<TraceTransform>(ss.str()));
- auto env = new_fg->output();
- while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
- // {prim::kPrimEnvSetItem, env, symbolickey, value}
- auto &inputs = env->cast<CNodePtr>()->inputs();
- if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
- MS_LOG(EXCEPTION) << "It should be SymbolicKeyInstance.";
- }
-
- env = inputs[1];
- auto value = inputs[3];
- auto key2 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
- if (*key2 == *key) {
- new_fg->set_output(value);
- cache[hash_key] = new_fg;
- cache_[fg] = cache;
- return new_fg;
- }
- }
- new_fg->set_output(new_fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, NewValueNode(key), default_node}));
- cache[hash_key] = new_fg;
- }
-
- return cache[hash_key];
- }
-
- private:
- std::unordered_map<FuncGraphPtr,
- std::unordered_map<std::pair<SymbolicKeyInstancePtr, AnfNodePtr>, FuncGraphPtr, PairHasher>>
- cache_;
- };
- } // namespace internal
-
- // {prim::kPrimEnvGetItem, C1, C2, Y} -> Y
- class NewEnvGetItem : public AnfVisitor {
- public:
- AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
- Reset();
- auto gety = [this](const AnfNodePtr &node) -> bool {
- this->y_ = node;
- return true;
- };
-
- AnfVisitor::Match(prim::kPrimEnvGetItem, {IsValueNode<EnvInstance>, IsVNode, gety})(node);
- if (env_ != nullptr && env_->Len() == 0) {
- return y_;
- }
- return nullptr;
- }
-
- void Visit(const ValueNodePtr &vnode) override {
- if (env_ == nullptr) {
- env_ = GetValueNode<EnvInstancePtr>(vnode);
- }
- }
-
- void Reset() {
- y_ = nullptr;
- env_ = nullptr;
- }
-
- private:
- AnfNodePtr y_{nullptr};
- EnvInstancePtr env_{nullptr};
- };
-
- // {prim::kPrimEnvGetItem, {prim::kPrimEnvAdd, X, Y}, C, Z} ->
- // {prim::GetPythonOps("hyper_add"), {prim::kPrimEnvGetItem, X, C, Z}, {prim::kPrimEnvGetItem, Y, C, Z}}
- class AddEnvGetItem : public AnfVisitor {
- public:
- AddEnvGetItem() : PrimHyperAdd_(prim::GetPythonOps("hyper_add")) {}
- ~AddEnvGetItem() override = default;
-
- AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
- is_match_ = false;
- auto IsAddCNode = [](const AnfNodePtr &node) -> bool {
- return IsPrimitiveCNode(node, prim::kPrimEnvAdd) && node->cast<CNodePtr>()->size() == 3;
- };
- AnfVisitor::Match(prim::kPrimEnvGetItem, {IsAddCNode, IsVNode, IsNode})(node);
-
- if (!is_match_ || node->func_graph() == nullptr) {
- return nullptr;
- }
-
- // {prim::kPrimEnvGetItem, {...}, C, Z}
- auto cnode = node->cast<CNodePtr>();
- auto inp1 = cnode->input(1)->cast<CNodePtr>();
- auto c = cnode->input(2);
- auto z = cnode->input(3);
-
- // {prim::kPrimEnvAdd, X, Y}
- auto x = inp1->input(1);
- auto y = inp1->input(2);
-
- auto fg = node->func_graph();
- auto xcz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), x, c, z});
- auto ycz = fg->NewCNode({NewValueNode(prim::kPrimEnvGetItem), y, c, z});
-
- return fg->NewCNode({NewValueNode(PrimHyperAdd_), xcz, ycz});
- }
-
- void Visit(const AnfNodePtr &) override { is_match_ = true; }
-
- private:
- bool is_match_{false};
- ValuePtr PrimHyperAdd_;
- };
-
- // {prim::kPrimEnvGetItem, {prim::kPrimEnvSetItem, X, C1, Y}, C2, Z}
- class EnvGetSetItem : public AnfVisitor {
- public:
- AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
- is_match_ = false;
- auto IsSetCNode = [](const AnfNodePtr &node) -> bool {
- if (!IsPrimitiveCNode(node, prim::kPrimEnvSetItem)) {
- return false;
- }
-
- // {prim::kPrimEnvSetItem, X, C1, Y}
- auto &inputs = node->cast<CNodePtr>()->inputs();
- if (inputs.size() != 4) {
- return false;
- }
-
- return IsValueNode<SymbolicKeyInstance>(inputs[2]);
- };
- AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSetCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
-
- if (!is_match_ || node->func_graph() == nullptr) {
- return nullptr;
- }
-
- // {prim::kPrimEnvGetItem, {...}, C2, Z}
- auto cnode = node->cast<CNodePtr>();
- auto inp1 = cnode->input(1)->cast<CNodePtr>();
- auto key2 = cnode->input(2);
- auto c2 = GetValueNode<SymbolicKeyInstancePtr>(key2);
- auto default_v = cnode->input(3);
-
- // {prim::kPrimEnvSetItem, X, C1, Y}
- auto env = inp1->input(1);
- auto c1 = GetValueNode<SymbolicKeyInstancePtr>(inp1->input(2));
- auto last_set = inp1->input(3);
-
- if (*c1 == *c2) {
- return last_set;
- }
-
- while (IsPrimitiveCNode(env, prim::kPrimEnvSetItem)) {
- // {prim::kPrimEnvSetItem, env, symbolickey, value}
- auto &inputs = env->cast<CNodePtr>()->inputs();
- if (inputs.size() != 4 || !IsValueNode<SymbolicKeyInstance>(inputs[2])) {
- MS_LOG(EXCEPTION) << "Input 2 should be a SymbolicKeyInstance.";
- }
-
- env = inputs[1];
- last_set = inputs[3];
- auto symbolic_c1 = GetValueNode<SymbolicKeyInstancePtr>(inputs[2]);
- if (*symbolic_c1 == *c2) {
- return last_set;
- }
- }
-
- return node->func_graph()->NewCNode({NewValueNode(prim::kPrimEnvGetItem), env, key2, default_v});
- }
-
- void Visit(const AnfNodePtr &) override { is_match_ = true; }
-
- private:
- bool is_match_{false};
- };
-
- // {prim::kPrimEnvGetItem, {G, Xs}, C, Y}
- class IncorporateEnvGetitem : public AnfVisitor {
- public:
- IncorporateEnvGetitem() : env_get_item_transform_() {}
- ~IncorporateEnvGetitem() override = default;
-
- AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
- is_match_ = false;
- auto IsGCNode = [](const AnfNodePtr &node) -> bool {
- auto cnode = node->cast<CNodePtr>();
- if (cnode == nullptr || cnode->size() < 1) {
- return false;
- }
- return IsValueNode<FuncGraph>(cnode->input(0));
- };
- AnfVisitor::Match(prim::kPrimEnvGetItem, {IsGCNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
-
- if (!is_match_) {
- return nullptr;
- }
-
- // {prim::kPrimEnvGetItem, {...}, C, Y}
- auto cnode = node->cast<CNodePtr>();
- auto inp1 = cnode->input(1)->cast<CNodePtr>();
- auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
- auto default_v = cnode->input(3);
-
- // {G, Xs}
- auto inputs = inp1->inputs();
- auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
- auto new_fg = env_get_item_transform_(fg, key, default_v);
-
- std::vector<AnfNodePtr> args;
- args.push_back(NewValueNode(new_fg));
- (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
-
- return node->func_graph()->NewCNode(args);
- }
-
- void Visit(const AnfNodePtr &) override { is_match_ = true; }
-
- private:
- bool is_match_{false};
- internal::EnvGetitemTransform env_get_item_transform_;
- };
-
- // {prim::kPrimEnvGetItem, {{prim::kPrimSwitch, X, G1, G2}, Xs}, C, Y}
- class IncorporateEnvGetitemSwitch : public AnfVisitor {
- public:
- IncorporateEnvGetitemSwitch() : env_get_item_transform_() {}
- ~IncorporateEnvGetitemSwitch() override = default;
-
- AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
- is_match_ = false;
- auto IsSwNode = [](const AnfNodePtr &node) -> bool {
- auto cnode = node->cast<CNodePtr>();
- if (cnode == nullptr || cnode->size() < 1) {
- return false;
- }
-
- return IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch);
- };
- AnfVisitor::Match(prim::kPrimEnvGetItem, {IsSwNode, IsValueNode<SymbolicKeyInstance>, IsNode})(node);
- if (!is_match_ || node->func_graph() == nullptr) {
- return nullptr;
- }
-
- // {prim::kPrimEnvGetItem, {...}, C, Y}
- auto cnode = node->cast<CNodePtr>();
- auto inp1 = cnode->input(1)->cast<CNodePtr>();
- auto key = GetValueNode<SymbolicKeyInstancePtr>(cnode->input(2));
- auto default_v = cnode->input(3);
-
- // {{prim::kPrimSwitch, X, G1, G2}, Xs}
- auto inputs = inp1->inputs();
- is_match_ = false;
- AnfVisitor::Match(prim::kPrimSwitch, {IsNode, IsValueNode<FuncGraph>, IsValueNode<FuncGraph>})(inputs[0]);
- if (!is_match_) {
- return nullptr;
- }
-
- // {prim::kPrimSwitch, X, G1, G2}
- auto sw = inputs[0]->cast<CNodePtr>();
- auto x = sw->input(1);
- auto g1 = GetValueNode<FuncGraphPtr>(sw->input(2));
- auto g2 = GetValueNode<FuncGraphPtr>(sw->input(3));
- auto new_g1 = env_get_item_transform_(g1, key, default_v);
- auto new_g2 = env_get_item_transform_(g2, key, default_v);
-
- auto fg = node->func_graph();
- auto new_sw = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x, NewValueNode(new_g1), NewValueNode(new_g2)});
-
- std::vector<AnfNodePtr> args{new_sw};
- (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
-
- return fg->NewCNode(args);
- }
-
- void Visit(const AnfNodePtr &) override { is_match_ = true; }
-
- private:
- bool is_match_{false};
- internal::EnvGetitemTransform env_get_item_transform_;
- };
- } // namespace irpass
- } // namespace opt
- } // namespace mindspore
- #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_ENV_ITEM_ELIMINATE_H_
|