Browse Source

!5692 Add requires_grad option for python pass

Merge pull request !5692 from BowenK/pre_ad
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
88f5cbe5db
11 changed files with 129 additions and 79 deletions
  1. +9
    -1
      mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
  2. +38
    -28
      mindspore/ccsrc/frontend/optimizer/py_pass.cc
  3. +2
    -1
      mindspore/ccsrc/frontend/optimizer/py_pass.h
  4. +17
    -11
      mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc
  5. +3
    -3
      mindspore/ccsrc/frontend/optimizer/py_pass_manager.h
  6. +12
    -3
      mindspore/ccsrc/pipeline/jit/action.cc
  7. +10
    -0
      mindspore/ccsrc/pipeline/jit/pass.cc
  8. +1
    -0
      mindspore/ccsrc/pipeline/jit/pass.h
  9. +4
    -4
      mindspore/graph_utils/python_pass/__init__.py
  10. +13
    -8
      mindspore/graph_utils/python_pass/python_pass_register.py
  11. +20
    -20
      tests/ut/python/optimizer/test_python_pass.py

+ 9
- 1
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc View File

@@ -49,7 +49,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
grad_op_child_scope_prefix + prim->name());
ScopeGuard scope_guard(scope);
py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast<PrimitivePyPtr>()->GetBpropFunction();
py::function fn;
if (prim->is_base()) {
fn = GetBpropFunction(prim->name());
} else {
fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
if (py::isinstance<py::none>(fn)) {
fn = GetBpropFunction(prim->name());
}
}
if (!fn || py::isinstance<py::none>(fn)) {
MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
return nullptr;


+ 38
- 28
mindspore/ccsrc/frontend/optimizer/py_pass.cc View File

@@ -35,8 +35,10 @@ namespace internal {
const char PARAMETER_MODULE[] = "mindspore.common.parameter";
const char PARAMETER_CLASS[] = "Parameter";
const char SET_PARAM[] = "__setattr__";
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph);
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res);
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
const FuncGraphPtr &top_graph);
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
const MatchResultPtr &res);
void ReflectParamBackToPython(const AnfNodePtr &param, string param_name, tensor::TensorPtr default_input,
bool requires_grad, bool layerwise_parallel);

@@ -72,7 +74,8 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
return std::make_shared<ValueNode>(input_tensor);
}

AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg,
const FuncGraphPtr &top_graph) {
auto call_pattern = pattern->cast<CallPtr>();
MS_EXCEPTION_IF_NULL(call_pattern);
auto prim = call_pattern->prim_value();
@@ -81,20 +84,20 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
}
auto prim_pattern = call_pattern->prim_pattern();
MS_EXCEPTION_IF_NULL(prim_pattern);
return ProcessSinglePattern(prim_pattern, res, fg);
return ProcessSinglePattern(prim_pattern, res, fg, top_graph);
}

AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) {
auto new_para_pattern = pattern->cast<NewParameterPtr>();
MS_EXCEPTION_IF_NULL(new_para_pattern);
if (!new_para_pattern->built()) {
static int parameter_id = 0;
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++);
auto para_node = std::make_shared<Parameter>(func_graph);
auto para_node = std::make_shared<Parameter>(top_graph);
MS_EXCEPTION_IF_NULL(para_node);
para_node->set_name(para_name);
// Set function graph
para_node->set_func_graph(func_graph);
para_node->set_func_graph(top_graph);
// Set Debug Info
auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
para_node->set_debug_info(debug_info);
@@ -103,7 +106,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re
MS_EXCEPTION_IF_NULL(default_value);
para_node->set_abstract(default_value->ToAbstract()->Broaden());
res->add_entry(pattern, para_node);
func_graph->add_parameter(para_node);
top_graph->add_parameter(para_node);
// Reflect back to Cell._params
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
new_para_pattern->layerwise_parallel());
@@ -126,7 +129,8 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
return std::make_shared<ValueNode>(scalar_value_ptr);
}

AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
const FuncGraphPtr &top_graph) {
auto target_node = res->get_node(pattern);
if (target_node != nullptr) {
// If pattern is NewParameter, check whether it shouldn't last and is not built
@@ -141,9 +145,10 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
} else if (pattern->isa<NewTensor>()) {
return BuildNewTensor(pattern, res);
} else if (pattern->isa<Call>()) {
return BuildPrimitiveValueNode(pattern, res, func_graph);
return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
} else if (pattern->isa<NewParameter>()) {
return BuildNewParameter(pattern, res, func_graph);
// Add new parameter to top graph instead of current graph
return BuildNewParameter(pattern, res, top_graph);
} else if (pattern->isa<Imm>()) {
return BuildImmNode(pattern, res);
} else {
@@ -154,17 +159,18 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
}

AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
const FuncGraphPtr &func_graph) {
const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph) {
if (pattern->isa<Call>()) {
return BuildPrimitiveValueNode(pattern, res, func_graph);
return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
}
return nullptr;
}

AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
const MatchResultPtr &res) {
auto target_inputs = pattern->inputs();
if (target_inputs.size() == 0) {
auto new_node = ProcessSinglePattern(pattern, res, func_graph);
auto new_node = ProcessSinglePattern(pattern, res, func_graph, top_graph);
if (new_node != nullptr) {
res->add_entry(pattern, new_node);
}
@@ -172,14 +178,14 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
}
// Build up the AnfNode in a recursive manner
std::vector<AnfNodePtr> new_inputs;
auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph);
auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph, top_graph);
MS_EXCEPTION_IF_NULL(prim_value_node);
new_inputs.push_back(prim_value_node);
for (auto &iter : target_inputs) {
if (iter == pattern) {
MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n";
}
auto input_node = BuildTarget(iter, func_graph, res);
auto input_node = BuildTarget(iter, func_graph, top_graph, res);
if (input_node == nullptr) {
MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n";
}
@@ -240,11 +246,12 @@ void Reset(PatternPtr pattern) {

} // namespace internal

AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) {
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node,
const MatchResultPtr &res) {
auto match_res = src_pattern_->match(node);
if (match_res != nullptr) {
res->merge(match_res);
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, top_graph, res);
internal::Reset(dst_pattern());
return new_node;
}
@@ -284,16 +291,19 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
}
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(func_graph);
auto graph_nodes_sorted = TopoSort(func_graph->output());
auto func_graphs = manager->func_graphs();
bool changes = false;

// Traverse once
for (auto &node : graph_nodes_sorted) {
AnfNodePtr new_node = Run(func_graph, node, res);
if (new_node != nullptr && new_node != node) {
(void)manager->Replace(node, new_node);
changes = true;
for (auto &fg : func_graphs) {
manager->AddFuncGraph(fg);
auto graph_nodes_sorted = TopoSort(fg->output());
// Traverse once
for (auto &node : graph_nodes_sorted) {
AnfNodePtr new_node = Run(fg, func_graph, node, res);
if (new_node != nullptr && new_node != node) {
MS_LOG(WARNING) << "Matched";
(void)manager->Replace(node, new_node);
changes = true;
}
}
}
return changes;


+ 2
- 1
mindspore/ccsrc/frontend/optimizer/py_pass.h View File

@@ -39,7 +39,8 @@ class PythonPass {
~PythonPass() = default;
bool Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res);
std::string name() const { return name_; }
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res);
AnfNodePtr Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node,
const MatchResultPtr &res);
PatternPtr src_pattern() { return src_pattern_; }
PatternPtr dst_pattern() { return dst_pattern_; }



+ 17
- 11
mindspore/ccsrc/frontend/optimizer/py_pass_manager.cc View File

@@ -43,15 +43,19 @@ PyPassManagerPtr PyPassManager::GetInstance() {
}

PyPassManager::PyPassManager() {
phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>();
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>();
phase_to_group_[Phase::PREAD] = std::make_shared<PassGroup>("Pre_AD_PassGroup");
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>("After_OPT_PassGroup");
res_ = std::make_shared<MatchResult>();
}

void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
bool run_only_once) {
// NOTE: remove phase option to avoid unnecessary confusion.
auto cur_pg = GetPassGroup(Phase::OPT);
bool requires_grad, bool run_only_once) {
PassGroupPtr cur_pg;
if (requires_grad) {
cur_pg = GetPassGroup(Phase::PREAD);
} else {
cur_pg = GetPassGroup(Phase::OPT);
}
MS_EXCEPTION_IF_NULL(cur_pg);
cur_pg->SetRunOnlyOnce(run_only_once);
MS_EXCEPTION_IF_NULL(pattern);
@@ -62,11 +66,13 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
}

void PyPassManager::Unregiste(const std::string &pass_name) {
// NOTE: remove phase option to avoid unnecessary confusion.
auto cur_pm = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pm);
if (!cur_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
auto opt_pm = GetPassGroup(Phase::OPT);
if (!opt_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
}
auto pre_ad_pm = GetPassGroup(Phase::PREAD);
if (!pre_ad_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "Pre_AD has no such pass : " + pass_name + "\n";
}
}

