/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "backend/optimizer/common/pattern_engine.h" #include #include #include #include "frontend/optimizer/opt.h" #include "ir/anf.h" #include "utils/convert_utils_base.h" #include "utils/overload.h" namespace mindspore { static int GetNextTag() { static int kID = 0; return kID++; } void Var::EnsureTag() { if (tag_.length() == 0) { std::ostringstream buffer; buffer << "_" << GetNextTag(); tag_ = buffer.str(); } } bool operator==(const VarPtr &lhs, const VarPtr &rhs) { if (lhs->isa() && rhs->isa()) { CondVarPtr v1 = dyn_cast(lhs); CondVarPtr v2 = dyn_cast(rhs); return *v1 == *v2; } if (lhs->isa() && rhs->isa()) { SVarPtr v1 = dyn_cast(lhs); SVarPtr v2 = dyn_cast(rhs); return *v1 == *v2; } return (*lhs == *rhs); } std::string SeqVar::ToString() const { std::ostringstream buffer; buffer << "SeqVar(" << tag() << ", " << subvar_->ToString() << ")"; return buffer.str(); } std::ostream &operator<<(std::ostream &os, const VarPtr &var) { if (var == nullptr) { os << ""; } else { os << var->ToString(); } return os; } template <> std::ostream &operator<<(std::ostream &os, const Equiv &equiv) { os << "[Equiv]" << "\n"; for (auto &equiv_item : equiv) { auto k = equiv_item.first; os << k << ":"; BaseRef x = equiv_item.second; if (utils::isa(x)) { auto node = utils::cast(x); os << "TypeString[" << node->type_name() << "]"; if (IsValueNode(node)) { os << "IsValueNodeGraph "; } os << "type " << node->type_name(); if (node->isa()) { os << " value " << GetValueNode(node); } os << " addr: " << node; } else if (utils::isa(x)) { os << "Named " << x.ToString().c_str(); } else if (utils::isa(x)) { os << "TypeString[Var]"; os << utils::cast(x); } else if (utils::isa(x)) { os << "TypeString[Graph]"; } os << "\n"; } return os; } static BaseRef GetVar(const BaseRef &x) { MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); if (utils::isa(x)) { auto node = utils::cast(x); MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]"; if (node->isa()) { MS_LOG(DEBUG) << "IsVarNode " + node->cast()->var_->ToString(); return node->cast()->var_; } if (node->isa()) { MS_LOG(DEBUG) << "value " + GetValueNode(node)->ToString() + " addr: " + node->ToString(); } else { MS_LOG(DEBUG) << "type " + node->type_name(); } } else if (utils::isa(x)) { MS_LOG(DEBUG) << "Named " + x.ToString(); } else if (utils::isa(x)) { MS_LOG(DEBUG) << "VectorRef"; } else if (utils::isa(x)) { MS_LOG(DEBUG) << "TypeString[Var] " + x.ToString(); } MS_LOG(DEBUG) << "GetVar end: " + x.ToString(); return x; } EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); MS_EXCEPTION_IF_NULL(equiv); if (utils::isa(pattern)) { VarPtr var = utils::cast(pattern); if (var->matches(expr)) { (*equiv)[var] = expr; MS_LOG(DEBUG) << "pattern is var match: " + pattern.ToString() + ", " + expr.ToString(); return equiv; } } return nullptr; } bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, VectorRef *const values_expr) const { MS_EXCEPTION_IF_NULL(values_expr); if (utils::isa(pattern_ref)) { *values_pattern = pattern_ref; *values_expr = expr_ref; return true; } return false; } bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern, VectorRef *const values_expr) const { MS_EXCEPTION_IF_NULL(values_expr); // visitor to visite the list auto appender_pattern = [](VectorRef &values) { std::function fn = [&](const BaseRef &u) { values.push_back(GetVar(u)); return u; }; return fn; }; visitor_->SetFn(appender_pattern(*values_pattern)); MS_LOG(DEBUG) << "visit pattern_ref"; bool success = visitor_->Visit(pattern_ref, nullptr); if (!success) { return false; } auto appender_expr = [](VectorRef &values) { std::function fn = [&](const BaseRef &u) { values.push_back(u); return u; }; return fn; }; visitor_->SetFn(appender_expr(*values_expr)); MS_LOG(DEBUG) << "visit expr_ref"; return visitor_->Visit(expr_ref, nullptr); } static int GetSVarStartIndex(const VectorRef &values) { int index = -1; int count = 0; for (auto &value : values) { if (utils::isa(value) && utils::cast(value)->isa()) { if (index != -1) { MS_LOG(DEBUG) << "Multiple SVars in sequence"; return kInvalidVarIndex; } index = count; } count++; } return index; } void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars, EquivPtr equiv) { if (equiv == nullptr || values_pattern.empty() || !utils::isa(values_pattern[0]) || !utils::isa(expr_ref)) { return; } auto real_node = utils::cast(expr_ref); MS_EXCEPTION_IF_NULL(real_node); if (!real_node->isa()) { return; } auto prim_node = utils::cast(values_pattern[0]); MS_EXCEPTION_IF_NULL(prim_node); if (!IsValueNode(prim_node)) { return; } ValuePtr value = GetValueNode(prim_node); MS_EXCEPTION_IF_NULL(value); auto prim = value->cast(); MS_EXCEPTION_IF_NULL(prim); auto iter = primitive_vars.find(prim); if (iter == primitive_vars.end()) { return; } (*equiv)[iter->second] = real_node; } EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { int svar_index = GetSVarStartIndex(values_pattern); if (svar_index == kInvalidVarIndex) { return nullptr; } size_t values_pattern_len = values_pattern.size(); size_t values_expr_len = values_expr.size(); if (svar_index == -1) { if (values_pattern_len != values_expr_len) { MS_LOG(DEBUG) << "Structures of differing size: pattern len " << values_pattern_len << ", expr len " << values_expr_len; return nullptr; } } if (values_expr_len < values_pattern_len - 1) { MS_LOG(DEBUG) << "invalid size: pattern len " << values_pattern_len << ", expr len " << values_expr_len; return nullptr; } size_t diff = values_expr_len - values_pattern_len + 1; for (size_t i = 0; i < values_pattern_len; i++) { size_t expr_i = i; if (svar_index != -1 && i == IntToSize(svar_index)) { auto seq = std::vector(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv); } else { if (svar_index != -1 && i > IntToSize(svar_index)) { expr_i = i + diff - 1; } equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv); } if (equiv == nullptr) { return nullptr; } } return equiv; } EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { MS_LOG(DEBUG) << "-----[in Match]"; MS_LOG(DEBUG) << "GetVar w"; BaseRef pattern_ref = GetVar(pattern); MS_LOG(DEBUG) << "GetVar v"; BaseRef expr_ref = expr; if (equiv == nullptr) { MS_LOG(EXCEPTION) << "Equiv pointer is null"; } MS_LOG(DEBUG) << "Pattern ref " + pattern_ref.ToString() + ", expr ref" + expr_ref.ToString(); // 1. if pattern_ref is var and already in equiv, replace it. if (utils::isa(pattern_ref)) { VarPtr var = utils::cast(pattern_ref); auto iter = equiv->find(var); if (iter != equiv->end()) { pattern_ref = iter->second; } } // 2. check equal if (eq_(pattern_ref, expr_ref)) { return equiv; } // 3. match var EquivPtr ret_equiv = MatchOnVar(pattern_ref, expr_ref, equiv); if (ret_equiv) { return ret_equiv; } // 4. here the type can be std:vector, std:list, // or cnode. if (!type_eq_(pattern_ref, expr_ref)) { MS_LOG(DEBUG) << "Type mismatch"; return nullptr; } // 5. transfer the Containers by visitor to std::vector VectorRef values_pattern; VectorRef values_expr; if (!ToVector(pattern_ref, expr_ref, &values_pattern, &values_expr)) { return nullptr; } // 6. if any svar in both side, find the SeqVar index, // try to pack the Var s in std::vector to a Seq and match elements one by one. // check svar equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv); UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv); return equiv; } BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(equiv); MS_LOG(DEBUG) << "-----[in Replace]"; BaseRef ref = GetVar(pattern); BaseRef out; bool is_match = false; // w is var if (utils::isa(ref)) { const VarPtr &var = utils::cast(ref); auto iter = equiv->find(var); if (iter != equiv->end()) { out = iter->second; is_match = true; } } if (is_match) { return out; } // visitor to visit the list std::function fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); }; visitor_->SetFn(fn); BaseRef visit_out; if (!visitor_->Visit(pattern, &visit_out)) { return pattern; } return visit_out; } } // namespace mindspore