| @@ -18,7 +18,6 @@ | |||||
| #include <unistd.h> | #include <unistd.h> | ||||
| #include <map> | #include <map> | ||||
| #include "pybind11/pybind11.h" | |||||
| #include "debug/trace.h" | #include "debug/trace.h" | ||||
| // namespace to support utils module definition | // namespace to support utils module definition | ||||
| @@ -219,16 +218,10 @@ void LogWriter::operator^(const LogStream &stream) const { | |||||
| trace::TraceGraphEval(); | trace::TraceGraphEval(); | ||||
| trace::GetEvalStackInfo(oss); | trace::GetEvalStackInfo(oss); | ||||
| if (exception_type_ == IndexError) { | |||||
| throw pybind11::index_error(oss.str()); | |||||
| if (exception_handler_ != nullptr) { | |||||
| exception_handler_(exception_type_, oss.str()); | |||||
| } | } | ||||
| if (exception_type_ == ValueError) { | |||||
| throw pybind11::value_error(oss.str()); | |||||
| } | |||||
| if (exception_type_ == TypeError) { | |||||
| throw pybind11::type_error(oss.str()); | |||||
| } | |||||
| pybind11::pybind11_fail(oss.str()); | |||||
| throw std::runtime_error(oss.str()); | |||||
| } | } | ||||
| static std::string GetEnv(const std::string &envvar) { | static std::string GetEnv(const std::string &envvar) { | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <sstream> | #include <sstream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <functional> | |||||
| #include "./overload.h" | #include "./overload.h" | ||||
| #include "./securec.h" | #include "./securec.h" | ||||
| #ifdef USE_GLOG | #ifdef USE_GLOG | ||||
| @@ -133,6 +134,8 @@ extern int g_ms_submodule_log_levels[] __attribute__((visibility("default"))); | |||||
| class LogWriter { | class LogWriter { | ||||
| public: | public: | ||||
| using ExceptionHandler = std::function<void(ExceptionType, const std::string &msg)>; | |||||
| LogWriter(const LocationInfo &location, MsLogLevel log_level, SubModuleId submodule, | LogWriter(const LocationInfo &location, MsLogLevel log_level, SubModuleId submodule, | ||||
| ExceptionType excp_type = NoExceptionType) | ExceptionType excp_type = NoExceptionType) | ||||
| : location_(location), log_level_(log_level), submodule_(submodule), exception_type_(excp_type) {} | : location_(location), log_level_(log_level), submodule_(submodule), exception_type_(excp_type) {} | ||||
| @@ -141,6 +144,8 @@ class LogWriter { | |||||
| void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))); | void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))); | ||||
| void operator^(const LogStream &stream) const __attribute__((noreturn, visibility("default"))); | void operator^(const LogStream &stream) const __attribute__((noreturn, visibility("default"))); | ||||
| static void set_exception_handler(ExceptionHandler exception_handler) { exception_handler_ = exception_handler; } | |||||
| private: | private: | ||||
| void OutputLog(const std::ostringstream &msg) const; | void OutputLog(const std::ostringstream &msg) const; | ||||
| @@ -148,6 +153,8 @@ class LogWriter { | |||||
| MsLogLevel log_level_; | MsLogLevel log_level_; | ||||
| SubModuleId submodule_; | SubModuleId submodule_; | ||||
| ExceptionType exception_type_; | ExceptionType exception_type_; | ||||
| inline static ExceptionHandler exception_handler_ = nullptr; | |||||
| }; | }; | ||||
| #define MSLOG_IF(level, condition, excp_type) \ | #define MSLOG_IF(level, condition, excp_type) \ | ||||
| @@ -0,0 +1,46 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "utils/log_adapter.h" | |||||
| #include <string> | |||||
| #include "pybind11/pybind11.h" | |||||
| namespace py = pybind11; | |||||
| namespace mindspore { | |||||
| class PyExceptionInitializer { | |||||
| public: | |||||
| PyExceptionInitializer() { mindspore::LogWriter::set_exception_handler(HandleExceptionPy); } | |||||
| ~PyExceptionInitializer() = default; | |||||
| private: | |||||
| static void HandleExceptionPy(ExceptionType exception_type, const std::string &str) { | |||||
| if (exception_type == IndexError) { | |||||
| throw py::index_error(str); | |||||
| } | |||||
| if (exception_type == ValueError) { | |||||
| throw py::value_error(str); | |||||
| } | |||||
| if (exception_type == TypeError) { | |||||
| throw py::type_error(str); | |||||
| } | |||||
| py::pybind11_fail(str); | |||||
| } | |||||
| }; | |||||
| static PyExceptionInitializer py_exception_initializer; | |||||
| } // namespace mindspore | |||||
| @@ -127,11 +127,17 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) { | |||||
| try { | try { | ||||
| trace::ClearTraceStack(); | trace::ClearTraceStack(); | ||||
| engine_->Run(tupleSliceGraphPtr, args_spec_list); | engine_->Run(tupleSliceGraphPtr, args_spec_list); | ||||
| FAIL() << "Excepted exception :Args type is wrong"; | |||||
| FAIL() << "Excepted exception: Args type is wrong"; | |||||
| } catch (pybind11::type_error const &err) { | } catch (pybind11::type_error const &err) { | ||||
| ASSERT_TRUE(true); | ASSERT_TRUE(true); | ||||
| } catch (std::runtime_error const &err) { | |||||
| if (std::strstr(err.what(), "TypeError") != nullptr) { | |||||
| ASSERT_TRUE(true); | |||||
| } else { | |||||
| FAIL() << "Excepted exception: Args type is wrong, message: " << err.what(); | |||||
| } | |||||
| } catch (...) { | } catch (...) { | ||||
| FAIL() << "Excepted exception :Args type is wrong"; | |||||
| FAIL() << "Excepted exception: Args type is wrong"; | |||||
| } | } | ||||
| } | } | ||||