Browse Source

!9932 throw exception immediately

From: @kisnwang
Reviewed-by: @zhoufeng54,@chujinjin
Signed-off-by: @chujinjin
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c7b801a8d5
2 changed files with 32 additions and 6 deletions
  1. +10
    -4
      mindspore/core/ir/tensor.h
  2. +22
    -2
      mindspore/core/utils/ms_exception.h

+ 10
- 4
mindspore/core/ir/tensor.h View File

@@ -78,18 +78,19 @@ class TensorData {


using TensorDataPtr = std::shared_ptr<TensorData>; using TensorDataPtr = std::shared_ptr<TensorData>;


struct WaitEvent {
bool need_wait_{false};
mutable std::mutex mutex_;
mutable std::condition_variable cond_var_;
class WaitEvent : public ExceptionListener {
public:
void OnException() override { set_need_wait(false); }


void Wait() const { void Wait() const {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (!need_wait_) { if (!need_wait_) {
return; return;
} }
MsException::GetInstance().AddExceptionListener(const_cast<WaitEvent *>(this));
cond_var_.wait(lock, [this] { return !need_wait_; }); cond_var_.wait(lock, [this] { return !need_wait_; });
MsException::GetInstance().CheckException(); MsException::GetInstance().CheckException();
MsException::GetInstance().RemoveExceptionListener(const_cast<WaitEvent *>(this));
} }


void set_need_wait(bool need_wait) { void set_need_wait(bool need_wait) {
@@ -101,6 +102,11 @@ struct WaitEvent {
} }


bool need_wait() const { return need_wait_; } bool need_wait() const { return need_wait_; }

private:
bool need_wait_{false};
mutable std::mutex mutex_;
mutable std::condition_variable cond_var_;
}; };


// Tensor entity class // Tensor entity class


+ 22
- 2
mindspore/core/utils/ms_exception.h View File

@@ -17,8 +17,14 @@
#ifndef MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_ #ifndef MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_
#define MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_ #define MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_
#include <exception> #include <exception>
#include <set>
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
namespace mindspore { namespace mindspore {
class ExceptionListener {
public:
virtual void OnException() = 0;
};

class MsException { class MsException {
public: public:
static MsException &GetInstance() { static MsException &GetInstance() {
@@ -26,7 +32,17 @@ class MsException {
return instance; return instance;
} }


void SetException() { exception_ptr_ = std::current_exception(); }
void SetException() {
exception_ptr_ = std::current_exception();
if (exception_ptr_ != nullptr) {
for (auto &listener : listeners_) {
if (listener == nullptr) {
continue;
}
listener->OnException();
}
}
}


void CheckException() { void CheckException() {
if (exception_ptr_ != nullptr) { if (exception_ptr_ != nullptr) {
@@ -36,11 +52,15 @@ class MsException {
} }
} }


void AddExceptionListener(ExceptionListener *listener) { (void)listeners_.insert(listener); }

void RemoveExceptionListener(ExceptionListener *listener) { (void)listeners_.erase(listener); }

private: private:
MsException() = default; MsException() = default;
~MsException() = default; ~MsException() = default;
DISABLE_COPY_AND_ASSIGN(MsException) DISABLE_COPY_AND_ASSIGN(MsException)
std::set<ExceptionListener *> listeners_;
std::exception_ptr exception_ptr_{nullptr}; std::exception_ptr exception_ptr_{nullptr};
}; };
} // namespace mindspore } // namespace mindspore


Loading…
Cancel
Save