| @@ -18,7 +18,6 @@ | |||
| #include <unistd.h> | |||
| #include <map> | |||
| #include "pybind11/pybind11.h" | |||
| #include "debug/trace.h" | |||
| // namespace to support utils module definition | |||
| @@ -219,16 +218,10 @@ void LogWriter::operator^(const LogStream &stream) const { | |||
| trace::TraceGraphEval(); | |||
| trace::GetEvalStackInfo(oss); | |||
| if (exception_type_ == IndexError) { | |||
| throw pybind11::index_error(oss.str()); | |||
| if (exception_handler_ != nullptr) { | |||
| exception_handler_(exception_type_, oss.str()); | |||
| } | |||
| if (exception_type_ == ValueError) { | |||
| throw pybind11::value_error(oss.str()); | |||
| } | |||
| if (exception_type_ == TypeError) { | |||
| throw pybind11::type_error(oss.str()); | |||
| } | |||
| pybind11::pybind11_fail(oss.str()); | |||
| throw std::runtime_error(oss.str()); | |||
| } | |||
| static std::string GetEnv(const std::string &envvar) { | |||
| @@ -22,6 +22,7 @@ | |||
| #include <string> | |||
| #include <sstream> | |||
| #include <memory> | |||
| #include <functional> | |||
| #include "./overload.h" | |||
| #include "./securec.h" | |||
| #ifdef USE_GLOG | |||
| @@ -133,6 +134,8 @@ extern int g_ms_submodule_log_levels[] __attribute__((visibility("default"))); | |||
| class LogWriter { | |||
| public: | |||
| using ExceptionHandler = std::function<void(ExceptionType, const std::string &msg)>; | |||
| LogWriter(const LocationInfo &location, MsLogLevel log_level, SubModuleId submodule, | |||
| ExceptionType excp_type = NoExceptionType) | |||
| : location_(location), log_level_(log_level), submodule_(submodule), exception_type_(excp_type) {} | |||
| @@ -141,6 +144,8 @@ class LogWriter { | |||
| void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))); | |||
| void operator^(const LogStream &stream) const __attribute__((noreturn, visibility("default"))); | |||
| static void set_exception_handler(ExceptionHandler exception_handler) { exception_handler_ = exception_handler; } | |||
| private: | |||
| void OutputLog(const std::ostringstream &msg) const; | |||
| @@ -148,6 +153,8 @@ class LogWriter { | |||
| MsLogLevel log_level_; | |||
| SubModuleId submodule_; | |||
| ExceptionType exception_type_; | |||
| inline static ExceptionHandler exception_handler_ = nullptr; | |||
| }; | |||
| #define MSLOG_IF(level, condition, excp_type) \ | |||
| @@ -0,0 +1,46 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "utils/log_adapter.h" | |||
| #include <string> | |||
| #include "pybind11/pybind11.h" | |||
| namespace py = pybind11; | |||
| namespace mindspore { | |||
| class PyExceptionInitializer { | |||
| public: | |||
| PyExceptionInitializer() { mindspore::LogWriter::set_exception_handler(HandleExceptionPy); } | |||
| ~PyExceptionInitializer() = default; | |||
| private: | |||
| static void HandleExceptionPy(ExceptionType exception_type, const std::string &str) { | |||
| if (exception_type == IndexError) { | |||
| throw py::index_error(str); | |||
| } | |||
| if (exception_type == ValueError) { | |||
| throw py::value_error(str); | |||
| } | |||
| if (exception_type == TypeError) { | |||
| throw py::type_error(str); | |||
| } | |||
| py::pybind11_fail(str); | |||
| } | |||
| }; | |||
| static PyExceptionInitializer py_exception_initializer; | |||
| } // namespace mindspore | |||
| @@ -127,11 +127,17 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) { | |||
| try { | |||
| trace::ClearTraceStack(); | |||
| engine_->Run(tupleSliceGraphPtr, args_spec_list); | |||
| FAIL() << "Excepted exception :Args type is wrong"; | |||
| FAIL() << "Excepted exception: Args type is wrong"; | |||
| } catch (pybind11::type_error const &err) { | |||
| ASSERT_TRUE(true); | |||
| } catch (std::runtime_error const &err) { | |||
| if (std::strstr(err.what(), "TypeError") != nullptr) { | |||
| ASSERT_TRUE(true); | |||
| } else { | |||
| FAIL() << "Excepted exception: Args type is wrong, message: " << err.what(); | |||
| } | |||
| } catch (...) { | |||
| FAIL() << "Excepted exception :Args type is wrong"; | |||
| FAIL() << "Excepted exception: Args type is wrong"; | |||
| } | |||
| } | |||