Browse Source

!1979 fix partial primitive poly node

Merge pull request !1979 from fary86/fix_patial_primitive_poly_code
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
ea96cbcef2
2 changed files with 54 additions and 12 deletions
  1. +20
    -10
      mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc
  2. +34
    -2
      tests/ut/python/ops/test_ops_attr_infer.py

+ 20
- 10
mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc View File

@@ -378,11 +378,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr
} }
auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval); auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval);


if (func->context() != nullptr) {
if (!IsVisible(func_graph_, func->context()->func_graph())) {
MS_LOG(EXCEPTION) << "Func is not visible NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
}
} else {
if (func->context() == nullptr) {
MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info());
} }
AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals);
@@ -507,9 +503,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
// First element is partial, second is func so arg is start from 2 // First element is partial, second is func so arg is start from 2
(void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end());
func = inputs[1]; func = inputs[1];
new_inputs = args;
(void)new_inputs.insert(new_inputs.begin(), func);
} }
new_inputs = args;
(void)new_inputs.insert(new_inputs.begin(), func);


AbstractBasePtrList argvals; AbstractBasePtrList argvals;
MS_EXCEPTION_IF_NULL(new_inputs[0]); MS_EXCEPTION_IF_NULL(new_inputs[0]);
@@ -524,9 +520,23 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) {
<< new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString();
} }


if (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) {
auto wrapped_node = BuildSpecializedParameterNode(new_node);
new_inputs[0] = wrapped_node;
if (!func->isa<ValueNode>()) {
MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString();
if (func->abstract()->isa<AbstractFunction>() && !func->abstract()->isa<AbstractFuncUnion>()) {
auto func_abs = func->abstract()->cast<AbstractFunctionPtr>();
EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs);
std::pair<AbstractBasePtrList, AbstractBasePtr> result;
AbstractBasePtrList empty_args;
auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result);
MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status;
// if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early
if (status == kSpecializeFindUniqueArgvalPoly ||
(func->isa<Parameter>() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) ||
func->abstract()->isa<PartialAbstractClosure>()))) {
auto wrapped_node = BuildSpecializedParameterNode(new_node);
new_inputs[0] = wrapped_node;
}
}
} }


if (CanSpecializeNode(func)) { if (CanSpecializeNode(func)) {


+ 34
- 2
tests/ut/python/ops/test_ops_attr_infer.py View File

@@ -14,9 +14,12 @@
# ============================================================================ # ============================================================================
""" test nn ops """ """ test nn ops """
import numpy as np import numpy as np
from numpy.random import normal


import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context import mindspore.context as context
from mindspore.ops.composite import core
from mindspore.common.api import ms_function


from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import functional as F from mindspore.ops import functional as F
@@ -59,10 +62,39 @@ def test_conv2d_same_primitive():
net(t1, t2) net(t1, t2)




# test free variable function list as parameter
def test_remove_and_fv_2():
@core(loop_can_uroll=True)
def inner_loop(x, input_data, fv_func_list):
ret = ()
for fv_fn in fv_func_list:
ele = fv_fn(input_data)
ret += (ele,)
return ret

@ms_function
def out_loop(input1, input_data):
ret = ()

def fv_func1(y):
return input1 * y
def fv_func2(y):
return input1 - y
fv_func_list = [fv_func1, fv_func2]
ele0 = inner_loop(input1, input_data[0], fv_func_list)
ele1 = inner_loop(input1, input_data[1], fv_func_list)
ret = (ele0, ele1)
return ret

input_data = (Tensor(normal(0, 0.1, (3, 3))), Tensor(normal(0, 0.1, (3, 1))))
input1 = Tensor(normal(0, 0.1, (3, 3)))
out_loop(input1, input_data)


# test cell as high order argument # test cell as high order argument
# The graph with free variables used as argument is not supported yet # The graph with free variables used as argument is not supported yet
# because of the limit of inference specialize system # because of the limit of inference specialize system
def Xtest_conv2d_op_with_arg():
def test_conv2d_op_with_argi_1():
class Conv2dNet(nn.Cell): class Conv2dNet(nn.Cell):
def __init__(self): def __init__(self):
super(Conv2dNet, self).__init__() super(Conv2dNet, self).__init__()
@@ -279,7 +311,7 @@ def test_op_with_arg_as_input():


# The partial application used as argument is not supported yet # The partial application used as argument is not supported yet
# because of the limit of inference specialize system # because of the limit of inference specialize system
def Xtest_partial_as_arg():
def test_partial_as_arg():
class PartialArgNet(nn.Cell): class PartialArgNet(nn.Cell):
def __init__(self): def __init__(self):
super(PartialArgNet, self).__init__() super(PartialArgNet, self).__init__()


Loading…
Cancel
Save