Merge pull request !4425 from fary86/fix_if_no_return_coredumptags/v0.7.0-beta
| @@ -22,7 +22,7 @@ from .parser import (Parser, create_obj_instance, generate_scope, | |||||
| get_dataclass_attributes, get_dataclass_methods, get_obj_id, | get_dataclass_attributes, get_dataclass_methods, get_obj_id, | ||||
| get_module_namespace, get_obj_type, get_object_key, | get_module_namespace, get_obj_type, get_object_key, | ||||
| get_parse_method_of_class, get_scope_name, | get_parse_method_of_class, get_scope_name, | ||||
| is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor) | |||||
| is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, get_object_description) | |||||
| from .serialize import * | from .serialize import * | ||||
| __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', | __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', | ||||
| @@ -30,4 +30,4 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', | |||||
| 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', | 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', | ||||
| 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', | 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', | ||||
| 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', | 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', | ||||
| 'create_slice_obj', 'convert_to_ms_tensor'] | |||||
| 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description'] | |||||
| @@ -322,6 +322,20 @@ def convert_to_ms_tensor(data): | |||||
| return MsTensor(data) | return MsTensor(data) | ||||
| def get_object_description(obj, fname, fline): | |||||
| """return method or funcition description for error report, include location, class name, etc.""" | |||||
| if isinstance(obj, types.MethodType): | |||||
| obj_cls = obj.__self__.__class__ | |||||
| class_name = f'{obj_cls.__module__}.{obj_cls.__qualname__}' | |||||
| cls_fname = inspect.getfile(obj_cls) | |||||
| _, cls_fline = inspect.getsourcelines(obj_cls) | |||||
| class_loc = f'{cls_fname}:{cls_fline}' | |||||
| return f"bound method '{obj.__name__}' at {fname}:{fline} of <{class_name} at {class_loc} object>" | |||||
| if isinstance(obj, (types.FunctionType, ast.FunctionDef)): | |||||
| return f"function '{obj.name}' at {fname}:{fline}" | |||||
| return str(obj) | |||||
| class Parser: | class Parser: | ||||
| """ | """ | ||||
| Parser python code to ast tree. | Parser python code to ast tree. | ||||
| @@ -154,6 +154,23 @@ FuncGraphPtr Parser::ParseFuncGraph() { | |||||
| RemoveUnnecessaryPhis(); | RemoveUnnecessaryPhis(); | ||||
| MS_EXCEPTION_IF_NULL(pFnBlock); | MS_EXCEPTION_IF_NULL(pFnBlock); | ||||
| // check whether the functions refered by this function and itself are missing 'return' statement | |||||
| auto mng = Manage(pFnBlock->func_graph(), false); | |||||
| for (auto func_graph : mng->func_graphs()) { | |||||
| if (func_graph->get_return() != nullptr) { | |||||
| continue; | |||||
| } | |||||
| py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); | |||||
| py::str desc = | |||||
| python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]); | |||||
| MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << "."; | |||||
| } | |||||
| // clear manager info after checking missing return | |||||
| for (auto fg : mng->func_graphs()) { | |||||
| fg->ClearAllManagerInfo(); | |||||
| } | |||||
| return pFnBlock->func_graph(); | return pFnBlock->func_graph(); | ||||
| } | } | ||||
| @@ -271,9 +288,9 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo | |||||
| (void)ParseStatements(pFunBlock, funcObj); | (void)ParseStatements(pFunBlock, funcObj); | ||||
| if (current_fg->get_return() == nullptr) { | if (current_fg->get_return() == nullptr) { | ||||
| MS_LOG(ERROR) << "Graph return node is null, loc:" << GetLocation(node)->ToString(); | |||||
| errcode_ = PARSE_NO_RETURN; | |||||
| return pFunBlock; | |||||
| py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); | |||||
| py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]); | |||||
| MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << "."; | |||||
| } | } | ||||
| GenerateArgsDefaultValueForFunction(pFunBlock, node); | GenerateArgsDefaultValueForFunction(pFunBlock, node); | ||||
| return pFunBlock; | return pFunBlock; | ||||
| @@ -323,7 +340,11 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py: | |||||
| } | } | ||||
| auto filename = location[0].cast<std::string>(); | auto filename = location[0].cast<std::string>(); | ||||
| auto line_no = location[1].cast<int>(); | auto line_no = location[1].cast<int>(); | ||||
| MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; | |||||
| auto fn_loc = block->func_graph()->debug_info()->location(); | |||||
| py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), | |||||
| fn_loc->file_name(), fn_loc->line()); | |||||
| MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in " | |||||
| << desc.cast<std::string>() << "."; | |||||
| } | } | ||||
| } | } | ||||
| @@ -350,7 +371,11 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object | |||||
| py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); | py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); | ||||
| auto filename = ret[0].cast<std::string>(); | auto filename = ret[0].cast<std::string>(); | ||||
| auto line_no = ret[1].cast<int>(); | auto line_no = ret[1].cast<int>(); | ||||
| MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no; | |||||
| auto fn_loc = block->func_graph()->debug_info()->location(); | |||||
| py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), | |||||
| fn_loc->file_name(), fn_loc->line()); | |||||
| MS_LOG(EXCEPTION) << "Unsupported syntax '" << node_name << "' at " << filename << ":" << line_no << " in " | |||||
| << desc.cast<std::string>() << "."; | |||||
| } | } | ||||
| } | } | ||||
| @@ -69,6 +69,7 @@ const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; | |||||
| const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; | const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; | ||||
| const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class"; | const char PYTHON_MOD_GET_PARSE_METHOD[] = "get_parse_method_of_class"; | ||||
| const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; | const char PYTHON_MOD_GET_BPROP_METHOD[] = "get_bprop_method_of_class"; | ||||
| const char PYTHON_MOD_GET_OBJECT_DESCRIPTION[] = "get_object_description"; | |||||
| const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor"; | const char PYTHON_MOD_CONVERT_TO_MS_TENSOR[] = "convert_to_ms_tensor"; | ||||
| const char PYTHON_PARSE_GET_ARGS[] = "get_args"; | const char PYTHON_PARSE_GET_ARGS[] = "get_args"; | ||||
| @@ -379,7 +379,11 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector<AnfNodePtr> & | |||||
| FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); | FuncGraphSetPtr func_graphs_to_check = std::make_shared<FuncGraphSet>(); | ||||
| while (!nodes_ordered.empty()) { | while (!nodes_ordered.empty()) { | ||||
| AnfNodePtr node = nodes_ordered.pop(); | AnfNodePtr node = nodes_ordered.pop(); | ||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (node == nullptr) { | |||||
| // Here can not call 'MS_EXCEPTION_IF_NULL' to throw exception, this method may be triggered by desctuctor | |||||
| MS_LOG(WARNING) << "Node to be dropped is nullptr"; | |||||
| continue; | |||||
| } | |||||
| if (!all_nodes_.contains(node)) { | if (!all_nodes_.contains(node)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -96,21 +96,6 @@ TEST_F(TestParser, TestParseGraphSuccess) { | |||||
| ASSERT_TRUE(nullptr != func_graph); | ASSERT_TRUE(nullptr != func_graph); | ||||
| } | } | ||||
| TEST_F(TestParser, TestParseGraphFailure) { | |||||
| GetPythonFunction("get_no_return_fn"); | |||||
| // create parser | |||||
| std::shared_ptr<ParseAst> ast = std::make_shared<ParseAst>(fn); | |||||
| bool succ = ast->InitParseAstInfo(); | |||||
| ASSERT_TRUE(succ = true); | |||||
| std::shared_ptr<Parser> parser = std::make_shared<Parser>(ast); | |||||
| // parse ast to graph | |||||
| FuncGraphPtr func_graph = parser->ParseFuncGraph(); | |||||
| ASSERT_EQ(PARSE_NO_RETURN, parser->errcode()); | |||||
| ASSERT_TRUE(nullptr == func_graph); | |||||
| } | |||||
| TEST_F(TestParser, TestParseGraphIf) { | TEST_F(TestParser, TestParseGraphIf) { | ||||
| GetPythonFunction("test_if"); | GetPythonFunction("test_if"); | ||||
| @@ -689,3 +689,26 @@ def test_while_concat(): | |||||
| x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32)) | x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32)) | ||||
| net = Net(x) | net = Net(x) | ||||
| net(x) | net(x) | ||||
| def test_tensor_all_construct_lack_branch(): | |||||
| class NetConditionLackBranch(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetConditionLackBranch, self).__init__() | |||||
| self.logicaland = P.LogicalAnd() | |||||
| self.logicalor = P.LogicalOr() | |||||
| def construct(self, input1, input2): | |||||
| if input1.all(): | |||||
| return self.logicaland(input1, input2) | |||||
| while input1.any(): | |||||
| return self.logicalor(input1, input2) | |||||
| # NOTICE: here missing return statement, default return None | |||||
| input_np_1 = np.random.choice([True], size=(2, 3, 4, 5)) | |||||
| input_tensor_1 = Tensor(input_np_1) | |||||
| input_np_2 = np.random.choice([True, False], size=(2, 3, 4, 5)) | |||||
| input_tensor_2 = Tensor(input_np_2) | |||||
| net = NetConditionLackBranch() | |||||
| with pytest.raises(Exception): | |||||
| net(input_tensor_1, input_tensor_2) | |||||
| @@ -16,6 +16,7 @@ | |||||
| import functools | import functools | ||||
| import logging | import logging | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.context as context | import mindspore.context as context | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| @@ -62,13 +63,9 @@ def test_net_without_construct(): | |||||
| """ test_net_without_construct """ | """ test_net_without_construct """ | ||||
| net = NetMissConstruct() | net = NetMissConstruct() | ||||
| inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) | inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) | ||||
| try: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| _executor.compile(net, inp) | _executor.compile(net, inp) | ||||
| except RuntimeError as err: | |||||
| if str(err).find("Unsupported syntax 'Raise' at ") >= 0: | |||||
| print(str(err)) | |||||
| else: | |||||
| raise err | |||||
| assert "Unsupported syntax 'Raise' at " in str(err.value) | |||||
| class NetWithRaise(nn.Cell): | class NetWithRaise(nn.Cell): | ||||
| @@ -87,13 +84,9 @@ def test_net_with_raise(): | |||||
| """ test_net_with_raise """ | """ test_net_with_raise """ | ||||
| net = NetWithRaise() | net = NetWithRaise() | ||||
| inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) | inp = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)) | ||||
| try: | |||||
| with pytest.raises(RuntimeError) as err: | |||||
| _executor.compile(net, inp) | _executor.compile(net, inp) | ||||
| except RuntimeError as err: | |||||
| if str(err).find("Unsupported syntax 'Raise' at ") >= 0: | |||||
| print(str(err)) | |||||
| else: | |||||
| raise err | |||||
| assert "Unsupported syntax 'Raise' at " in str(err.value) | |||||
| class NetAddN(nn.Cell): | class NetAddN(nn.Cell): | ||||
| @@ -0,0 +1,201 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| test mindspore grammar constraints | |||||
| 1. funtion must have return statement | |||||
| 2. raise statement can not be used | |||||
| """ | |||||
| # pylint: disable=R1705, R1710, W0223 | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore import dtype as mstype | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| def test_missing_return(): | |||||
| class NetMissReturn(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetMissReturn, self).__init__() | |||||
| def construct(self, x, y, z): | |||||
| if x == 1: | |||||
| return 10 | |||||
| elif x == 20: | |||||
| if y == 1: | |||||
| return 3 | |||||
| elif y == 2: | |||||
| for i in range(z): | |||||
| return i + z | |||||
| i = 0 | |||||
| while i < z: | |||||
| return i + z | |||||
| def g(u): | |||||
| return x + u | |||||
| # here method 'construct' misses a return statement | |||||
| g(y) | |||||
| else: | |||||
| return 7 | |||||
| else: | |||||
| return 5 | |||||
| net = NetMissReturn() | |||||
| x = Tensor(0, mstype.int32) | |||||
| y = Tensor(5, mstype.int32) | |||||
| z = Tensor(2, mstype.int32) | |||||
| with pytest.raises(TypeError) as er: | |||||
| net(x, y, z) | |||||
| assert "Missing return statement in bound method 'construct'" in str(er.value) | |||||
| def test_nest_function_missing_return(): | |||||
| class NetNestFuncMissReturn(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetNestFuncMissReturn, self).__init__() | |||||
| def construct(self, x, y, z): | |||||
| if x == 1: | |||||
| return 10 | |||||
| elif x == 20: | |||||
| if y == 1: | |||||
| return 3 | |||||
| elif y == 2: | |||||
| for i in range(z): | |||||
| return i + z | |||||
| i = 0 | |||||
| while i < z: | |||||
| return i + z | |||||
| def g(u): | |||||
| x += u | |||||
| # nested function 'g' misses a return a statement | |||||
| return g(y) | |||||
| else: | |||||
| return 7 | |||||
| else: | |||||
| return 5 | |||||
| net = NetNestFuncMissReturn() | |||||
| x = Tensor(0, mstype.int32) | |||||
| y = Tensor(5, mstype.int32) | |||||
| z = Tensor(2, mstype.int32) | |||||
| with pytest.raises(TypeError) as er: | |||||
| net(x, y, z) | |||||
| assert "Missing return statement in function 'g'" in str(er.value) | |||||
| def test_raise_in_method(): | |||||
| class NetRaiseInMethod(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetRaiseInMethod, self).__init__() | |||||
| def construct(self, x, y, z): | |||||
| if x == 1: | |||||
| return 10 | |||||
| elif x == 20: | |||||
| # add not support grammar 'raise' here | |||||
| raise ValueError('Illegal case') | |||||
| else: | |||||
| return y + z | |||||
| net = NetRaiseInMethod() | |||||
| x = Tensor(0, mstype.int32) | |||||
| y = Tensor(5, mstype.int32) | |||||
| z = Tensor(2, mstype.int32) | |||||
| with pytest.raises(RuntimeError) as er: | |||||
| net(x, y, z) | |||||
| assert "Unsupported syntax 'Raise' at" in str(er.value) | |||||
| def test_raise_in_nested_function(): | |||||
| class NetNestRaise(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetNestRaise, self).__init__() | |||||
| def construct(self, x, y, z): | |||||
| if x == 1: | |||||
| return 10 | |||||
| elif x == 20: | |||||
| def nest_fn(u): | |||||
| if u > 0: | |||||
| # add not support grammar 'raise' here | |||||
| raise ValueError('Illegal case') | |||||
| return u + z + 1 | |||||
| return nest_fn(y) | |||||
| else: | |||||
| return y + z | |||||
| net = NetNestRaise() | |||||
| x = Tensor(0, mstype.int32) | |||||
| y = Tensor(5, mstype.int32) | |||||
| z = Tensor(2, mstype.int32) | |||||
| with pytest.raises(RuntimeError) as er: | |||||
| net(x, y, z) | |||||
| assert "Unsupported syntax 'Raise' at " in str(er.value) | |||||
| def test_nest_branch_with_return(): | |||||
| class NetBranchWithReturn(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetBranchWithReturn, self).__init__() | |||||
| def construct(self, x, y, z): | |||||
| if x == 1: | |||||
| return 10 | |||||
| else: | |||||
| return 5 | |||||
| context.set_context(save_graphs=True) | |||||
| net = NetBranchWithReturn() | |||||
| x = Tensor(0, mstype.int32) | |||||
| y = Tensor(5, mstype.int32) | |||||
| z = Tensor(2, mstype.int32) | |||||
| net(x, y, z) | |||||
| def test_any_with_no_return(): | |||||
| class NetAnyNoReturn(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetAnyNoReturn, self).__init__() | |||||
| def construct(self, inp): | |||||
| result = inp.any() | |||||
| if result: | |||||
| return 6 | |||||
| np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_) | |||||
| tensor = Tensor(np_input) | |||||
| net = NetAnyNoReturn() | |||||
| with pytest.raises(TypeError) as er: | |||||
| net(tensor) | |||||
| assert "Missing return statement in bound method 'construct'" in str(er.value) | |||||
| def test_missing_construct(): | |||||
| class NetMissConstruct(nn.Cell): | |||||
| def __init__(self): | |||||
| super(NetMissConstruct, self).__init__() | |||||
| def construct1(self, inp): | |||||
| return 5 | |||||
| np_input = np.arange(2 * 3 * 4).reshape((2, 3, 4)).astype(np.bool_) | |||||
| tensor = Tensor(np_input) | |||||
| net = NetMissConstruct() | |||||
| with pytest.raises(RuntimeError) as er: | |||||
| net(tensor) | |||||
| assert "Unsupported syntax 'Raise' at " in str(er.value) | |||||