/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include "backend/optimizer/common/visit.h" #include "base/base.h" #include "utils/log_adapter.h" #include "base/base_ref.h" namespace mindspore { class CondVar; class SeqVar; using CondVarPtr = std::shared_ptr; using SVarPtr = std::shared_ptr; const int kInvalidVarIndex = -2; using ConditionFunc = std::function; // Base wildcard variable which could match any anf node. class Var : public Base { friend class VarHasher; public: explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); } explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { EnsureTag(); } Var(const Var &other) : Base(other), tag_(other.tag_) {} virtual Var &operator=(const Var &other) { if (&other == this) { return *this; } this->tag_ = other.tag_; return *this; } ~Var() override = default; MS_DECLARE_PARENT(Var, Base); virtual bool matches(const BaseRef &) { return true; } virtual bool operator==(const Var &other) const { return tag_ == other.tag_; } bool operator!=(const Var &other) const { return !(&other == this); } std::string tag() const { return tag_; } PrimitivePtr primitive() const { return primitive_; } std::string ToString() const override { std::ostringstream buffer; buffer << "Var(" << tag_ << ")"; return buffer.str(); } std::size_t hash() const override { return std::hash()(tag_); } protected: void EnsureTag(); std::string tag_; PrimitivePtr primitive_; }; // VarNode means variable node, a subclass of AnfNode class VarNode : public AnfNode { public: VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {} ~VarNode() override = default; MS_DECLARE_PARENT(VarNode, AnfNode); const VarPtr var_; }; using VarNodePtr = std::shared_ptr; class VarHasher { public: std::size_t operator()(const Var &var) const { return var.hash(); } }; // Condition Var, match an anf node when condition function return true. class CondVar : public Var { public: explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {} ~CondVar() override = default; MS_DECLARE_PARENT(CondVar, Var); bool matches(const BaseRef &value) override { MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); if (utils::isa(value)) { return false; } return cond_fn_(value); } ConditionFunc cond_fn_; }; using Seq = VectorRef; using SeqPtr = std::shared_ptr; // Sequence Var which could match multiple consecutive input nodes of a CNode. class SeqVar : public Var { public: SeqVar() { subvar_ = std::make_shared(); } ~SeqVar() override = default; MS_DECLARE_PARENT(SeqVar, Var); explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; } bool matches(const BaseRef &value) override { // match Seq. if (utils::isa(value)) { const Seq &seq = utils::cast(value); return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) { auto eq = subvar_->matches(v); return eq; }); } return false; } bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; } std::string ToString() const override; private: VarPtr subvar_; }; bool operator==(const VarPtr &lhs, const VarPtr &rhs); inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); } std::ostream &operator<<(std::ostream &os, const VarPtr &var); using Equiv = std::map; using EquivPtr = std::shared_ptr; using PrimitiveVarMap = std::unordered_map; using PrimitiveVarMapPtr = std::shared_ptr; inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); } class PatternEngine { public: PatternEngine(const std::shared_ptr &visitor, const std::function &eq, const std::function &type_eq = DefaultTypeEq) : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} ~PatternEngine() = default; EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; // Replace pattern with equivalent BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const; private: EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern, VectorRef *const values_expr) const; bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, VectorRef *const values_expr) const; std::shared_ptr visitor_; std::function eq_; std::function type_eq_; }; } // namespace mindspore namespace std { using mindspore::ERROR; using mindspore::LogStream; using mindspore::NoExceptionType; template <> struct hash { std::size_t operator()(const mindspore::VarPtr var) const { if (var == nullptr) { MS_LOG(ERROR) << "Invalid var ptr"; return 0; } return std::hash{}(var->tag()); } }; } // namespace std #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_PATTERN_ENGINE_H_