Merge pull request !1979 from fary86/fix_patial_primitive_poly_codetags/v0.5.0-beta
| @@ -378,11 +378,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr | |||||
| } | } | ||||
| auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval); | auto real_eval = dyn_cast<BaseFuncGraphEvaluator>(eval); | ||||
| if (func->context() != nullptr) { | |||||
| if (!IsVisible(func_graph_, func->context()->func_graph())) { | |||||
| MS_LOG(EXCEPTION) << "Func is not visible NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); | |||||
| } | |||||
| } else { | |||||
| if (func->context() == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); | MS_LOG(EXCEPTION) << "Func context is nullptr NodeInfo: " << trace::GetDebugInfo(func_graph_->debug_info()); | ||||
| } | } | ||||
| AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); | AnalysisContextPtr context = real_eval->MakeContext(engine_, argvals); | ||||
| @@ -507,9 +503,9 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||||
| // First element is partial, second is func so arg is start from 2 | // First element is partial, second is func so arg is start from 2 | ||||
| (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); | (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); | ||||
| func = inputs[1]; | func = inputs[1]; | ||||
| new_inputs = args; | |||||
| (void)new_inputs.insert(new_inputs.begin(), func); | |||||
| } | } | ||||
| new_inputs = args; | |||||
| (void)new_inputs.insert(new_inputs.begin(), func); | |||||
| AbstractBasePtrList argvals; | AbstractBasePtrList argvals; | ||||
| MS_EXCEPTION_IF_NULL(new_inputs[0]); | MS_EXCEPTION_IF_NULL(new_inputs[0]); | ||||
| @@ -524,9 +520,23 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||||
| << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); | << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); | ||||
| } | } | ||||
| if (func->isa<Parameter>() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) { | |||||
| auto wrapped_node = BuildSpecializedParameterNode(new_node); | |||||
| new_inputs[0] = wrapped_node; | |||||
| if (!func->isa<ValueNode>()) { | |||||
| MS_LOG(DEBUG) << func->abstract()->type_name() << " | " << func->abstract()->ToString(); | |||||
| if (func->abstract()->isa<AbstractFunction>() && !func->abstract()->isa<AbstractFuncUnion>()) { | |||||
| auto func_abs = func->abstract()->cast<AbstractFunctionPtr>(); | |||||
| EvaluatorPtr eval = engine_->GetEvaluatorFor(func_abs); | |||||
| std::pair<AbstractBasePtrList, AbstractBasePtr> result; | |||||
| AbstractBasePtrList empty_args; | |||||
| auto status = FindUniqueArgvals(func_abs, eval, empty_args, &result); | |||||
| MS_LOG(DEBUG) << "FindUniqueArgvals return status: " << status; | |||||
| // if a node is a poly node, or an input parameter is a PartialAbstractClosure, expand it early | |||||
| if (status == kSpecializeFindUniqueArgvalPoly || | |||||
| (func->isa<Parameter>() && (func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER) || | |||||
| func->abstract()->isa<PartialAbstractClosure>()))) { | |||||
| auto wrapped_node = BuildSpecializedParameterNode(new_node); | |||||
| new_inputs[0] = wrapped_node; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| if (CanSpecializeNode(func)) { | if (CanSpecializeNode(func)) { | ||||
| @@ -14,9 +14,12 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test nn ops """ | """ test nn ops """ | ||||
| import numpy as np | import numpy as np | ||||
| from numpy.random import normal | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore.ops.composite import core | |||||
| from mindspore.common.api import ms_function | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| @@ -59,10 +62,39 @@ def test_conv2d_same_primitive(): | |||||
| net(t1, t2) | net(t1, t2) | ||||
| # test free variable function list as parameter | |||||
| def test_remove_and_fv_2(): | |||||
| @core(loop_can_uroll=True) | |||||
| def inner_loop(x, input_data, fv_func_list): | |||||
| ret = () | |||||
| for fv_fn in fv_func_list: | |||||
| ele = fv_fn(input_data) | |||||
| ret += (ele,) | |||||
| return ret | |||||
| @ms_function | |||||
| def out_loop(input1, input_data): | |||||
| ret = () | |||||
| def fv_func1(y): | |||||
| return input1 * y | |||||
| def fv_func2(y): | |||||
| return input1 - y | |||||
| fv_func_list = [fv_func1, fv_func2] | |||||
| ele0 = inner_loop(input1, input_data[0], fv_func_list) | |||||
| ele1 = inner_loop(input1, input_data[1], fv_func_list) | |||||
| ret = (ele0, ele1) | |||||
| return ret | |||||
| input_data = (Tensor(normal(0, 0.1, (3, 3))), Tensor(normal(0, 0.1, (3, 1)))) | |||||
| input1 = Tensor(normal(0, 0.1, (3, 3))) | |||||
| out_loop(input1, input_data) | |||||
| # test cell as high order argument | # test cell as high order argument | ||||
| # The graph with free variables used as argument is not supported yet | # The graph with free variables used as argument is not supported yet | ||||
| # because of the limit of inference specialize system | # because of the limit of inference specialize system | ||||
| def Xtest_conv2d_op_with_arg(): | |||||
| def test_conv2d_op_with_argi_1(): | |||||
| class Conv2dNet(nn.Cell): | class Conv2dNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Conv2dNet, self).__init__() | super(Conv2dNet, self).__init__() | ||||
| @@ -279,7 +311,7 @@ def test_op_with_arg_as_input(): | |||||
| # The partial application used as argument is not supported yet | # The partial application used as argument is not supported yet | ||||
| # because of the limit of inference specialize system | # because of the limit of inference specialize system | ||||
| def Xtest_partial_as_arg(): | |||||
| def test_partial_as_arg(): | |||||
| class PartialArgNet(nn.Cell): | class PartialArgNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(PartialArgNet, self).__init__() | super(PartialArgNet, self).__init__() | ||||