Browse Source

Add Python Pass UT

tags/v0.6.0-beta
BowenK 5 years ago
parent
commit
f267a105b8
7 changed files with 130 additions and 9 deletions
  1. +28
    -0
      mindspore/ccsrc/ir/anf_py.cc
  2. +35
    -0
      mindspore/ccsrc/ir/func_graph_py.cc
  3. +1
    -0
      mindspore/ccsrc/optimizer/py_pass.cc
  4. +1
    -1
      mindspore/ccsrc/pipeline/action.cc
  5. +0
    -7
      mindspore/ccsrc/pipeline/init.cc
  6. +1
    -1
      mindspore/ops/primitive.py
  7. +64
    -0
      tests/ut/python/optimizer/test_python_pass.py

+ 28
- 0
mindspore/ccsrc/ir/anf_py.cc View File

@@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "ir/anf.h"

#include "pybind_api/api_register.h"

namespace mindspore {
// Define python 'RefKey' class.
REGISTER_PYBIND_DEFINE(CNode, ([](const pybind11::module *m) {
(void)py::class_<CNode, CNodePtr>(*m, "CNode")
.def("expanded_str", (std::string(CNode::*)(int) const) & CNode::DebugString,
"Get CNode string representation with specified expansion level.");
}));
} // namespace mindspore

+ 35
- 0
mindspore/ccsrc/ir/func_graph_py.cc View File

@@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "ir/meta_func_graph.h"
#include "ir/func_graph.h"

#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"

namespace mindspore {
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
// Define python "MetaFuncGraph_" class
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
.def_readonly(PYTHON_METAFUNCGRAPH_FLAG, &MetaFuncGraph::parse_info_)
.def(py::init<std::string &>());
// Define python "FuncGraph" class
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
.def(py::init())
.def("str", &FuncGraph::ToString, "Get FuncGraph string representation.")
.def("get_return", &FuncGraph::get_return, "Get return node of FuncGraph");
}));
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/optimizer/py_pass.cc View File

@@ -54,6 +54,7 @@ void ResolveFuncGraph_(const FuncGraphPtr &fg) {
auto manager = Manage(fg, false); auto manager = Manage(fg, false);
parse::python_adapter::set_use_signature_in_resolve(false); parse::python_adapter::set_use_signature_in_resolve(false);
parse::ResolveAll(manager); parse::ResolveAll(manager);
parse::python_adapter::set_use_signature_in_resolve(true);
} }


bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) { bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) {


+ 1
- 1
mindspore/ccsrc/pipeline/action.cc View File

@@ -437,7 +437,7 @@ bool ResolveActionPyStub(const ResourcePtr &res) {
} }


bool OptActionPyStub(const ResourcePtr &res) { bool OptActionPyStub(const ResourcePtr &res) {
ActionPyStub(res, opt::python_pass::Phase::RESOLVE);
ActionPyStub(res, opt::python_pass::Phase::OPT);
return true; return true;
} }




+ 0
- 7
mindspore/ccsrc/pipeline/init.cc View File

@@ -38,7 +38,6 @@
#endif #endif
namespace py = pybind11; namespace py = pybind11;


using FuncGraph = mindspore::FuncGraph;
using EnvInstance = mindspore::EnvInstance; using EnvInstance = mindspore::EnvInstance;
using ExecutorPy = mindspore::pipeline::ExecutorPy; using ExecutorPy = mindspore::pipeline::ExecutorPy;
using Pipeline = mindspore::pipeline::Pipeline; using Pipeline = mindspore::pipeline::Pipeline;
@@ -54,10 +53,6 @@ using CostModelContext = mindspore::parallel::CostModelContext;
PYBIND11_MODULE(_c_expression, m) { PYBIND11_MODULE(_c_expression, m) {
m.doc() = "MindSpore c plugin"; m.doc() = "MindSpore c plugin";


(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
.def_readonly(mindspore::PYTHON_METAFUNCGRAPH_FLAG, &mindspore::MetaFuncGraph::parse_info_)
.def(py::init<std::string &>());

auto fns = mindspore::PybindDefineRegister::AllFuncs(); auto fns = mindspore::PybindDefineRegister::AllFuncs();
for (auto &item : fns) { for (auto &item : fns) {
item.second(&m); item.second(&m);
@@ -85,8 +80,6 @@ PYBIND11_MODULE(_c_expression, m) {
py::arg("broadcast_params") = py::dict(), "Build data graph.") py::arg("broadcast_params") = py::dict(), "Build data graph.")
.def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.")
.def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph."); .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph.");
// Class Graph interface
(void)py::class_<FuncGraph, mindspore::FuncGraphPtr>(m, "FuncGraph").def(py::init());


(void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_") (void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_")
.def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_) .def_readonly(mindspore::PYTHON_ENVINSTANCE_FLAG, &mindspore::EnvInstance::parse_info_)


+ 1
- 1
mindspore/ops/primitive.py View File

@@ -146,7 +146,7 @@ class Primitive(Primitive_):
Check whether or not certain inputs should go into backend. Subclass in need should override this method. Check whether or not certain inputs should go into backend. Subclass in need should override this method.


Args: Args:
Same as arguments of current Primitive
*args(Primitive args): Same as arguments of current Primitive.


Returns: Returns:
A tuple of two elements, first element indicates whether or not we should filter out current arguments; A tuple of two elements, first element indicates whether or not we should filter out current arguments;


+ 64
- 0
tests/ut/python/optimizer/test_python_pass.py View File

@@ -0,0 +1,64 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np

import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.common.python_pass_register import registe_pass, PyPassManager
from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_

context.set_context(mode=context.GRAPH_MODE)

def get_func_graph(obj, *args, phase="predict"):
args_names, args_list = _generate_pip_args(obj, *args)
dic = dict(zip(args_names, args_list))
key = generate_key(phase, dic)
phase_prefix = str(key[1])
if phase == 'export':
phase = phase + '.' + phase_prefix + '.' + str(obj.create_time)
else:
phase = phase_prefix + phase + '.' + str(obj.create_time)
_executor = Executor_.get_instance()
_executor.compile(obj, args_list, phase, False)
return _executor.get_func_graph(phase)

def test_softmax_relu():
"""
Use python pass to transform from Softmax to ReLU.
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()

@registe_pass(run_only_once=True)
def softmax_relu_pass():
softmax = P.Softmax()
relu = P.ReLU()
def pattern(x):
x = softmax(x)
return x
def target(x):
x = relu(x)
return x
return pattern, target

transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
ppm = PyPassManager()
ppm.unregiste(softmax_relu_pass)
assert "ReLU" in transformed_repr
assert "Softmax" not in transformed_repr

Loading…
Cancel
Save