From 13e53484ea8417ef13df56debd40ad64028f8836 Mon Sep 17 00:00:00 2001 From: kswang Date: Mon, 14 Dec 2020 20:18:47 +0800 Subject: [PATCH] add exception listener --- mindspore/core/ir/tensor.h | 14 ++++++++++---- mindspore/core/utils/ms_exception.h | 24 ++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 672c0f3bb3..8c2bad940e 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -78,18 +78,19 @@ class TensorData { using TensorDataPtr = std::shared_ptr; -struct WaitEvent { - bool need_wait_{false}; - mutable std::mutex mutex_; - mutable std::condition_variable cond_var_; +class WaitEvent : public ExceptionListener { + public: + void OnException() override { set_need_wait(false); } void Wait() const { std::unique_lock lock(mutex_); if (!need_wait_) { return; } + MsException::GetInstance().AddExceptionListener(const_cast(this)); cond_var_.wait(lock, [this] { return !need_wait_; }); MsException::GetInstance().CheckException(); + MsException::GetInstance().RemoveExceptionListener(const_cast(this)); } void set_need_wait(bool need_wait) { @@ -101,6 +102,11 @@ struct WaitEvent { } bool need_wait() const { return need_wait_; } + + private: + bool need_wait_{false}; + mutable std::mutex mutex_; + mutable std::condition_variable cond_var_; }; // Tensor entity class diff --git a/mindspore/core/utils/ms_exception.h b/mindspore/core/utils/ms_exception.h index e4454878c3..fab3310a5b 100644 --- a/mindspore/core/utils/ms_exception.h +++ b/mindspore/core/utils/ms_exception.h @@ -17,8 +17,14 @@ #ifndef MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_ #define MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_ #include +#include #include "utils/ms_utils.h" namespace mindspore { +class ExceptionListener { + public: + virtual void OnException() = 0; +}; + class MsException { public: static MsException &GetInstance() { @@ -26,7 +32,17 @@ class MsException { return instance; } - void SetException() { exception_ptr_ = std::current_exception(); } + void SetException() { + exception_ptr_ = std::current_exception(); + if (exception_ptr_ != nullptr) { + for (auto &listener : listeners_) { + if (listener == nullptr) { + continue; + } + listener->OnException(); + } + } + } void CheckException() { if (exception_ptr_ != nullptr) { @@ -36,11 +52,15 @@ class MsException { } } + void AddExceptionListener(ExceptionListener *listener) { (void)listeners_.insert(listener); } + + void RemoveExceptionListener(ExceptionListener *listener) { (void)listeners_.erase(listener); } + private: MsException() = default; ~MsException() = default; DISABLE_COPY_AND_ASSIGN(MsException) - + std::set listeners_; std::exception_ptr exception_ptr_{nullptr}; }; } // namespace mindspore