| @@ -346,10 +346,6 @@ class TensorAddByZero : public AnfVisitor { | |||||
| } | } | ||||
| void Visit(const AnfNodePtr &node) override { | void Visit(const AnfNodePtr &node) override { | ||||
| if (IsPrimitive(node, prim::kPrimZerosLike)) { | |||||
| is_zero_ = true; | |||||
| return; | |||||
| } | |||||
| if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { | if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { | ||||
| is_zero_ = true; | is_zero_ = true; | ||||
| return; | return; | ||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "optimizer/pass_group.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace python_pass { | |||||
| void PassGroup::AddPass(const PythonPassPtr &pass) { | |||||
| if (pass != nullptr) { | |||||
| passes_.push_back(pass); | |||||
| } | |||||
| } | |||||
| bool PassGroup::DeletePass(const std::string &pass_name) { | |||||
| for (auto iter = passes_.begin(); iter != passes_.end(); iter++) { | |||||
| if ((*iter)->name() == pass_name) { | |||||
| *iter = nullptr; | |||||
| passes_.erase(iter); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const { | |||||
| if (func_graph == nullptr) { | |||||
| return false; | |||||
| } | |||||
| bool changed = false; | |||||
| for (const auto &pass : passes) { | |||||
| if (pass != nullptr) { | |||||
| if (pass->Run(func_graph)) { | |||||
| changed = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| bool PassGroup::Run(const FuncGraphPtr &func_graph) const { | |||||
| bool changed = false; | |||||
| // run all passes | |||||
| bool change = true; | |||||
| while (change) { | |||||
| change = Run(func_graph, passes_); | |||||
| changed = change || changed; | |||||
| if (run_only_once_) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| return changed; | |||||
| } | |||||
| } // namespace python_pass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,61 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ | |||||
| #define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "optimizer/py_pass.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace python_pass { | |||||
| class PassGroup { | |||||
| public: | |||||
| explicit PassGroup(const std::string &name = "pass_group", bool run_only_once = false) | |||||
| : name_(name), passes_{}, run_only_once_(run_only_once) {} | |||||
| virtual ~PassGroup() = default; | |||||
| // Add graph pass, the pass object will be freed when pass manager freed. | |||||
| void AddPass(const PythonPassPtr &pass); | |||||
| // Delete graph pass before the pass manager is freed. | |||||
| bool DeletePass(const std::string &pass_name); | |||||
| // Run passes added in pass manager on the input graph | |||||
| // @param [inout] graph The graph to be optimized | |||||
| // @return true, graph changed | |||||
| // @return false, graph not changed | |||||
| bool Run(const FuncGraphPtr &func_graph) const; | |||||
| // Run the given graph passes on the input graph | |||||
| // @param [inout] graph The graph to be optimized | |||||
| // @param [in] passes The given graph passes | |||||
| // @return true, graph changed | |||||
| // @return false, graph not changed | |||||
| bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const; | |||||
| std::string name() const { return name_; } | |||||
| private: | |||||
| const std::string name_; | |||||
| std::vector<PythonPassPtr> passes_; | |||||
| bool run_only_once_; | |||||
| }; | |||||
| using PassGroupPtr = std::shared_ptr<PassGroup>; | |||||
| } // namespace python_pass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ | |||||
| @@ -0,0 +1,236 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "optimizer/py_pass.h" | |||||
| #include <unordered_set> | |||||
| #include <deque> | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "ir/func_graph.h" | |||||
| #include "ir/manager.h" | |||||
| #include "pipeline/parse/parse_base.h" | |||||
| #include "pipeline/resource.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace python_pass { | |||||
| namespace internal { | |||||
| std::string GetNodeRepr(AnfNodePtr node) { | |||||
| if (node != nullptr) { | |||||
| if (node->isa<CNode>()) { | |||||
| std::string repr = "("; | |||||
| auto const &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| for (auto &input : inputs) { | |||||
| repr += " "; | |||||
| repr += GetNodeRepr(input); | |||||
| repr += " "; | |||||
| } | |||||
| repr += ")"; | |||||
| return repr; | |||||
| } | |||||
| if (node->isa<ValueNode>()) { | |||||
| return GetValueNode(node)->ToString(); | |||||
| } | |||||
| return node->ToString(); | |||||
| } | |||||
| return ""; | |||||
| } | |||||
| void ResolveFuncGraph_(const FuncGraphPtr &fg) { | |||||
| auto manager = Manage(fg, false); | |||||
| parse::python_adapter::set_use_signature_in_resolve(false); | |||||
| parse::ResolveAll(manager); | |||||
| } | |||||
| bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { | |||||
| if (node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(pattern); | |||||
| if (pattern->isa<ValueNode>()) { | |||||
| if (!node->isa<ValueNode>()) { | |||||
| return false; | |||||
| } | |||||
| if (GetNodeRepr(pattern) == GetNodeRepr(node)) { | |||||
| // add to equiv_ptr | |||||
| equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node)); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } else if (pattern->isa<Parameter>()) { | |||||
| MS_LOG(DEBUG) << pattern->ToString() + "\n"; | |||||
| // add to equiv_ptr | |||||
| equiv_ptr->insert(std::make_pair(pattern->ToString(), node)); | |||||
| return true; | |||||
| } else if (pattern->isa<CNode>()) { | |||||
| // match every single sub ANode | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto pattern_inputs = pattern->cast<CNodePtr>()->inputs(); | |||||
| auto node_inputs = node->cast<CNodePtr>()->inputs(); | |||||
| if (pattern_inputs.size() != node_inputs.size()) { | |||||
| return false; | |||||
| } | |||||
| for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end(); | |||||
| p_item++, node_item++) { | |||||
| auto res = Match(*p_item, *node_item, equiv_ptr); | |||||
| if (!res) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n"; | |||||
| } | |||||
| AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_, | |||||
| const NodeEquivPtr &equiv_ptr) { | |||||
| if (cur_raw_dst_node_->isa<Parameter>()) { | |||||
| auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString()); | |||||
| if (sub_pair != equiv_ptr->end()) { | |||||
| return sub_pair->second; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n"; | |||||
| } else if (cur_raw_dst_node_->isa<ValueNode>()) { | |||||
| // check primitive ValueNode | |||||
| auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast<ValueNodePtr>()->value()->ToString()); | |||||
| if (sub_pair != equiv_ptr->end()) { | |||||
| return sub_pair->second; | |||||
| } | |||||
| return cur_raw_dst_node_; | |||||
| } else if (cur_raw_dst_node_->isa<CNode>()) { | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| auto inputs = cur_raw_dst_node_->cast<CNodePtr>()->inputs(); | |||||
| for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) { | |||||
| auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr); | |||||
| new_inputs.push_back(subed); | |||||
| } | |||||
| return func_graph->NewCNode(new_inputs); | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_); | |||||
| } | |||||
| bool isTraversable(const AnfNodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| return false; | |||||
| } | |||||
| if (node->isa<CNode>() || node->isa<Parameter>()) { | |||||
| return true; | |||||
| } | |||||
| if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace internal | |||||
| void PythonPass::Build(const py::function &src, const py::function &dst) { | |||||
| // 1. get FuncGraph from py::function | |||||
| auto src_fg_ = parse::ParsePythonCode(src); | |||||
| auto dst_fg_ = parse::ParsePythonCode(dst); | |||||
| if (src_fg_ == nullptr || dst_fg_ == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Failed to parse python code.\n"; | |||||
| } | |||||
| // 2. Resolve | |||||
| internal::ResolveFuncGraph_(src_fg_); | |||||
| internal::ResolveFuncGraph_(dst_fg_); | |||||
| // 3. from FuncGraphPtr to ValueNode | |||||
| src_node_ = src_fg_->output(); | |||||
| dst_node_ = dst_fg_->output(); | |||||
| } | |||||
| PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once, | |||||
| bool multigraph) | |||||
| : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) { | |||||
| Build(src, dst); | |||||
| } | |||||
| AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| auto equiv_ptr = std::make_shared<NodeEquiv>(); | |||||
| bool is_a_match = internal::Match(src_node_, node, equiv_ptr); | |||||
| if (is_a_match) { | |||||
| auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr); | |||||
| MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; | |||||
| return new_node; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| bool PythonPass::Run(const FuncGraphPtr &func_graph) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| manager->AddFuncGraph(func_graph); | |||||
| auto seen = NewSeenGeneration(); | |||||
| // 1024 is for the initial capacity of deque | |||||
| std::deque<AnfNodePtr> todo(1024); | |||||
| todo.push_back(func_graph->output()); | |||||
| bool changes = false; | |||||
| auto &all_nodes = manager->all_nodes(); | |||||
| while (!todo.empty()) { | |||||
| AnfNodePtr node = todo.front(); | |||||
| todo.pop_front(); | |||||
| // check whether this node has been matched. | |||||
| if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) { | |||||
| continue; | |||||
| } | |||||
| node->seen_ = seen; | |||||
| // select nodes that this transform can be applied. | |||||
| AnfNodePtr new_node = Run(func_graph, node); | |||||
| bool change = (new_node != nullptr); | |||||
| if (new_node != nullptr && new_node != node) { | |||||
| (void)manager->Replace(node, new_node); | |||||
| } else if (new_node == nullptr) { | |||||
| new_node = node; | |||||
| } | |||||
| if (run_only_once_) { | |||||
| return change; | |||||
| } | |||||
| // find success, and add them to todo list | |||||
| if (IsValueNode<FuncGraph>(node)) { | |||||
| todo.push_back(GetValueNode<FuncGraphPtr>(node)->output()); | |||||
| } | |||||
| if (node->isa<CNode>()) { | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); | |||||
| } | |||||
| auto &node_users = manager->node_users(); | |||||
| if (change && node_users.find(node) != node_users.end()) { | |||||
| for (auto &use : node_users[node]) { | |||||
| auto use_node = use.first; | |||||
| if (use_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| todo.push_back(use_node); | |||||
| if (use_node->seen_ == seen) { | |||||
| use_node->seen_--; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return changes; | |||||
| } | |||||
| } // namespace python_pass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ | |||||
| #define MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include "ir/anf.h" | |||||
| #include "pybind_api/api_register.h" | |||||
| #include "pybind_api/export_flags.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace python_pass { | |||||
| class PythonPass; | |||||
| using PythonPassPtr = std::shared_ptr<PythonPass>; | |||||
| using NodeEquiv = std::unordered_map<std::string, AnfNodePtr>; | |||||
| using NodeEquivPtr = std::shared_ptr<NodeEquiv>; | |||||
| class PythonPass { | |||||
| public: | |||||
| explicit PythonPass(const std::string &name, const py::function &src, const py::function &dst, | |||||
| bool run_only_once = false, bool multigraph = true); | |||||
| ~PythonPass() = default; | |||||
| bool Run(const FuncGraphPtr &func_graph); | |||||
| std::string name() const { return name_; } | |||||
| AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node); | |||||
| private: | |||||
| void Build(const py::function &src, const py::function &dst); | |||||
| AnfNodePtr src_node_ = nullptr; | |||||
| AnfNodePtr dst_node_ = nullptr; | |||||
| const std::string name_; | |||||
| bool run_only_once_; | |||||
| bool multigraph_ = true; | |||||
| }; | |||||
| using PythonPassPtr = std::shared_ptr<PythonPass>; | |||||
| } // namespace python_pass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ | |||||
| @@ -0,0 +1,84 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "optimizer/py_pass_manager.h" | |||||
| #include <functional> | |||||
| #include <algorithm> | |||||
| #include <utility> | |||||
| #include <initializer_list> | |||||
| #include "ir/manager.h" | |||||
| #include "optimizer/pass_group.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace python_pass { | |||||
| PyPassManagerPtr PyPassManager::global_instance = nullptr; | |||||
| std::unordered_map<Phase, PassGroupPtr> PyPassManager::phase_to_group_; | |||||
| PassGroupPtr PyPassManager::GetPassGroup(Phase phase) { | |||||
| auto pm = phase_to_group_.find(phase); | |||||
| if (pm == phase_to_group_.end()) { | |||||
| return nullptr; | |||||
| } | |||||
| return pm->second; | |||||
| } | |||||
| PyPassManagerPtr PyPassManager::GetInstance() { | |||||
| if (global_instance == nullptr) { | |||||
| global_instance = std::shared_ptr<PyPassManager>(new (std::nothrow) PyPassManager()); | |||||
| } | |||||
| return global_instance; | |||||
| } | |||||
| PyPassManager::PyPassManager() { | |||||
| phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>(); | |||||
| phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>(); | |||||
| } | |||||
| void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, | |||||
| Phase phase, bool run_only_once, bool multigraph) { | |||||
| auto cur_pm = GetPassGroup(phase); | |||||
| MS_EXCEPTION_IF_NULL(cur_pm); | |||||
| PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once, multigraph); | |||||
| cur_pm->AddPass(new_pass); | |||||
| } | |||||
| void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { | |||||
| auto cur_pm = GetPassGroup(phase); | |||||
| MS_EXCEPTION_IF_NULL(cur_pm); | |||||
| if (!cur_pm->DeletePass(pass_name)) { | |||||
| MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; | |||||
| } | |||||
| } | |||||
| void PyPassManager::ClearRes() { | |||||
| MS_LOG(INFO) << "Clear PyPassManager resources!"; | |||||
| global_instance = nullptr; | |||||
| phase_to_group_.clear(); | |||||
| } | |||||
| REGISTER_PYBIND_DEFINE( | |||||
| PyPassManager_, ([](const py::module *m) { | |||||
| (void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT); | |||||
| (void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_") | |||||
| .def(py::init([]() { return PyPassManager::GetInstance(); })) | |||||
| .def("registe", &PyPassManager::Registe, "Registe python pass") | |||||
| .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass"); | |||||
| })); | |||||
| } // namespace python_pass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,66 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ | |||||
| #define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include "ir/anf.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "ir/primitive.h" | |||||
| #include "utils/graph_utils.h" | |||||
| #include "common/utils.h" | |||||
| #include "pipeline/parse/resolve.h" | |||||
| #include "optimizer/py_pass.h" | |||||
| #include "optimizer/pass_group.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace python_pass { | |||||
| class PyPassManager; | |||||
| using PyPassManagerPtr = std::shared_ptr<PyPassManager>; | |||||
| enum Phase { RESOLVE, OPT }; | |||||
| class PyPassManager { | |||||
| protected: | |||||
| PyPassManager(); | |||||
| static PyPassManagerPtr global_instance; | |||||
| public: | |||||
| // Singletons should not be cloneable and assignable | |||||
| PyPassManager(const PyPassManager &other) = delete; | |||||
| void operator=(const PyPassManager &) = delete; | |||||
| // Access the only global instance | |||||
| static PyPassManagerPtr GetInstance(); | |||||
| virtual ~PyPassManager() = default; | |||||
| void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, | |||||
| Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); | |||||
| void Unregiste(const std::string &pass_name, Phase phase); | |||||
| PassGroupPtr GetPassGroup(Phase phase); | |||||
| void ClearRes(); | |||||
| private: | |||||
| static std::unordered_map<Phase, PassGroupPtr> phase_to_group_; | |||||
| }; | |||||
| } // namespace python_pass | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ | |||||
| @@ -39,6 +39,7 @@ | |||||
| #include "optimizer/optimizer.h" | #include "optimizer/optimizer.h" | ||||
| #include "vm/transform.h" | #include "vm/transform.h" | ||||
| #include "parse/python_adapter.h" | #include "parse/python_adapter.h" | ||||
| #include "optimizer/py_pass_manager.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pipeline { | namespace pipeline { | ||||
| @@ -420,6 +421,25 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { | |||||
| bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } | bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } | ||||
| void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { | |||||
| MS_EXCEPTION_IF_NULL(res->manager()); | |||||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||||
| auto ppm = opt::python_pass::PyPassManager::GetInstance(); | |||||
| if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { | |||||
| MS_LOG(DEBUG) << "No match.\n"; | |||||
| } | |||||
| } | |||||
| bool ResolveActionPyStub(const ResourcePtr &res) { | |||||
| ActionPyStub(res, opt::python_pass::Phase::RESOLVE); | |||||
| return true; | |||||
| } | |||||
| bool OptActionPyStub(const ResourcePtr &res) { | |||||
| ActionPyStub(res, opt::python_pass::Phase::RESOLVE); | |||||
| return true; | |||||
| } | |||||
| static std::vector<ActionItem> CommonPipeline() { | static std::vector<ActionItem> CommonPipeline() { | ||||
| std::vector<ActionItem> actions; | std::vector<ActionItem> actions; | ||||
| @@ -432,6 +452,8 @@ static std::vector<ActionItem> CommonPipeline() { | |||||
| if (!multi_graphs) { | if (!multi_graphs) { | ||||
| actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | ||||
| } | } | ||||
| // Add resolve-stage python pass stub | |||||
| actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); | |||||
| actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); | actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); | ||||
| // Evaluate type and shape, and specialize | // Evaluate type and shape, and specialize | ||||
| actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | ||||
| @@ -443,6 +465,8 @@ std::vector<ActionItem> GePipeline() { | |||||
| auto actions = CommonPipeline(); | auto actions = CommonPipeline(); | ||||
| // optimize | // optimize | ||||
| actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); | actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); | ||||
| // Add opt-stage python pass stub | |||||
| actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); | |||||
| actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); | actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); | ||||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | actions.emplace_back(std::make_pair("validate", ValidateAction)); | ||||
| return actions; | return actions; | ||||
| @@ -454,6 +478,9 @@ std::vector<ActionItem> VmPipeline() { | |||||
| // optimize | // optimize | ||||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | ||||
| // Add opt-stage python pass stub | |||||
| actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); | |||||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | actions.emplace_back(std::make_pair("validate", ValidateAction)); | ||||
| // compile the ANF graph | // compile the ANF graph | ||||
| @@ -39,6 +39,7 @@ | |||||
| #include "device/kernel_runtime_manager.h" | #include "device/kernel_runtime_manager.h" | ||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| #include "pynative/pynative_execute.h" | #include "pynative/pynative_execute.h" | ||||
| #include "optimizer/py_pass_manager.h" | |||||
| #if (ENABLE_GE || ENABLE_D) | #if (ENABLE_GE || ENABLE_D) | ||||
| #include "pipeline/pipeline_ge.h" | #include "pipeline/pipeline_ge.h" | ||||
| @@ -964,6 +965,7 @@ void ClearResAtexit() { | |||||
| pipeline::ExecutorPy::ClearRes(); | pipeline::ExecutorPy::ClearRes(); | ||||
| pipeline::ReclaimOptimizer(); | pipeline::ReclaimOptimizer(); | ||||
| pynative::PynativeExecutor::GetInstance()->ClearRes(); | pynative::PynativeExecutor::GetInstance()->ClearRes(); | ||||
| opt::python_pass::PyPassManager::GetInstance()->ClearRes(); | |||||
| #ifdef ENABLE_GE | #ifdef ENABLE_GE | ||||
| transform::DfGraphManager::GetInstance().ClearGraph(); | transform::DfGraphManager::GetInstance().ClearGraph(); | ||||
| transform::DfGraphConvertor::get_adpt_map().clear(); | transform::DfGraphConvertor::get_adpt_map().clear(); | ||||
| @@ -0,0 +1,80 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Python pass register""" | |||||
| from inspect import isfunction | |||||
| from mindspore._c_expression import PyPassManager_ | |||||
| from mindspore._c_expression import phase | |||||
| class PyPassManager(PyPassManager_): | |||||
| r""" | |||||
| Used to registe and unregiste python passes which can be used to alter graphs. | |||||
| Args: | |||||
| pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt. | |||||
| run_only_once (bool): Specify whether or not to run pass only once. Default: False. | |||||
| multigraph (bool): Whether or not the pattern exists across graphs. Default: True. | |||||
| Raises: | |||||
| TypeError: If argument has invalid type. | |||||
| """ | |||||
| def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True): | |||||
| if not isinstance(pipeline_phase, phase): | |||||
| raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}") | |||||
| if not isinstance(run_only_once, bool): | |||||
| raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}") | |||||
| if not isinstance(multi_graph, bool): | |||||
| raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}") | |||||
| PyPassManager_.__init__(self) | |||||
| self.phase_ = pipeline_phase | |||||
| self.run_only_once_ = run_only_once | |||||
| self.multi_graph_ = multi_graph | |||||
| def registe(self, py_pass): | |||||
| if not isfunction(py_pass): | |||||
| raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}") | |||||
| pattern, target = py_pass() | |||||
| pass_name = py_pass.__name__ | |||||
| if not isfunction(pattern): | |||||
| raise TypeError(f"Expecting function pattern, got : ({type(pattern)}){pattern}") | |||||
| if not isfunction(target): | |||||
| raise TypeError(f"Expecting function target, got : ({type(target)}){target}") | |||||
| super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_) | |||||
| def unregiste(self, py_pass, pipeline_phase=phase.opt): | |||||
| if not isinstance(pipeline_phase, phase): | |||||
| raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}") | |||||
| if isinstance(py_pass, str): | |||||
| super().unregiste(py_pass, pipeline_phase) | |||||
| return | |||||
| if isfunction(py_pass): | |||||
| super().unregiste(py_pass.__name__, pipeline_phase) | |||||
| return | |||||
| raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}") | |||||
| def __call__(self, py_pass): | |||||
| self.registe(py_pass) | |||||
| return py_pass | |||||
| def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True): | |||||
| """ | |||||
| Examples: | |||||
| >>> @registe_pass() | |||||
| >>> def toy_pass(): | |||||
| >>> def pattern(): | |||||
| >>> pass | |||||
| >>> def target(): | |||||
| >>> pass | |||||
| """ | |||||
| return PyPassManager(pipeline_phase, run_only_once, multi_graph) | |||||
| @@ -170,7 +170,8 @@ class Dense(Cell): | |||||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | ||||
| same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | ||||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | ||||
| activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | |||||
| activation (str): activate function applied to the output of the fully connected layer, eg. 'relu'. | |||||
| Default: None. | |||||
| Raises: | Raises: | ||||
| ValueError: If weight_init or bias_init shape is incorrect. | ValueError: If weight_init or bias_init shape is incorrect. | ||||