Browse Source

reduce pass traversals if no pre-ad python pass exists

tags/v1.0.0
BowenK 5 years ago
parent
commit
e482e4e8bf
5 changed files with 29 additions and 23 deletions
  1. +1
    -0
      mindspore/ccsrc/frontend/optimizer/pass_group.h
  2. +6
    -1
      mindspore/ccsrc/pipeline/jit/action.cc
  3. +4
    -4
      mindspore/graph_utils/python_pass/__init__.py
  4. +4
    -4
      mindspore/graph_utils/python_pass/python_pass_register.py
  5. +14
    -14
      tests/ut/python/optimizer/test_python_pass.py

+ 1
- 0
mindspore/ccsrc/frontend/optimizer/pass_group.h View File

@@ -49,6 +49,7 @@ class PassGroup {
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes, const MatchResultPtr &res) const; bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes, const MatchResultPtr &res) const;
std::string name() const { return name_; } std::string name() const { return name_; }
void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; } void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; }
size_t size() { return passes_.size(); }


private: private:
const std::string name_; const std::string name_;


+ 6
- 1
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -301,7 +301,12 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
return true; return true;
} }


bool OptInlineAction(const ResourcePtr &res) { return OptimizeAction(res, kInlinePasses); }
bool OptInlineAction(const ResourcePtr &res) {
if (opt::python_pass::PyPassManager::GetInstance()->GetPassGroup(opt::python_pass::Phase::PREAD)->size() != 0) {
return OptimizeAction(res, kInlinePasses);
}
return true;
}


bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }




+ 4
- 4
mindspore/graph_utils/python_pass/__init__.py View File

@@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Reference for python pass registration.""" """Reference for python pass registration."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, _set_renorm,\
_set_reopt
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
set_reopt


__all__ = [ __all__ = [
"registe_pass", "registe_pass",
"unregiste_pass", "unregiste_pass",
"gen_new_parameter", "gen_new_parameter",
"cancel_new_parameter", "cancel_new_parameter",
"_set_renorm",
"_set_reopt"
"set_renorm",
"set_reopt"
] ]

+ 4
- 4
mindspore/graph_utils/python_pass/python_pass_register.py View File

@@ -23,8 +23,8 @@ __all__ = [
"unregiste_pass", "unregiste_pass",
"gen_new_parameter", "gen_new_parameter",
"cancel_new_parameter", "cancel_new_parameter",
"_set_renorm",
"_set_reopt"
"set_renorm",
"set_reopt"
] ]
class PyPassManager(PyPassManager_): class PyPassManager(PyPassManager_):
r""" r"""
@@ -162,7 +162,7 @@ def cancel_new_parameter(pattern):
ppm = PyPassManager() ppm = PyPassManager()
ppm.unregiste(pattern.para_name) ppm.unregiste(pattern.para_name)


def _set_renorm(should_renorm):
def set_renorm(should_renorm):
""" """
Set whether or not to do renormalization after modified graph in python pass(es). Set whether or not to do renormalization after modified graph in python pass(es).


@@ -176,7 +176,7 @@ def _set_renorm(should_renorm):
ppm = PyPassManager() ppm = PyPassManager()
ppm.set_renorm(should_renorm) ppm.set_renorm(should_renorm)


def _set_reopt(do_reopt):
def set_reopt(do_reopt):
""" """
Set whether or not to do optimization after modified graph in python pass(es). Set whether or not to do optimization after modified graph in python pass(es).




+ 14
- 14
tests/ut/python/optimizer/test_python_pass.py View File

@@ -19,8 +19,8 @@ import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\
cancel_new_parameter, _set_reopt
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter, set_reopt
from mindspore.common.api import _generate_pip_args from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_ from mindspore._c_expression import generate_key, Executor_
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
@@ -157,8 +157,8 @@ def test_isnot_pattern_0():
Test IsNot pattern which expresses the IsNot semantics. Test IsNot pattern which expresses the IsNot semantics.
Case: IsNot pass failed to match Case: IsNot pass failed to match
""" """
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
class ConvBN(nn.Cell): class ConvBN(nn.Cell):
def __init__(self): def __init__(self):
super(ConvBN, self).__init__() super(ConvBN, self).__init__()
@@ -202,7 +202,7 @@ def test_isnot_pattern_0():
unregiste_pass(bn_pass) unregiste_pass(bn_pass)
assert "ReLU6" not in transformed_repr assert "ReLU6" not in transformed_repr
assert "Softmax" in transformed_repr assert "Softmax" in transformed_repr
_set_renorm(True)
set_renorm(True)


def test_isnot_pattern_1(): def test_isnot_pattern_1():
""" """
@@ -234,8 +234,8 @@ def test_newtensor_pattern():
""" """
Test NewTensor pattern in the target Test NewTensor pattern in the target
""" """
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
inputs = Tensor(np.ones([42]), mindspore.float16) inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax() softmax_model = nn.Softmax()


@@ -252,7 +252,7 @@ def test_newtensor_pattern():
unregiste_pass(softmax_addn_pass) unregiste_pass(softmax_addn_pass)
assert "AddN" in transformed_repr assert "AddN" in transformed_repr
assert "Softmax" not in transformed_repr assert "Softmax" not in transformed_repr
_set_renorm(True)
set_renorm(True)


def test_newparameter_pattern(): def test_newparameter_pattern():
""" """
@@ -261,8 +261,8 @@ def test_newparameter_pattern():
inputs = Tensor(np.ones([42]), mindspore.float16) inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax() softmax_model = nn.Softmax()


_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True) @registe_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass(): def softmax_addn_pass():
x = Any() x = Any()
@@ -288,8 +288,8 @@ def test_imm_target():
inputs = Tensor(np.ones([42]), mindspore.float16) inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax() softmax_model = nn.Softmax()


_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True) @registe_pass(requires_grad=False, run_only_once=True)
def softmax_pass(): def softmax_pass():
x = Any() x = Any()
@@ -313,8 +313,8 @@ def test_gen_new_parameter():


default_tensor = Tensor(np.ones((4, 4)), mindspore.float32) default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
new_para = NewParameter("Merlin", default_tensor) new_para = NewParameter("Merlin", default_tensor)
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
gen_new_parameter(new_para) gen_new_parameter(new_para)
@registe_pass(requires_grad=False, run_only_once=True) @registe_pass(requires_grad=False, run_only_once=True)
def softmax_make_tuple_pass(): def softmax_make_tuple_pass():


Loading…
Cancel
Save