Merge pull request !4911 from riemann_penn/add_func_type_check_for_switch_layertags/v0.7.0-beta
| @@ -989,19 +989,13 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar | |||||
| // args: tuple of items, index | // args: tuple of items, index | ||||
| const std::string op_name = std::string("TupleGetItemTensor"); | const std::string op_name = std::string("TupleGetItemTensor"); | ||||
| abstract::CheckArgsSize(op_name, args_spec_list, 2); | abstract::CheckArgsSize(op_name, args_spec_list, 2); | ||||
| AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||||
| AbstractBasePtrList branches = branches_abs->elements(); | |||||
| if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) { | |||||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||||
| ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr functions = ret_graph->add_parameter(); | |||||
| auto index = ret_graph->add_parameter(); | |||||
| auto ret_graph = std::make_shared<FuncGraph>(); | |||||
| ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); | |||||
| auto functions = ret_graph->add_parameter(); | |||||
| auto index = ret_graph->add_parameter(); | |||||
| ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); | |||||
| return ret_graph; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << "."; | |||||
| ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); | |||||
| return ret_graph; | |||||
| } | } | ||||
| REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { | REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { | ||||
| @@ -114,14 +114,14 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP | |||||
| AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | ||||
| AbstractBasePtrList branches = branches_abs->elements(); | AbstractBasePtrList branches = branches_abs->elements(); | ||||
| const size_t maximum_layer_num = 1000; | const size_t maximum_layer_num = 1000; | ||||
| if (branches.size() < 0 || branches.size() > maximum_layer_num) { | |||||
| if (branches.size() < 1 || branches.size() > maximum_layer_num) { | |||||
| MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " | MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " | ||||
| << branches.size() << " branches."; | << branches.size() << " branches."; | ||||
| } | } | ||||
| for (size_t i = 0; i < branches.size(); i++) { | for (size_t i = 0; i < branches.size(); i++) { | ||||
| MS_EXCEPTION_IF_NULL(branches[i]); | MS_EXCEPTION_IF_NULL(branches[i]); | ||||
| if (!branches[i]->isa<AbstractFunction>()) { | |||||
| if (!branches[i]->isa<FuncGraphAbstractClosure>()) { | |||||
| MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got " | MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got " | ||||
| << branches[i]->ToString() << " as the " << i << "th element."; | << branches[i]->ToString() << " as the " << i << "th element."; | ||||
| } | } | ||||
| @@ -851,3 +851,25 @@ def test_tensor_all_construct_lack_branch(): | |||||
| net = NetConditionLackBranch() | net = NetConditionLackBranch() | ||||
| with pytest.raises(Exception): | with pytest.raises(Exception): | ||||
| net(input_tensor_1, input_tensor_2) | net(input_tensor_1, input_tensor_2) | ||||
| def test_parser_switch_layer_func_primitive(): | |||||
| class FinalNet(nn.Cell): | |||||
| def __init__(self, funcs): | |||||
| super().__init__() | |||||
| self.funcs = funcs | |||||
| def construct(self, i, input1): | |||||
| x = self.funcs[i](input1) | |||||
| return x | |||||
| func1 = P.ReLU() | |||||
| func2 = P.Softmax() | |||||
| funcs = (func1, func2) | |||||
| net = FinalNet(funcs) | |||||
| input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) | |||||
| i = Tensor(1, mstype.int32) | |||||
| with pytest.raises(ValueError): | |||||
| net(i, input1) | |||||