| @@ -0,0 +1,28 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include "ir/anf.h" | |||||
| #include "pybind_api/api_register.h" | |||||
| namespace mindspore { | |||||
| // Define python 'RefKey' class. | |||||
| REGISTER_PYBIND_DEFINE(CNode, ([](const pybind11::module *m) { | |||||
| (void)py::class_<CNode, CNodePtr>(*m, "CNode") | |||||
| .def("expanded_str", (std::string(CNode::*)(int) const) & CNode::DebugString, | |||||
| "Get CNode string representation with specified expansion level."); | |||||
| })); | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include "ir/meta_func_graph.h" | |||||
| #include "ir/func_graph.h" | |||||
| #include "pybind_api/api_register.h" | |||||
| #include "pybind_api/export_flags.h" | |||||
| namespace mindspore { | |||||
| REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { | |||||
| // Define python "MetaFuncGraph_" class | |||||
| (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_") | |||||
| .def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_) | |||||
| .def(py::init<std::string &>()); | |||||
| // Define python "FuncGraph" class | |||||
| (void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph") | |||||
| .def(py::init()) | |||||
| .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.") | |||||
| .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph"); | |||||
| })); | |||||
| } // namespace mindspore | |||||
| @@ -54,6 +54,7 @@ void ResolveFuncGraph_(const FuncGraphPtr &fg) { | |||||
| auto manager = Manage(fg, false); | auto manager = Manage(fg, false); | ||||
| parse::python_adapter::set_use_signature_in_resolve(false); | parse::python_adapter::set_use_signature_in_resolve(false); | ||||
| parse::ResolveAll(manager); | parse::ResolveAll(manager); | ||||
| parse::python_adapter::set_use_signature_in_resolve(true); | |||||
| } | } | ||||
| bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { | bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { | ||||
| @@ -437,7 +437,7 @@ bool ResolveActionPyStub(const ResourcePtr &res) { | |||||
| } | } | ||||
| bool OptActionPyStub(const ResourcePtr &res) { | bool OptActionPyStub(const ResourcePtr &res) { | ||||
| ActionPyStub(res, opt::python_pass::Phase::RESOLVE); | |||||
| ActionPyStub(res, opt::python_pass::Phase::OPT); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -38,7 +38,6 @@ | |||||
| #endif | #endif | ||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| using FuncGraph = mindspore::FuncGraph; | |||||
| using EnvInstance = mindspore::EnvInstance; | using EnvInstance = mindspore::EnvInstance; | ||||
| using ExecutorPy = mindspore::pipeline::ExecutorPy; | using ExecutorPy = mindspore::pipeline::ExecutorPy; | ||||
| using Pipeline = mindspore::pipeline::Pipeline; | using Pipeline = mindspore::pipeline::Pipeline; | ||||
| @@ -54,10 +53,6 @@ using CostModelContext = mindspore::parallel::CostModelContext; | |||||
| PYBIND11_MODULE(_c_expression, m) { | PYBIND11_MODULE(_c_expression, m) { | ||||
| m.doc() = "MindSpore c plugin"; | m.doc() = "MindSpore c plugin"; | ||||
| (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_") | |||||
| .def_readonly(mindspore::PYTHON_METAFUNCGRAPH_FLAG, &mindspore::MetaFuncGraph::parse_info_) | |||||
| .def(py::init<std::string &>()); | |||||
| auto fns = mindspore::PybindDefineRegister::AllFuncs(); | auto fns = mindspore::PybindDefineRegister::AllFuncs(); | ||||
| for (auto &item : fns) { | for (auto &item : fns) { | ||||
| item.second(&m); | item.second(&m); | ||||
| @@ -85,8 +80,6 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| py::arg("broadcast_params") = py::dict(), "Build data graph.") | py::arg("broadcast_params") = py::dict(), "Build data graph.") | ||||
| .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") | .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") | ||||
| .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); | .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); | ||||
| // Class Graph interface | |||||
| (void)py::class_<FuncGraph, mindspore::FuncGraphPtr>(m, "FuncGraph").def(py::init()); | |||||
| (void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_") | (void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_") | ||||
| .def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_) | .def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_) | ||||
| @@ -146,7 +146,7 @@ class Primitive(Primitive_): | |||||
| Check whether or not certain inputs should go into backend. Subclass in need should override this method. | Check whether or not certain inputs should go into backend. Subclass in need should override this method. | ||||
| Args: | Args: | ||||
| Same as arguments of current Primitive | |||||
| *args(Primitive args): Same as arguments of current Primitive. | |||||
| Returns: | Returns: | ||||
| A tuple of two elements, first element indicates whether or not we should filter out current arguments; | A tuple of two elements, first element indicates whether or not we should filter out current arguments; | ||||
| @@ -0,0 +1,64 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import mindspore | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.python_pass_register import registe_pass, PyPassManager | |||||
| from mindspore.common.api import _generate_pip_args | |||||
| from mindspore._c_expression import generate_key, Executor_ | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| def get_func_graph(obj, *args, phase="predict"): | |||||
| args_names, args_list = _generate_pip_args(obj, *args) | |||||
| dic = dict(zip(args_names, args_list)) | |||||
| key = generate_key(phase, dic) | |||||
| phase_prefix = str(key[1]) | |||||
| if phase == 'export': | |||||
| phase = phase + '.' + phase_prefix + '.' + str(obj.create_time) | |||||
| else: | |||||
| phase = phase_prefix + phase + '.' + str(obj.create_time) | |||||
| _executor = Executor_.get_instance() | |||||
| _executor.compile(obj, args_list, phase, False) | |||||
| return _executor.get_func_graph(phase) | |||||
| def test_softmax_relu(): | |||||
| """ | |||||
| Use python pass to transform from Softmax to ReLU. | |||||
| """ | |||||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||||
| softmax_model = nn.Softmax() | |||||
| @registe_pass(run_only_once=True) | |||||
| def softmax_relu_pass(): | |||||
| softmax = P.Softmax() | |||||
| relu = P.ReLU() | |||||
| def pattern(x): | |||||
| x = softmax(x) | |||||
| return x | |||||
| def target(x): | |||||
| x = relu(x) | |||||
| return x | |||||
| return pattern, target | |||||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||||
| ppm = PyPassManager() | |||||
| ppm.unregiste(softmax_relu_pass) | |||||
| assert "ReLU" in transformed_repr | |||||
| assert "Softmax" not in transformed_repr | |||||