Browse Source

reduce pass traversals if no pre-ad python pass exists

tags/v1.0.0
BowenK 5 years ago
parent
commit
e482e4e8bf
5 changed files with 29 additions and 23 deletions
  1. +1
    -0
      mindspore/ccsrc/frontend/optimizer/pass_group.h
  2. +6
    -1
      mindspore/ccsrc/pipeline/jit/action.cc
  3. +4
    -4
      mindspore/graph_utils/python_pass/__init__.py
  4. +4
    -4
      mindspore/graph_utils/python_pass/python_pass_register.py
  5. +14
    -14
      tests/ut/python/optimizer/test_python_pass.py

+ 1
- 0
mindspore/ccsrc/frontend/optimizer/pass_group.h View File

@@ -49,6 +49,7 @@ class PassGroup {
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes, const MatchResultPtr &res) const;
std::string name() const { return name_; }
void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; }
size_t size() { return passes_.size(); }

private:
const std::string name_;


+ 6
- 1
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -301,7 +301,12 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
return true;
}

bool OptInlineAction(const ResourcePtr &res) { return OptimizeAction(res, kInlinePasses); }
bool OptInlineAction(const ResourcePtr &res) {
if (opt::python_pass::PyPassManager::GetInstance()->GetPassGroup(opt::python_pass::Phase::PREAD)->size() != 0) {
return OptimizeAction(res, kInlinePasses);
}
return true;
}

bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }



+ 4
- 4
mindspore/graph_utils/python_pass/__init__.py View File

@@ -13,14 +13,14 @@
# limitations under the License.
# ============================================================================
"""Reference for python pass registration."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, _set_renorm,\
_set_reopt
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
set_reopt

__all__ = [
"registe_pass",
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"_set_renorm",
"_set_reopt"
"set_renorm",
"set_reopt"
]

+ 4
- 4
mindspore/graph_utils/python_pass/python_pass_register.py View File

@@ -23,8 +23,8 @@ __all__ = [
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"_set_renorm",
"_set_reopt"
"set_renorm",
"set_reopt"
]
class PyPassManager(PyPassManager_):
r"""
@@ -162,7 +162,7 @@ def cancel_new_parameter(pattern):
ppm = PyPassManager()
ppm.unregiste(pattern.para_name)

def _set_renorm(should_renorm):
def set_renorm(should_renorm):
"""
Set whether or not to do renormalization after modified graph in python pass(es).

@@ -176,7 +176,7 @@ def _set_renorm(should_renorm):
ppm = PyPassManager()
ppm.set_renorm(should_renorm)

def _set_reopt(do_reopt):
def set_reopt(do_reopt):
"""
Set whether or not to do optimization after modified graph in python pass(es).



+ 14
- 14
tests/ut/python/optimizer/test_python_pass.py View File

@@ -19,8 +19,8 @@ import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\
cancel_new_parameter, _set_reopt
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter, set_reopt
from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
@@ -157,8 +157,8 @@ def test_isnot_pattern_0():
Test IsNot pattern which expresses the IsNot semantics.
Case: IsNot pass failed to match
"""
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
class ConvBN(nn.Cell):
def __init__(self):
super(ConvBN, self).__init__()
@@ -202,7 +202,7 @@ def test_isnot_pattern_0():
unregiste_pass(bn_pass)
assert "ReLU6" not in transformed_repr
assert "Softmax" in transformed_repr
_set_renorm(True)
set_renorm(True)

def test_isnot_pattern_1():
"""
@@ -234,8 +234,8 @@ def test_newtensor_pattern():
"""
Test NewTensor pattern in the target
"""
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()

@@ -252,7 +252,7 @@ def test_newtensor_pattern():
unregiste_pass(softmax_addn_pass)
assert "AddN" in transformed_repr
assert "Softmax" not in transformed_repr
_set_renorm(True)
set_renorm(True)

def test_newparameter_pattern():
"""
@@ -261,8 +261,8 @@ def test_newparameter_pattern():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()

_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass():
x = Any()
@@ -288,8 +288,8 @@ def test_imm_target():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()

_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_pass():
x = Any()
@@ -313,8 +313,8 @@ def test_gen_new_parameter():

default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
new_para = NewParameter("Merlin", default_tensor)
_set_renorm(False)
_set_reopt(False)
set_renorm(False)
set_reopt(False)
gen_new_parameter(new_para)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_make_tuple_pass():


Loading…
Cancel
Save