| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include "ir/anf.h" | |||
| #include "pybind_api/api_register.h" | |||
| namespace mindspore { | |||
| // Define python 'RefKey' class. | |||
| REGISTER_PYBIND_DEFINE(CNode, ([](const pybind11::module *m) { | |||
| (void)py::class_<CNode, CNodePtr>(*m, "CNode") | |||
| .def("expanded_str", (std::string(CNode::*)(int) const) & CNode::DebugString, | |||
| "Get CNode string representation with specified expansion level."); | |||
| })); | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include "ir/meta_func_graph.h" | |||
| #include "ir/func_graph.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "pybind_api/export_flags.h" | |||
| namespace mindspore { | |||
| REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { | |||
| // Define python "MetaFuncGraph_" class | |||
| (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_") | |||
| .def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_) | |||
| .def(py::init<std::string &>()); | |||
| // Define python "FuncGraph" class | |||
| (void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph") | |||
| .def(py::init()) | |||
| .def("str", &FuncGraph::ToString, "Get FuncGraph string representation.") | |||
| .def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph"); | |||
| })); | |||
| } // namespace mindspore | |||
| @@ -54,6 +54,7 @@ void ResolveFuncGraph_(const FuncGraphPtr &fg) { | |||
| auto manager = Manage(fg, false); | |||
| parse::python_adapter::set_use_signature_in_resolve(false); | |||
| parse::ResolveAll(manager); | |||
| parse::python_adapter::set_use_signature_in_resolve(true); | |||
| } | |||
| bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { | |||
| @@ -437,7 +437,7 @@ bool ResolveActionPyStub(const ResourcePtr &res) { | |||
| } | |||
| bool OptActionPyStub(const ResourcePtr &res) { | |||
| ActionPyStub(res, opt::python_pass::Phase::RESOLVE); | |||
| ActionPyStub(res, opt::python_pass::Phase::OPT); | |||
| return true; | |||
| } | |||
| @@ -38,7 +38,6 @@ | |||
| #endif | |||
| namespace py = pybind11; | |||
| using FuncGraph = mindspore::FuncGraph; | |||
| using EnvInstance = mindspore::EnvInstance; | |||
| using ExecutorPy = mindspore::pipeline::ExecutorPy; | |||
| using Pipeline = mindspore::pipeline::Pipeline; | |||
| @@ -54,10 +53,6 @@ using CostModelContext = mindspore::parallel::CostModelContext; | |||
| PYBIND11_MODULE(_c_expression, m) { | |||
| m.doc() = "MindSpore c plugin"; | |||
| (void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_") | |||
| .def_readonly(mindspore::PYTHON_METAFUNCGRAPH_FLAG, &mindspore::MetaFuncGraph::parse_info_) | |||
| .def(py::init<std::string &>()); | |||
| auto fns = mindspore::PybindDefineRegister::AllFuncs(); | |||
| for (auto &item : fns) { | |||
| item.second(&m); | |||
| @@ -85,8 +80,6 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| py::arg("broadcast_params") = py::dict(), "Build data graph.") | |||
| .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") | |||
| .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); | |||
| // Class Graph interface | |||
| (void)py::class_<FuncGraph, mindspore::FuncGraphPtr>(m, "FuncGraph").def(py::init()); | |||
| (void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_") | |||
| .def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_) | |||
| @@ -146,7 +146,7 @@ class Primitive(Primitive_): | |||
| Check whether or not certain inputs should go into backend. Subclass in need should override this method. | |||
| Args: | |||
| Same as arguments of current Primitive | |||
| *args(Primitive args): Same as arguments of current Primitive. | |||
| Returns: | |||
| A tuple of two elements, first element indicates whether or not we should filter out current arguments; | |||
| @@ -0,0 +1,64 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.common.python_pass_register import registe_pass, PyPassManager | |||
| from mindspore.common.api import _generate_pip_args | |||
| from mindspore._c_expression import generate_key, Executor_ | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def get_func_graph(obj, *args, phase="predict"): | |||
| args_names, args_list = _generate_pip_args(obj, *args) | |||
| dic = dict(zip(args_names, args_list)) | |||
| key = generate_key(phase, dic) | |||
| phase_prefix = str(key[1]) | |||
| if phase == 'export': | |||
| phase = phase + '.' + phase_prefix + '.' + str(obj.create_time) | |||
| else: | |||
| phase = phase_prefix + phase + '.' + str(obj.create_time) | |||
| _executor = Executor_.get_instance() | |||
| _executor.compile(obj, args_list, phase, False) | |||
| return _executor.get_func_graph(phase) | |||
| def test_softmax_relu(): | |||
| """ | |||
| Use python pass to transform from Softmax to ReLU. | |||
| """ | |||
| inputs = Tensor(np.ones([42]), mindspore.float16) | |||
| softmax_model = nn.Softmax() | |||
| @registe_pass(run_only_once=True) | |||
| def softmax_relu_pass(): | |||
| softmax = P.Softmax() | |||
| relu = P.ReLU() | |||
| def pattern(x): | |||
| x = softmax(x) | |||
| return x | |||
| def target(x): | |||
| x = relu(x) | |||
| return x | |||
| return pattern, target | |||
| transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2) | |||
| ppm = PyPassManager() | |||
| ppm.unregiste(softmax_relu_pass) | |||
| assert "ReLU" in transformed_repr | |||
| assert "Softmax" not in transformed_repr | |||