| @@ -346,10 +346,6 @@ class TensorAddByZero : public AnfVisitor { | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| if (IsPrimitive(node, prim::kPrimZerosLike)) { | |||
| is_zero_ = true; | |||
| return; | |||
| } | |||
| if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) { | |||
| is_zero_ = true; | |||
| return; | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "optimizer/pass_group.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace python_pass { | |||
| void PassGroup::AddPass(const PythonPassPtr &pass) { | |||
| if (pass != nullptr) { | |||
| passes_.push_back(pass); | |||
| } | |||
| } | |||
| bool PassGroup::DeletePass(const std::string &pass_name) { | |||
| for (auto iter = passes_.begin(); iter != passes_.end(); iter++) { | |||
| if ((*iter)->name() == pass_name) { | |||
| *iter = nullptr; | |||
| passes_.erase(iter); | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const { | |||
| if (func_graph == nullptr) { | |||
| return false; | |||
| } | |||
| bool changed = false; | |||
| for (const auto &pass : passes) { | |||
| if (pass != nullptr) { | |||
| if (pass->Run(func_graph)) { | |||
| changed = true; | |||
| } | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| bool PassGroup::Run(const FuncGraphPtr &func_graph) const { | |||
| bool changed = false; | |||
| // run all passes | |||
| bool change = true; | |||
| while (change) { | |||
| change = Run(func_graph, passes_); | |||
| changed = change || changed; | |||
| if (run_only_once_) { | |||
| break; | |||
| } | |||
| } | |||
| return changed; | |||
| } | |||
| } // namespace python_pass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "optimizer/py_pass.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace python_pass { | |||
| class PassGroup { | |||
| public: | |||
| explicit PassGroup(const std::string &name = "pass_group", bool run_only_once = false) | |||
| : name_(name), passes_{}, run_only_once_(run_only_once) {} | |||
| virtual ~PassGroup() = default; | |||
| // Add graph pass, the pass object will be freed when pass manager freed. | |||
| void AddPass(const PythonPassPtr &pass); | |||
| // Delete graph pass before the pass manager is freed. | |||
| bool DeletePass(const std::string &pass_name); | |||
| // Run passes added in pass manager on the input graph | |||
| // @param [inout] graph The graph to be optimized | |||
| // @return true, graph changed | |||
| // @return false, graph not changed | |||
| bool Run(const FuncGraphPtr &func_graph) const; | |||
| // Run the given graph passes on the input graph | |||
| // @param [inout] graph The graph to be optimized | |||
| // @param [in] passes The given graph passes | |||
| // @return true, graph changed | |||
| // @return false, graph not changed | |||
| bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const; | |||
| std::string name() const { return name_; } | |||
| private: | |||
| const std::string name_; | |||
| std::vector<PythonPassPtr> passes_; | |||
| bool run_only_once_; | |||
| }; | |||
| using PassGroupPtr = std::shared_ptr<PassGroup>; | |||
| } // namespace python_pass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_ | |||
| @@ -0,0 +1,236 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "optimizer/py_pass.h" | |||
| #include <unordered_set> | |||
| #include <deque> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "ir/func_graph.h" | |||
| #include "ir/manager.h" | |||
| #include "pipeline/parse/parse_base.h" | |||
| #include "pipeline/resource.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace python_pass { | |||
| namespace internal { | |||
| std::string GetNodeRepr(AnfNodePtr node) { | |||
| if (node != nullptr) { | |||
| if (node->isa<CNode>()) { | |||
| std::string repr = "("; | |||
| auto const &inputs = node->cast<CNodePtr>()->inputs(); | |||
| for (auto &input : inputs) { | |||
| repr += " "; | |||
| repr += GetNodeRepr(input); | |||
| repr += " "; | |||
| } | |||
| repr += ")"; | |||
| return repr; | |||
| } | |||
| if (node->isa<ValueNode>()) { | |||
| return GetValueNode(node)->ToString(); | |||
| } | |||
| return node->ToString(); | |||
| } | |||
| return ""; | |||
| } | |||
| void ResolveFuncGraph_(const FuncGraphPtr &fg) { | |||
| auto manager = Manage(fg, false); | |||
| parse::python_adapter::set_use_signature_in_resolve(false); | |||
| parse::ResolveAll(manager); | |||
| } | |||
| bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { | |||
| if (node == nullptr) { | |||
| return false; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(pattern); | |||
| if (pattern->isa<ValueNode>()) { | |||
| if (!node->isa<ValueNode>()) { | |||
| return false; | |||
| } | |||
| if (GetNodeRepr(pattern) == GetNodeRepr(node)) { | |||
| // add to equiv_ptr | |||
| equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node)); | |||
| return true; | |||
| } | |||
| return false; | |||
| } else if (pattern->isa<Parameter>()) { | |||
| MS_LOG(DEBUG) << pattern->ToString() + "\n"; | |||
| // add to equiv_ptr | |||
| equiv_ptr->insert(std::make_pair(pattern->ToString(), node)); | |||
| return true; | |||
| } else if (pattern->isa<CNode>()) { | |||
| // match every single sub ANode | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto pattern_inputs = pattern->cast<CNodePtr>()->inputs(); | |||
| auto node_inputs = node->cast<CNodePtr>()->inputs(); | |||
| if (pattern_inputs.size() != node_inputs.size()) { | |||
| return false; | |||
| } | |||
| for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end(); | |||
| p_item++, node_item++) { | |||
| auto res = Match(*p_item, *node_item, equiv_ptr); | |||
| if (!res) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n"; | |||
| } | |||
| AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_, | |||
| const NodeEquivPtr &equiv_ptr) { | |||
| if (cur_raw_dst_node_->isa<Parameter>()) { | |||
| auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString()); | |||
| if (sub_pair != equiv_ptr->end()) { | |||
| return sub_pair->second; | |||
| } | |||
| MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n"; | |||
| } else if (cur_raw_dst_node_->isa<ValueNode>()) { | |||
| // check primitive ValueNode | |||
| auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast<ValueNodePtr>()->value()->ToString()); | |||
| if (sub_pair != equiv_ptr->end()) { | |||
| return sub_pair->second; | |||
| } | |||
| return cur_raw_dst_node_; | |||
| } else if (cur_raw_dst_node_->isa<CNode>()) { | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| auto inputs = cur_raw_dst_node_->cast<CNodePtr>()->inputs(); | |||
| for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) { | |||
| auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr); | |||
| new_inputs.push_back(subed); | |||
| } | |||
| return func_graph->NewCNode(new_inputs); | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_); | |||
| } | |||
| bool isTraversable(const AnfNodePtr &node) { | |||
| if (node == nullptr) { | |||
| return false; | |||
| } | |||
| if (node->isa<CNode>() || node->isa<Parameter>()) { | |||
| return true; | |||
| } | |||
| if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace internal | |||
| void PythonPass::Build(const py::function &src, const py::function &dst) { | |||
| // 1. get FuncGraph from py::function | |||
| auto src_fg_ = parse::ParsePythonCode(src); | |||
| auto dst_fg_ = parse::ParsePythonCode(dst); | |||
| if (src_fg_ == nullptr || dst_fg_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Failed to parse python code.\n"; | |||
| } | |||
| // 2. Resolve | |||
| internal::ResolveFuncGraph_(src_fg_); | |||
| internal::ResolveFuncGraph_(dst_fg_); | |||
| // 3. from FuncGraphPtr to ValueNode | |||
| src_node_ = src_fg_->output(); | |||
| dst_node_ = dst_fg_->output(); | |||
| } | |||
| PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once, | |||
| bool multigraph) | |||
| : name_(name), run_only_once_(run_only_once), multigraph_(multigraph) { | |||
| Build(src, dst); | |||
| } | |||
| AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| auto equiv_ptr = std::make_shared<NodeEquiv>(); | |||
| bool is_a_match = internal::Match(src_node_, node, equiv_ptr); | |||
| if (is_a_match) { | |||
| auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr); | |||
| MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n"; | |||
| return new_node; | |||
| } | |||
| return nullptr; | |||
| } | |||
| bool PythonPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(func_graph); | |||
| auto seen = NewSeenGeneration(); | |||
| // 1024 is for the initial capacity of deque | |||
| std::deque<AnfNodePtr> todo(1024); | |||
| todo.push_back(func_graph->output()); | |||
| bool changes = false; | |||
| auto &all_nodes = manager->all_nodes(); | |||
| while (!todo.empty()) { | |||
| AnfNodePtr node = todo.front(); | |||
| todo.pop_front(); | |||
| // check whether this node has been matched. | |||
| if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) { | |||
| continue; | |||
| } | |||
| node->seen_ = seen; | |||
| // select nodes that this transform can be applied. | |||
| AnfNodePtr new_node = Run(func_graph, node); | |||
| bool change = (new_node != nullptr); | |||
| if (new_node != nullptr && new_node != node) { | |||
| (void)manager->Replace(node, new_node); | |||
| } else if (new_node == nullptr) { | |||
| new_node = node; | |||
| } | |||
| if (run_only_once_) { | |||
| return change; | |||
| } | |||
| // find success, and add them to todo list | |||
| if (IsValueNode<FuncGraph>(node)) { | |||
| todo.push_back(GetValueNode<FuncGraphPtr>(node)->output()); | |||
| } | |||
| if (node->isa<CNode>()) { | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); | |||
| } | |||
| auto &node_users = manager->node_users(); | |||
| if (change && node_users.find(node) != node_users.end()) { | |||
| for (auto &use : node_users[node]) { | |||
| auto use_node = use.first; | |||
| if (use_node == nullptr) { | |||
| continue; | |||
| } | |||
| todo.push_back(use_node); | |||
| if (use_node->seen_ == seen) { | |||
| use_node->seen_--; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return changes; | |||
| } | |||
| } // namespace python_pass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include "ir/anf.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace python_pass { | |||
| class PythonPass; | |||
| using PythonPassPtr = std::shared_ptr<PythonPass>; | |||
| using NodeEquiv = std::unordered_map<std::string, AnfNodePtr>; | |||
| using NodeEquivPtr = std::shared_ptr<NodeEquiv>; | |||
| class PythonPass { | |||
| public: | |||
| explicit PythonPass(const std::string &name, const py::function &src, const py::function &dst, | |||
| bool run_only_once = false, bool multigraph = true); | |||
| ~PythonPass() = default; | |||
| bool Run(const FuncGraphPtr &func_graph); | |||
| std::string name() const { return name_; } | |||
| AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node); | |||
| private: | |||
| void Build(const py::function &src, const py::function &dst); | |||
| AnfNodePtr src_node_ = nullptr; | |||
| AnfNodePtr dst_node_ = nullptr; | |||
| const std::string name_; | |||
| bool run_only_once_; | |||
| bool multigraph_ = true; | |||
| }; | |||
| using PythonPassPtr = std::shared_ptr<PythonPass>; | |||
| } // namespace python_pass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_H_ | |||
| @@ -0,0 +1,84 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "optimizer/py_pass_manager.h" | |||
| #include <functional> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include <initializer_list> | |||
| #include "ir/manager.h" | |||
| #include "optimizer/pass_group.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace python_pass { | |||
| PyPassManagerPtr PyPassManager::global_instance = nullptr; | |||
| std::unordered_map<Phase, PassGroupPtr> PyPassManager::phase_to_group_; | |||
| PassGroupPtr PyPassManager::GetPassGroup(Phase phase) { | |||
| auto pm = phase_to_group_.find(phase); | |||
| if (pm == phase_to_group_.end()) { | |||
| return nullptr; | |||
| } | |||
| return pm->second; | |||
| } | |||
| PyPassManagerPtr PyPassManager::GetInstance() { | |||
| if (global_instance == nullptr) { | |||
| global_instance = std::shared_ptr<PyPassManager>(new (std::nothrow) PyPassManager()); | |||
| } | |||
| return global_instance; | |||
| } | |||
| PyPassManager::PyPassManager() { | |||
| phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>(); | |||
| phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>(); | |||
| } | |||
| void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, | |||
| Phase phase, bool run_only_once, bool multigraph) { | |||
| auto cur_pm = GetPassGroup(phase); | |||
| MS_EXCEPTION_IF_NULL(cur_pm); | |||
| PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once, multigraph); | |||
| cur_pm->AddPass(new_pass); | |||
| } | |||
| void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) { | |||
| auto cur_pm = GetPassGroup(phase); | |||
| MS_EXCEPTION_IF_NULL(cur_pm); | |||
| if (!cur_pm->DeletePass(pass_name)) { | |||
| MS_LOG(WARNING) << "No such pass : " + pass_name + "\n"; | |||
| } | |||
| } | |||
| void PyPassManager::ClearRes() { | |||
| MS_LOG(INFO) << "Clear PyPassManager resources!"; | |||
| global_instance = nullptr; | |||
| phase_to_group_.clear(); | |||
| } | |||
| REGISTER_PYBIND_DEFINE( | |||
| PyPassManager_, ([](const py::module *m) { | |||
| (void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT); | |||
| (void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_") | |||
| .def(py::init([]() { return PyPassManager::GetInstance(); })) | |||
| .def("registe", &PyPassManager::Registe, "Registe python pass") | |||
| .def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass"); | |||
| })); | |||
| } // namespace python_pass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ | |||
| #define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| #include "ir/primitive.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "common/utils.h" | |||
| #include "pipeline/parse/resolve.h" | |||
| #include "optimizer/py_pass.h" | |||
| #include "optimizer/pass_group.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace python_pass { | |||
| class PyPassManager; | |||
| using PyPassManagerPtr = std::shared_ptr<PyPassManager>; | |||
| enum Phase { RESOLVE, OPT }; | |||
| class PyPassManager { | |||
| protected: | |||
| PyPassManager(); | |||
| static PyPassManagerPtr global_instance; | |||
| public: | |||
| // Singletons should not be cloneable and assignable | |||
| PyPassManager(const PyPassManager &other) = delete; | |||
| void operator=(const PyPassManager &) = delete; | |||
| // Access the only global instance | |||
| static PyPassManagerPtr GetInstance(); | |||
| virtual ~PyPassManager() = default; | |||
| void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target, | |||
| Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true); | |||
| void Unregiste(const std::string &pass_name, Phase phase); | |||
| PassGroupPtr GetPassGroup(Phase phase); | |||
| void ClearRes(); | |||
| private: | |||
| static std::unordered_map<Phase, PassGroupPtr> phase_to_group_; | |||
| }; | |||
| } // namespace python_pass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_ | |||
| @@ -39,6 +39,7 @@ | |||
| #include "optimizer/optimizer.h" | |||
| #include "vm/transform.h" | |||
| #include "parse/python_adapter.h" | |||
| #include "optimizer/py_pass_manager.h" | |||
| namespace mindspore { | |||
| namespace pipeline { | |||
| @@ -420,6 +421,25 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { | |||
| bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } | |||
| void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) { | |||
| MS_EXCEPTION_IF_NULL(res->manager()); | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| auto ppm = opt::python_pass::PyPassManager::GetInstance(); | |||
| if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) { | |||
| MS_LOG(DEBUG) << "No match.\n"; | |||
| } | |||
| } | |||
| bool ResolveActionPyStub(const ResourcePtr &res) { | |||
| ActionPyStub(res, opt::python_pass::Phase::RESOLVE); | |||
| return true; | |||
| } | |||
| bool OptActionPyStub(const ResourcePtr &res) { | |||
| ActionPyStub(res, opt::python_pass::Phase::RESOLVE); | |||
| return true; | |||
| } | |||
| static std::vector<ActionItem> CommonPipeline() { | |||
| std::vector<ActionItem> actions; | |||
| @@ -432,6 +452,8 @@ static std::vector<ActionItem> CommonPipeline() { | |||
| if (!multi_graphs) { | |||
| actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); | |||
| } | |||
| // Add resolve-stage python pass stub | |||
| actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub)); | |||
| actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); | |||
| // Evaluate type and shape, and specialize | |||
| actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); | |||
| @@ -443,6 +465,8 @@ std::vector<ActionItem> GePipeline() { | |||
| auto actions = CommonPipeline(); | |||
| // optimize | |||
| actions.emplace_back(std::make_pair("optimize", GeOptimizeAction)); | |||
| // Add opt-stage python pass stub | |||
| actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); | |||
| actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); | |||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| return actions; | |||
| @@ -454,6 +478,9 @@ std::vector<ActionItem> VmPipeline() { | |||
| // optimize | |||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | |||
| // Add opt-stage python pass stub | |||
| actions.emplace_back(std::make_pair("py_opt", OptActionPyStub)); | |||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| // compile the ANF graph | |||
| @@ -39,6 +39,7 @@ | |||
| #include "device/kernel_runtime_manager.h" | |||
| #include "debug/trace.h" | |||
| #include "pynative/pynative_execute.h" | |||
| #include "optimizer/py_pass_manager.h" | |||
| #if (ENABLE_GE || ENABLE_D) | |||
| #include "pipeline/pipeline_ge.h" | |||
| @@ -964,6 +965,7 @@ void ClearResAtexit() { | |||
| pipeline::ExecutorPy::ClearRes(); | |||
| pipeline::ReclaimOptimizer(); | |||
| pynative::PynativeExecutor::GetInstance()->ClearRes(); | |||
| opt::python_pass::PyPassManager::GetInstance()->ClearRes(); | |||
| #ifdef ENABLE_GE | |||
| transform::DfGraphManager::GetInstance().ClearGraph(); | |||
| transform::DfGraphConvertor::get_adpt_map().clear(); | |||
| @@ -0,0 +1,80 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Python pass register""" | |||
| from inspect import isfunction | |||
| from mindspore._c_expression import PyPassManager_ | |||
| from mindspore._c_expression import phase | |||
| class PyPassManager(PyPassManager_): | |||
| r""" | |||
| Used to registe and unregiste python passes which can be used to alter graphs. | |||
| Args: | |||
| pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt. | |||
| run_only_once (bool): Specify whether or not to run pass only once. Default: False. | |||
| multigraph (bool): Whether or not the pattern exists across graphs. Default: True. | |||
| Raises: | |||
| TypeError: If argument has invalid type. | |||
| """ | |||
| def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True): | |||
| if not isinstance(pipeline_phase, phase): | |||
| raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}") | |||
| if not isinstance(run_only_once, bool): | |||
| raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}") | |||
| if not isinstance(multi_graph, bool): | |||
| raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}") | |||
| PyPassManager_.__init__(self) | |||
| self.phase_ = pipeline_phase | |||
| self.run_only_once_ = run_only_once | |||
| self.multi_graph_ = multi_graph | |||
| def registe(self, py_pass): | |||
| if not isfunction(py_pass): | |||
| raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}") | |||
| pattern, target = py_pass() | |||
| pass_name = py_pass.__name__ | |||
| if not isfunction(pattern): | |||
| raise TypeError(f"Expecting function pattern, got : ({type(pattern)}){pattern}") | |||
| if not isfunction(target): | |||
| raise TypeError(f"Expecting function target, got : ({type(target)}){target}") | |||
| super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_) | |||
| def unregiste(self, py_pass, pipeline_phase=phase.opt): | |||
| if not isinstance(pipeline_phase, phase): | |||
| raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}") | |||
| if isinstance(py_pass, str): | |||
| super().unregiste(py_pass, pipeline_phase) | |||
| return | |||
| if isfunction(py_pass): | |||
| super().unregiste(py_pass.__name__, pipeline_phase) | |||
| return | |||
| raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}") | |||
| def __call__(self, py_pass): | |||
| self.registe(py_pass) | |||
| return py_pass | |||
| def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True): | |||
| """ | |||
| Examples: | |||
| >>> @registe_pass() | |||
| >>> def toy_pass(): | |||
| >>> def pattern(): | |||
| >>> pass | |||
| >>> def target(): | |||
| >>> pass | |||
| """ | |||
| return PyPassManager(pipeline_phase, run_only_once, multi_graph) | |||
| @@ -170,7 +170,8 @@ class Dense(Cell): | |||
| bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is | |||
| same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. | |||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. | |||
| activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. | |||
| activation (str): activate function applied to the output of the fully connected layer, eg. 'relu'. | |||
| Default: None. | |||
| Raises: | |||
| ValueError: If weight_init or bias_init shape is incorrect. | |||