/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ #include #include #include #include #include #include "optimizer/irpass.h" #include "optimizer/optimizer.h" #include "ir/visitor.h" #include "ir/manager.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" #include "operator/ops.h" namespace mindspore { namespace opt { namespace irpass { namespace internal { class SpecializeTransform { public: SpecializeTransform() : cache_() {} ~SpecializeTransform() = default; FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector graph_args, std::vector prim_args) { if (cache_.count(func_graph) == 0) { cache_[func_graph] = {}; } auto &cache = cache_[func_graph]; auto key = std::make_pair(graph_args, prim_args); if (cache.count(key) == 0) { auto mng = func_graph->manager(); MS_EXCEPTION_IF_NULL(mng); FuncGraphPtr new_fg = TransformableClone(func_graph, std::make_shared("sp")); mng->AddFuncGraph(new_fg); std::vector params = new_fg->parameters(); std::vector new_params; size_t n = graph_args.size(); for (size_t i = 0; i < n; i++) { if (graph_args[i] != nullptr) { auto arg = NewValueNode(graph_args[i]); (void)mng->Replace(params[i], arg); continue; } if (prim_args[i] != nullptr) { auto arg = NewValueNode(prim_args[i]); (void)mng->Replace(params[i], arg); continue; } new_params.push_back(params[i]); } mng->SetParameters(new_fg, new_params); cache[key] = new_fg; } return cache[key]; } private: std::unordered_map, std::vector>, FuncGraphPtr>> cache_; }; } // namespace internal // {G, Xs} class SpecializeOnGraphArguments : public AnfVisitor { public: SpecializeOnGraphArguments() : specialize_transform_() {} ~SpecializeOnGraphArguments() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; } auto &inputs = node->cast()->inputs(); if (!IsValueNode(inputs[0])) { return nullptr; } auto inp0_fg = GetValueNode(inputs[0]); if (inp0_fg->recursive()) { return nullptr; } std::vector graph_args; std::vector prim_args; std::vector new_xs; bool hasVNode = false; for (size_t i = 1; i < inputs.size(); i++) { if (IsValueNode(inputs[i])) { auto fg_vnode = GetValueNode(inputs[i]); graph_args.push_back(fg_vnode); prim_args.emplace_back(nullptr); hasVNode = true; } else if (IsValueNode(inputs[i])) { auto p_vnode = GetValueNode(inputs[i]); graph_args.emplace_back(nullptr); prim_args.push_back(p_vnode); hasVNode = true; } else { graph_args.emplace_back(nullptr); prim_args.emplace_back(nullptr); new_xs.push_back(inputs[i]); } } if (!hasVNode) { return nullptr; } auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args); (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); return node->func_graph()->NewCNode(new_xs); } private: internal::SpecializeTransform specialize_transform_; }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_