| @@ -40,6 +40,7 @@ | |||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| #include "pipeline/pynative/pynative_execute.h" | #include "pipeline/pynative/pynative_execute.h" | ||||
| #include "frontend/optimizer/py_pass_manager.h" | #include "frontend/optimizer/py_pass_manager.h" | ||||
| #include "pybind_api/pybind_patch.h" | |||||
| #if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES) | #if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES) | ||||
| #include "frontend/parallel/ps/common.h" | #include "frontend/parallel/ps/common.h" | ||||
| @@ -536,6 +537,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py: | |||||
| } catch (const py::index_error &ex) { | } catch (const py::index_error &ex) { | ||||
| ReleaseResource(phase); | ReleaseResource(phase); | ||||
| throw py::index_error(ex); | throw py::index_error(ex); | ||||
| } catch (const py::attribute_error &ex) { | |||||
| ReleaseResource(phase); | |||||
| throw py::attribute_error(ex); | |||||
| } catch (const std::exception &ex) { | } catch (const std::exception &ex) { | ||||
| ReleaseResource(phase); | ReleaseResource(phase); | ||||
| // re-throw this exception to Python interpreter to handle it | // re-throw this exception to Python interpreter to handle it | ||||
| @@ -761,8 +761,8 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng | |||||
| ValuePtr method = cls->GetMethod(item_name); | ValuePtr method = cls->GetMethod(item_name); | ||||
| if (method->isa<AnyValue>()) { | if (method->isa<AnyValue>()) { | ||||
| MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() | |||||
| << ", item value: " << item_v->ToString(); | |||||
| MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString() | |||||
| << ", item value: " << item_v->ToString(); | |||||
| } | } | ||||
| // Infer class method | // Infer class method | ||||
| @@ -0,0 +1,24 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PYBIND_API_PYBIND_PATCH_H_ | |||||
| #define PYBIND_API_PYBIND_PATCH_H_ | |||||
| namespace pybind11 { | |||||
| PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError) | |||||
| } | |||||
| #endif // PYBIND_API_PYBIND_PATCH_H_ | |||||
| @@ -145,10 +145,11 @@ static std::string ExceptionTypeToString(ExceptionType type) { | |||||
| _TO_STRING(IndexError), | _TO_STRING(IndexError), | ||||
| _TO_STRING(ValueError), | _TO_STRING(ValueError), | ||||
| _TO_STRING(TypeError), | _TO_STRING(TypeError), | ||||
| _TO_STRING(AttributeError), | |||||
| }; | }; | ||||
| // clang-format on | // clang-format on | ||||
| #undef _TO_STRING | #undef _TO_STRING | ||||
| if (type < UnknownError || type > TypeError) { | |||||
| if (type < UnknownError || type > AttributeError) { | |||||
| type = UnknownError; | type = UnknownError; | ||||
| } | } | ||||
| return std::string(type_names[type]); | return std::string(type_names[type]); | ||||
| @@ -212,7 +213,7 @@ void LogWriter::operator^(const LogStream &stream) const { | |||||
| std::ostringstream oss; | std::ostringstream oss; | ||||
| oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] "; | oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] "; | ||||
| if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError && | if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError && | ||||
| exception_type_ != ValueError) { | |||||
| exception_type_ != ValueError && exception_type_ != AttributeError) { | |||||
| oss << ExceptionTypeToString(exception_type_) << " "; | oss << ExceptionTypeToString(exception_type_) << " "; | ||||
| } | } | ||||
| oss << msg.str(); | oss << msg.str(); | ||||
| @@ -58,6 +58,7 @@ enum ExceptionType { | |||||
| IndexError, | IndexError, | ||||
| ValueError, | ValueError, | ||||
| TypeError, | TypeError, | ||||
| AttributeError, | |||||
| }; | }; | ||||
| struct LocationInfo { | struct LocationInfo { | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| #include "pybind_api/pybind_patch.h" | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -38,6 +39,9 @@ class PyExceptionInitializer { | |||||
| if (exception_type == TypeError) { | if (exception_type == TypeError) { | ||||
| throw py::type_error(str); | throw py::type_error(str); | ||||
| } | } | ||||
| if (exception_type == AttributeError) { | |||||
| throw py::attribute_error(str); | |||||
| } | |||||
| py::pybind11_fail(str); | py::pybind11_fail(str); | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -304,6 +304,29 @@ def test_access(): | |||||
| """ test_access """ | """ test_access """ | ||||
| invoke_dataclass(1, 2) | invoke_dataclass(1, 2) | ||||
| @dataclass | |||||
| class Access2: | |||||
| a: int | |||||
| b: int | |||||
| def max(self): | |||||
| if self.a > self.b: | |||||
| return self.c | |||||
| return self.b | |||||
| @ms_function | |||||
| def invoke_dataclass2(x, y): | |||||
| """ invoke_dataclass """ | |||||
| acs = Access2(x, y) | |||||
| return acs.max() | |||||
| def test_access_attr_error(): | |||||
| """ test_access """ | |||||
| with pytest.raises(AttributeError): | |||||
| invoke_dataclass2(1, 2) | |||||
| def myfunc(x): | def myfunc(x): | ||||
| """ myfunc """ | """ myfunc """ | ||||