@@ -92,7 +98,7 @@ void PyPassManager::ClearRes() {

REGISTER_PYBIND_DEFINE(
PyPassManager_, ([](const py::module *m) {
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT);
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
.def(py::init([]() { return PyPassManager::GetInstance(); }))
.def("registe", &PyPassManager::Registe, "Registe python pass")


+ 3
- 3
mindspore/ccsrc/frontend/optimizer/py_pass_manager.h View File

@@ -38,7 +38,7 @@ namespace python_pass {
class PyPassManager;
using PyPassManagerPtr = std::shared_ptr<PyPassManager>;

enum Phase { RESOLVE, OPT };
enum Phase { PREAD, OPT };

class PyPassManager {
protected:
@@ -52,8 +52,8 @@ class PyPassManager {
// Access the only global instance
static PyPassManagerPtr GetInstance();
virtual ~PyPassManager() = default;
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
bool run_only_once = false);
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
bool run_only_once);
void Unregiste(const std::string &pass_name);
void GenNewParameter(const PatternPtr &parameter);
PassGroupPtr GetPassGroup(Phase phase);


+ 12
- 3
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -301,6 +301,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
return true;
}

bool OptInlineAction(const ResourcePtr &res) { return OptimizeAction(res, kInlinePasses); }

bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }

bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
@@ -473,7 +475,12 @@ bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
return ppm->GetPassGroup(phase)->Run(res->func_graph());
}

bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); }
bool PreAdActionPyStub(const ResourcePtr &res) {
if (!ActionPyStub(res, opt::python_pass::Phase::PREAD)) {
MS_LOG(DEBUG) << "No Match.";
}
return true;
}

bool OptActionVmPyStub(const ResourcePtr &res) {
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
@@ -529,12 +536,14 @@ static std::vector<ActionItem> CommonPipeline() {
if (!multi_graphs) {
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
}
// Add resolve-stage python pass stub
actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub));

actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
// Evaluate type and shape, and specialize
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
// Do data structure simplifications and inline
actions.emplace_back(std::make_pair("inline", OptInlineAction));
// Add pre-ad, post-inline python pass stub
actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub));

return actions;
}


+ 10
- 0
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -165,6 +165,12 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
return map_a;
}

OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
auto opt_a = GetOptPassesA(irpass);
OptPassGroupMap a1_a2({opt_a[0], opt_a[1]});
return a1_a2;
}

OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig c_1 = opt::OptPassConfig({
// Safe inlining,
@@ -270,6 +276,7 @@ static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts =
void InitOpt(const ResourcePtr &res) {
if (g_pass_opts.size() == 0) {
opt::irpass::OptimizeIRPassLib irpass;
g_pass_opts["a1a2"] = Optimizer::MakeOptimizer("a1a2", res, GetA1A2(irpass));
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
g_pass_opts["opt_after_cconv"] =
@@ -318,6 +325,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
return true;
}

bool OptPassA1A2(const ResourcePtr &res) { return OptPassGroup(res, "a1a2"); }
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
@@ -440,5 +448,7 @@ std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
{"cconv", CconvPass},
{"transform_top", TransformTopGraphPass},
{"transform_graph", OptPassTransformGraphGroup}};

std::vector<PassItem> kInlinePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"a1a2", OptPassA1A2}};
} // namespace pipeline
} // namespace mindspore

+ 1
- 0
mindspore/ccsrc/pipeline/jit/pass.h View File

@@ -29,6 +29,7 @@ using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;

extern std::vector<PassItem> kGePasses;
extern std::vector<PassItem> kVmPasses;
extern std::vector<PassItem> kInlinePasses;
extern std::vector<PassItem> kPynativePasses;

bool CconvPass(const ResourcePtr &res);


+ 4
- 4
mindspore/graph_utils/python_pass/__init__.py View File

@@ -13,14 +13,14 @@
# limitations under the License.
# ============================================================================
"""Reference for python pass registration."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
set_reopt
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, _set_renorm,\
_set_reopt

__all__ = [
"registe_pass",
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm",
"set_reopt"
"_set_renorm",
"_set_reopt"
]

+ 13
- 8
mindspore/graph_utils/python_pass/python_pass_register.py View File

@@ -23,22 +23,26 @@ __all__ = [
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm",
"set_reopt"
"_set_renorm",
"_set_reopt"
]
class PyPassManager(PyPassManager_):
r"""
Used to registe and unregiste python passes which can be used to alter graphs.

Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
run_only_once (bool): Specify whether or not to run pass only once. Default: False.

