/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ #define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_ #include #include #include #include #include #include #include "optimizer/irpass.h" #include "optimizer/optimizer.h" #include "ir/visitor.h" #include "ir/manager.h" #include "ir/func_graph.h" #include "ir/func_graph_cloner.h" #include "operator/ops.h" namespace mindspore { namespace opt { namespace irpass { namespace internal { class SpecializeTransform { public: SpecializeTransform() : cache_() {} ~SpecializeTransform() = default; FuncGraphPtr operator()(const FuncGraphPtr &func_graph, std::vector graph_args, std::vector prim_args, std::vector value_args) { if (cache_.count(func_graph) == 0) { cache_[func_graph] = {}; } auto &cache = cache_[func_graph]; auto key = std::make_pair(graph_args, prim_args); if (cache.count(key) == 0) { auto mng = func_graph->manager(); MS_EXCEPTION_IF_NULL(mng); FuncGraphPtr new_fg = TransformableClone(func_graph, std::make_shared("sp")); mng->AddFuncGraph(new_fg); std::vector params = new_fg->parameters(); std::vector new_params; size_t n = graph_args.size(); for (size_t i = 0; i < n; i++) { if (graph_args[i] != nullptr) { auto arg = NewValueNode(graph_args[i]); (void)mng->Replace(params[i], arg); continue; } if (prim_args[i] != nullptr) { auto arg = NewValueNode(prim_args[i]); (void)mng->Replace(params[i], arg); continue; } if (value_args[i] != nullptr) { auto &const_tensor = *value_args[i]; auto const_tensor_ptr = std::make_shared(const_tensor); AnfNodePtr arg = NewValueNode(const_tensor_ptr); (void)mng->Replace(params[i], arg); continue; } new_params.push_back(params[i]); } mng->SetParameters(new_fg, new_params); cache[key] = new_fg; } return cache[key]; } private: std::unordered_map, std::vector>, FuncGraphPtr>> cache_; }; } // namespace internal // {G, Xs} class SpecializeOnGraphArguments : public AnfVisitor { public: SpecializeOnGraphArguments() : specialize_transform_() {} ~SpecializeOnGraphArguments() override = default; AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; } auto &inputs = node->cast()->inputs(); if (!IsValueNode(inputs[0])) { return nullptr; } auto inp0_fg = GetValueNode(inputs[0]); if (inp0_fg->recursive()) { return nullptr; } std::vector graph_args; std::vector prim_args; std::vector value_node_args; std::vector new_xs; bool hasVNode = false; for (size_t i = 1; i < inputs.size(); i++) { if (IsValueNode(inputs[i])) { auto fg_vnode = GetValueNode(inputs[i]); graph_args.push_back(fg_vnode); prim_args.emplace_back(nullptr); value_node_args.emplace_back(nullptr); hasVNode = true; } else if (IsValueNode(inputs[i])) { auto p_vnode = GetValueNode(inputs[i]); graph_args.emplace_back(nullptr); prim_args.push_back(p_vnode); value_node_args.emplace_back(nullptr); hasVNode = true; } else if (IsValueNode(inputs[i])) { tensor::TensorPtr t_vnode = GetValueNode(inputs[i]); graph_args.emplace_back(nullptr); prim_args.emplace_back(nullptr); value_node_args.emplace_back(t_vnode); hasVNode = true; } else { graph_args.emplace_back(nullptr); prim_args.emplace_back(nullptr); value_node_args.emplace_back(nullptr); new_xs.push_back(inputs[i]); } } if (!hasVNode) { return nullptr; } auto new_fg = specialize_transform_(inp0_fg, graph_args, prim_args, value_node_args); (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); return node->func_graph()->NewCNode(new_xs); } private: internal::SpecializeTransform specialize_transform_; }; // Eliminate unused parameters. // {G, Xs} class UnusedParasEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto &inputs = cnode->inputs(); auto fg = GetValueNode(inputs[0]); MS_EXCEPTION_IF_NULL(fg); std::vector parameters = fg->parameters(); size_t size = parameters.size(); if (size != inputs.size() - 1) { return nullptr; } std::vector new_xs; std::vector keep_parameters; auto mng = fg->manager(); MS_EXCEPTION_IF_NULL(mng); auto &node_users = mng->node_users(); bool has_unused_para = false; for (size_t i = 0; i < size; ++i) { auto iter = node_users.find(parameters[i]); if (iter != node_users.end() && !iter->second.empty()) { keep_parameters.push_back(true); new_xs.push_back(inputs[i + 1]); continue; } keep_parameters.push_back(false); has_unused_para = true; } if (!has_unused_para) { return nullptr; } FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared("sp")); mng->AddFuncGraph(new_fg); std::vector new_fg_parameters = new_fg->parameters(); std::vector new_parameters; for (size_t i = 0; i < size; i++) { if (keep_parameters[i]) { if (parameters[i]->abstract() != nullptr) { new_fg_parameters[i]->set_abstract(parameters[i]->abstract()); } new_parameters.push_back(new_fg_parameters[i]); } } mng->SetParameters(new_fg, new_parameters); (void)new_xs.insert(new_xs.begin(), NewValueNode(new_fg)); return node->func_graph()->NewCNode(new_xs); } }; // Eliminate unused outputs. // {G, Xs} class UnusedOutputEliminater : public AnfVisitor { public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; } auto &inputs = node->cast()->inputs(); auto fg = GetValueNode(inputs[0]); MS_EXCEPTION_IF_NULL(fg); auto mng = fg->manager(); MS_EXCEPTION_IF_NULL(mng); if (fg->recursive()) { return nullptr; } auto new_fg = TransformableClone(fg, std::make_shared("fg")); mng->AddFuncGraph(new_fg); auto new_fg_output = new_fg->output(); if (!IsPrimitiveCNode(new_fg_output, prim::kPrimMakeTuple)) { return nullptr; } auto output_cnode = new_fg_output->cast(); auto &node_users = mng->node_users(); if (node_users.count(node) == 0 || node_users[node].empty()) { return nullptr; } std::unordered_set used_output_idx; std::vector> all_users; for (auto &node_user : node_users[node]) { if (!IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { return nullptr; } auto user_cnode = node_user.first->cast(); size_t used_idx = GetValue(user_cnode->input(2)->cast()->value()); used_output_idx.insert(used_idx); all_users.push_back(std::make_pair(node_user.first, used_idx)); } if (used_output_idx.size() >= output_cnode->inputs().size() - 1) { // all output has users. return nullptr; } if (used_output_idx.empty()) { // we do not process this case. return nullptr; } else if (used_output_idx.size() == 1) { // after eliminate, only one output left. new_fg->set_output(output_cnode->input(*used_output_idx.begin() + 1)); // update users. for (auto &ret_user : all_users) { (void)mng->Replace(ret_user.first, node); } } else { // after eliminate, create new multi output. std::vector new_output_inputs{output_cnode->input(0)}; std::unordered_map new_idx_map; for (auto idx : used_output_idx) { new_idx_map[idx] = SizeToInt(new_output_inputs.size() - 1); new_output_inputs.push_back(output_cnode->input(idx + 1)); } new_fg->set_output(new_fg->NewCNode(new_output_inputs)); // update users. for (auto &ret_user : all_users) { auto ret_user_cnode = ret_user.first->cast(); ret_user_cnode->set_input(2, NewValueNode(new_idx_map[ret_user.second])); } } auto new_sx = inputs; new_sx[0] = NewValueNode(new_fg); return node->func_graph()->NewCNode(new_sx); } }; } // namespace irpass } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPECIALIZE_TRANSFORM_H_