From abab21ed57cfc4984b75067e06ddd70a50435551 Mon Sep 17 00:00:00 2001 From: panyifeng Date: Fri, 21 Aug 2020 15:24:47 +0800 Subject: [PATCH] add func type check for switch layer --- .../frontend/operator/composite/composite.cc | 18 +++++---------- mindspore/core/abstract/prim_statement.cc | 4 ++-- tests/ut/python/ops/test_control_ops.py | 22 +++++++++++++++++++ 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 4f0b42952a..e4c0a0723f 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -989,19 +989,13 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar // args: tuple of items, index const std::string op_name = std::string("TupleGetItemTensor"); abstract::CheckArgsSize(op_name, args_spec_list, 2); - AbstractTuplePtr branches_abs = abstract::CheckArg(op_name, args_spec_list, 0); - AbstractBasePtrList branches = branches_abs->elements(); - if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa()) { - FuncGraphPtr ret_graph = std::make_shared(); - ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); - AnfNodePtr functions = ret_graph->add_parameter(); - auto index = ret_graph->add_parameter(); + auto ret_graph = std::make_shared(); + ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); + auto functions = ret_graph->add_parameter(); + auto index = ret_graph->add_parameter(); - ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); - return ret_graph; - } - - MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << "."; + ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); + return ret_graph; } REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { diff --git a/mindspore/core/abstract/prim_statement.cc b/mindspore/core/abstract/prim_statement.cc index 1cc077d300..fe796ec06d 100644 --- a/mindspore/core/abstract/prim_statement.cc +++ b/mindspore/core/abstract/prim_statement.cc @@ -114,14 +114,14 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP AbstractTuplePtr branches_abs = CheckArg(op_name, args_spec_list, 1); AbstractBasePtrList branches = branches_abs->elements(); const size_t maximum_layer_num = 1000; - if (branches.size() < 0 || branches.size() > maximum_layer_num) { + if (branches.size() < 1 || branches.size() > maximum_layer_num) { MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " << branches.size() << " branches."; } for (size_t i = 0; i < branches.size(); i++) { MS_EXCEPTION_IF_NULL(branches[i]); - if (!branches[i]->isa()) { + if (!branches[i]->isa()) { MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got " << branches[i]->ToString() << " as the " << i << "th element."; } diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 6e8467b700..26132165b5 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -851,3 +851,25 @@ def test_tensor_all_construct_lack_branch(): net = NetConditionLackBranch() with pytest.raises(Exception): net(input_tensor_1, input_tensor_2) + + +def test_parser_switch_layer_func_primitive(): + class FinalNet(nn.Cell): + def __init__(self, funcs): + super().__init__() + self.funcs = funcs + + def construct(self, i, input1): + x = self.funcs[i](input1) + return x + + func1 = P.ReLU() + func2 = P.Softmax() + funcs = (func1, func2) + net = FinalNet(funcs) + + input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) + i = Tensor(1, mstype.int32) + + with pytest.raises(ValueError): + net(i, input1)