Browse Source

!4911 add func type check for switch layer

Merge pull request !4911 from riemann_penn/add_func_type_check_for_switch_layer
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
c92f7c9170
3 changed files with 30 additions and 14 deletions
  1. +6
    -12
      mindspore/ccsrc/frontend/operator/composite/composite.cc
  2. +2
    -2
      mindspore/core/abstract/prim_statement.cc
  3. +22
    -0
      tests/ut/python/ops/test_control_ops.py

+ 6
- 12
mindspore/ccsrc/frontend/operator/composite/composite.cc View File

@@ -989,19 +989,13 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar
// args: tuple of items, index
const std::string op_name = std::string("TupleGetItemTensor");
abstract::CheckArgsSize(op_name, args_spec_list, 2);
AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
AbstractBasePtrList branches = branches_abs->elements();
if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) {
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr functions = ret_graph->add_parameter();
auto index = ret_graph->add_parameter();
auto ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
auto functions = ret_graph->add_parameter();
auto index = ret_graph->add_parameter();

ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
return ret_graph;
}

MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << ".";
ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
return ret_graph;
}

REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {


+ 2
- 2
mindspore/core/abstract/prim_statement.cc View File

@@ -114,14 +114,14 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
AbstractBasePtrList branches = branches_abs->elements();
const size_t maximum_layer_num = 1000;
if (branches.size() < 0 || branches.size() > maximum_layer_num) {
if (branches.size() < 1 || branches.size() > maximum_layer_num) {
MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
<< branches.size() << " branches.";
}

for (size_t i = 0; i < branches.size(); i++) {
MS_EXCEPTION_IF_NULL(branches[i]);
if (!branches[i]->isa<AbstractFunction>()) {
if (!branches[i]->isa<FuncGraphAbstractClosure>()) {
MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got "
<< branches[i]->ToString() << " as the " << i << "th element.";
}


+ 22
- 0
tests/ut/python/ops/test_control_ops.py View File

@@ -851,3 +851,25 @@ def test_tensor_all_construct_lack_branch():
net = NetConditionLackBranch()
with pytest.raises(Exception):
net(input_tensor_1, input_tensor_2)


def test_parser_switch_layer_func_primitive():
class FinalNet(nn.Cell):
def __init__(self, funcs):
super().__init__()
self.funcs = funcs

def construct(self, i, input1):
x = self.funcs[i](input1)
return x

func1 = P.ReLU()
func2 = P.Softmax()
funcs = (func1, func2)
net = FinalNet(funcs)

input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
i = Tensor(1, mstype.int32)

with pytest.raises(ValueError):
net(i, input1)

Loading…
Cancel
Save