/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "backend/optimizer/common/visit.h" #include #include #include #include "backend/optimizer/common/pattern_engine.h" #include "utils/any.h" #include "ir/anf.h" #include "ir/func_graph.h" #include "utils/log_adapter.h" /* namespace to support utils definition */ namespace mindspore { bool CheckIfNeedExpand(const std::vector &list) { return std::any_of(list.begin(), list.end(), [](const BaseRef &any) { return utils::isa(any); }); } std::shared_ptr ExpandList(const std::vector &list) { std::shared_ptr new_list = std::make_shared(); for (auto &item : list) { if (utils::isa(item)) { const Seq &seq = utils::cast(item); new_list->insert(new_list->end(), seq.begin(), seq.end()); } else { new_list->push_back(item); } } return new_list; } static BaseRef GetVar(const BaseRef &x) { if (utils::isa(x)) { auto node = utils::cast(x); MS_LOG(DEBUG) << "TypeString [" + node->type_name() + "]"; if (node->isa()) { MS_LOG(DEBUG) << "IsVarNode " + node->cast()->var_->ToString(); return node->cast()->var_; } } return x; } bool Visitor::Visit(const VectorRef &v_any, VectorRef *const values_ref, BaseRef *const visit_out) const { std::vector out; for (const auto &element : v_any) { out.push_back(element); values_ref->push_back(GetVar(element)); } if (visit_out != nullptr) { *visit_out = ExpandList(out); } return true; } bool Visitor::Visit(const BaseRef &any, VectorRef *const values_ref, BaseRef *const visit_out) const { if (utils::isa(any)) { return Visit(utils::cast(any), values_ref, visit_out); } else if (utils::isa(any)) { auto nodeptr = utils::cast(any); AnfNodePtr output; AnfNodePtr *p_output = &output; if (visit_out == nullptr) { p_output = nullptr; } Visit(nodeptr, values_ref, p_output); if (visit_out != nullptr) { *visit_out = output; } return true; } MS_LOG(DEBUG) << "VisitError, not support type to Visit: " + any.ToString(); return false; } void Visitor::Visit(const AnfNodePtr &node, VectorRef *const values_ref, AnfNodePtr *output) const { if (node->isa()) { Visit(node->cast(), values_ref, output); return; } if (node->isa()) { Visit(node->cast(), values_ref, output); return; } if (output != nullptr) { *output = node; } } void Visitor::Visit(const CNodePtr &cnode, VectorRef *const values_ref, AnfNodePtr *output) const { // if output is nullptr, it's not required to make the new CNode node. if (output == nullptr) { for (auto &inp : cnode->inputs()) { auto var = GetVar(inp); values_ref->push_back(var); } if (cnode->func_graph() != nullptr) { values_ref->push_back(GetVar(cnode->func_graph())); } else { values_ref->push_back(GetVar(cnode->func_graph_as_var())); } return; } std::vector new_inputs; std::vector after_cnode_fn; std::shared_ptr out; for (auto &input : cnode->inputs()) { after_cnode_fn.push_back(input); values_ref->push_back(GetVar(input)); } if (CheckIfNeedExpand(after_cnode_fn)) { out = ExpandList(after_cnode_fn); } std::vector &outs = after_cnode_fn; if (out != nullptr) { outs = out->elements(); } for (auto &any_item : outs) { if (!utils::isa(any_item)) { MS_LOG(EXCEPTION) << "VisitError, fn not return the same type AnfNodePtr"; } new_inputs.push_back(utils::cast(any_item)); } BaseRef any_fg; AnfNodePtr new_cnode = nullptr; if (cnode->func_graph() != nullptr) { any_fg = cnode->func_graph(); values_ref->push_back(GetVar(any_fg)); if (!utils::isa(any_fg)) { MS_LOG(EXCEPTION) << "VisitError, fn not return the same type FuncGraphPtr"; } new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); } else { any_fg = cnode->func_graph_as_var(); values_ref->push_back(GetVar(any_fg)); if (utils::isa(any_fg)) { new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); } else if (utils::isa(any_fg)) { new_cnode = std::make_shared(new_inputs, utils::cast(any_fg)); } else { MS_LOG(EXCEPTION) << "VisitError, fn not return VarPtr or FuncGraphPtr"; } } new_cnode->set_abstract(cnode->abstract()); *output = new_cnode; } void Visitor::Visit(const ValueNodePtr &vnode, VectorRef *const values_ref, AnfNodePtr *output) const { values_ref->push_back(GetVar(vnode->value())); const BaseRef &value = utils::cast(vnode->value()); if (utils::isa(value)) { if (output != nullptr) { auto ct = NewValueNode(utils::cast(value)); ct->set_abstract(vnode->abstract()); *output = ct; } return; } MS_LOG(EXCEPTION) << "Visit result is not ValuePtr."; } } // namespace mindspore