Raises:
TypeError: If argument has invalid type.
"""
def __init__(self, run_only_once=False):
def __init__(self, requires_grad=True, run_only_once=False):
if not isinstance(requires_grad, bool):
raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}")
if not isinstance(run_only_once, bool):
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
self.requires_grad = requires_grad
self.run_only_once_ = run_only_once
PyPassManager_.__init__(self)

@@ -51,7 +55,7 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.run_only_once_)
super().registe(pass_name, pattern, target, self.requires_grad, self.run_only_once_)

def unregiste(self, py_pass):
if isinstance(py_pass, str):
@@ -81,11 +85,12 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
super().set_reopt(do_reopt)

def registe_pass(run_only_once=False):
def registe_pass(requires_grad=True, run_only_once=False):
"""
Registe python pass to specified pipeline phase which would be used in compilation.

Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.

Returns:
@@ -99,7 +104,7 @@ def registe_pass(run_only_once=False):
>>> target = IsPrimTypeOf("ReLU6")
>>> return pattern, target
"""
return PyPassManager(run_only_once)
return PyPassManager(requires_grad, run_only_once)

def unregiste_pass(py_pass):
"""
@@ -157,7 +162,7 @@ def cancel_new_parameter(pattern):
ppm = PyPassManager()
ppm.unregiste(pattern.para_name)

def set_renorm(should_renorm):
def _set_renorm(should_renorm):
"""
Set whether or not to do renormalization after modified graph in python pass(es).

@@ -171,7 +176,7 @@ def set_renorm(should_renorm):
ppm = PyPassManager()
ppm.set_renorm(should_renorm)

def set_reopt(do_reopt):
def _set_reopt(do_reopt):
"""
Set whether or not to do optimization after modified graph in python pass(es).



+ 20
- 20
tests/ut/python/optimizer/test_python_pass.py View File

@@ -19,8 +19,8 @@ import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter, set_reopt
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\
cancel_new_parameter, _set_reopt
from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
@@ -157,8 +157,8 @@ def test_isnot_pattern_0():
Test IsNot pattern which expresses the IsNot semantics.
Case: IsNot pass failed to match
"""
set_renorm(False)
set_reopt(False)
_set_renorm(False)
_set_reopt(False)
class ConvBN(nn.Cell):
def __init__(self):
super(ConvBN, self).__init__()
@@ -176,7 +176,7 @@ def test_isnot_pattern_0():
inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
conv_bn_model = ConvBN()

@registe_pass(run_only_once=True)
@registe_pass(requires_grad=False, run_only_once=True)
def single_bn_pass():
"""
Sub a BN which does NOT take Conv as inputs to ReLU6.
@@ -188,7 +188,7 @@ def test_isnot_pattern_0():
target = Call(P.ReLU6(), [pattern_0])
return pattern, target

@registe_pass(run_only_once=True)
@registe_pass(requires_grad=False, run_only_once=True)
def bn_pass():
"""
Sub a BN to Softmax.
@@ -202,7 +202,7 @@ def test_isnot_pattern_0():
unregiste_pass(bn_pass)
assert "ReLU6" not in transformed_repr
assert "Softmax" in transformed_repr
set_renorm(True)
_set_renorm(True)

def test_isnot_pattern_1():
"""
@@ -234,12 +234,12 @@ def test_newtensor_pattern():
"""
Test NewTensor pattern in the target
"""
set_renorm(False)
set_reopt(False)
_set_renorm(False)
_set_reopt(False)
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()

@registe_pass(run_only_once=True)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass():
x = Any()
pattern = Call(P.Softmax(), [x])
@@ -252,7 +252,7 @@ def test_newtensor_pattern():
unregiste_pass(softmax_addn_pass)
assert "AddN" in transformed_repr
assert "Softmax" not in transformed_repr
set_renorm(True)
_set_renorm(True)

def test_newparameter_pattern():
"""
@@ -261,9 +261,9 @@ def test_newparameter_pattern():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()

set_renorm(False)
set_reopt(False)
@registe_pass(run_only_once=True)
_set_renorm(False)
_set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass():
x = Any()
pattern = Call(P.Softmax(), [x])
@@ -288,9 +288,9 @@ def test_imm_target():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()

set_renorm(False)
set_reopt(False)
@registe_pass(run_only_once=True)
_set_renorm(False)
_set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_pass():
x = Any()
pattern = Call(P.Softmax(), [x])
@@ -313,10 +313,10 @@ def test_gen_new_parameter():

default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
new_para = NewParameter("Merlin", default_tensor)
set_renorm(False)
set_reopt(False)
_set_renorm(False)
_set_reopt(False)
gen_new_parameter(new_para)
@registe_pass(run_only_once=True)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_make_tuple_pass():
x = Any()
softmax = P.Softmax()


Loading…
Cancel
Save