Browse Source

!9932 throw exception immediately

From: @kisnwang
Reviewed-by: @zhoufeng54,@chujinjin
Signed-off-by: @chujinjin
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
c7b801a8d5
2 changed files with 32 additions and 6 deletions
  1. +10
    -4
      mindspore/core/ir/tensor.h
  2. +22
    -2
      mindspore/core/utils/ms_exception.h

+ 10
- 4
mindspore/core/ir/tensor.h View File

@@ -78,18 +78,19 @@ class TensorData {

using TensorDataPtr = std::shared_ptr<TensorData>;

struct WaitEvent {
bool need_wait_{false};
mutable std::mutex mutex_;
mutable std::condition_variable cond_var_;
class WaitEvent : public ExceptionListener {
public:
void OnException() override { set_need_wait(false); }

void Wait() const {
std::unique_lock<std::mutex> lock(mutex_);
if (!need_wait_) {
return;
}
MsException::GetInstance().AddExceptionListener(const_cast<WaitEvent *>(this));
cond_var_.wait(lock, [this] { return !need_wait_; });
MsException::GetInstance().CheckException();
MsException::GetInstance().RemoveExceptionListener(const_cast<WaitEvent *>(this));
}

void set_need_wait(bool need_wait) {
@@ -101,6 +102,11 @@ struct WaitEvent {
}

bool need_wait() const { return need_wait_; }

private:
bool need_wait_{false};
mutable std::mutex mutex_;
mutable std::condition_variable cond_var_;
};

// Tensor entity class


+ 22
- 2
mindspore/core/utils/ms_exception.h View File

@@ -17,8 +17,14 @@
#ifndef MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_
#define MINDSPORE_CORE_UTILS_MS_EXCEPTION_H_
#include <exception>
#include <set>
#include "utils/ms_utils.h"
namespace mindspore {
class ExceptionListener {
public:
virtual void OnException() = 0;
};

class MsException {
public:
static MsException &GetInstance() {
@@ -26,7 +32,17 @@ class MsException {
return instance;
}

void SetException() { exception_ptr_ = std::current_exception(); }
void SetException() {
exception_ptr_ = std::current_exception();
if (exception_ptr_ != nullptr) {
for (auto &listener : listeners_) {
if (listener == nullptr) {
continue;
}
listener->OnException();
}
}
}

void CheckException() {
if (exception_ptr_ != nullptr) {
@@ -36,11 +52,15 @@ class MsException {
}
}

void AddExceptionListener(ExceptionListener *listener) { (void)listeners_.insert(listener); }

void RemoveExceptionListener(ExceptionListener *listener) { (void)listeners_.erase(listener); }

private:
MsException() = default;
~MsException() = default;
DISABLE_COPY_AND_ASSIGN(MsException)
std::set<ExceptionListener *> listeners_;
std::exception_ptr exception_ptr_{nullptr};
};
} // namespace mindspore


Loading…
Cancel
Save