Browse Source

backend executor reset throw exception

tags/v1.0.0
kswang 5 years ago
parent
commit
7b3ef5741a
3 changed files with 16 additions and 12 deletions
  1. +4
    -2
      mindspore/ccsrc/backend/session/executor.cc
  2. +7
    -5
      mindspore/core/ir/tensor.cc
  3. +5
    -5
      mindspore/core/ir/tensor.h

+ 4
- 2
mindspore/ccsrc/backend/session/executor.cc View File

@@ -40,7 +40,7 @@ void UpdateOutputTensors(VectorRef *outputs,
}
if (tensor->NeedSyncDeviceToHostImmediately()) {
tensor->data_sync();
tensor->set_sync_status(kNoNeedSync);
tensor->set_device_address(nullptr);
}
}
}
@@ -112,7 +112,9 @@ Executor::Executor(const std::string &device_name, uint32_t device_id) {

void Executor::CheckException() {
if (exception_ptr_ != nullptr) {
std::rethrow_exception(exception_ptr_);
auto exception_ptr = exception_ptr_;
exception_ptr_ = nullptr;
std::rethrow_exception(exception_ptr);
}
}



+ 7
- 5
mindspore/core/ir/tensor.cc View File

@@ -550,12 +550,14 @@ std::string Tensor::ToStringRepr() const {
}

void Tensor::data_sync() const {
const_cast<Tensor *>(this)->Wait();
if (device_sync_ != nullptr) {
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
}
Wait();
if (device_sync_ == nullptr) {
return;
}
if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) {
MS_LOG(EXCEPTION) << "SyncDeviceToHost failed.";
}
sync_status_ = kNeedSyncHostToDevice;
}

TypeId Tensor::set_data_type(const TypeId data_type) {


+ 5
- 5
mindspore/core/ir/tensor.h View File

@@ -79,10 +79,10 @@ using TensorDataPtr = std::shared_ptr<TensorData>;

struct WaitEvent {
bool need_wait_{false};
std::mutex mutex_;
std::condition_variable cond_var_;
mutable std::mutex mutex_;
mutable std::condition_variable cond_var_;

void Wait() {
void Wait() const {
std::unique_lock<std::mutex> lock(mutex_);
if (!need_wait_) {
return;
@@ -285,7 +285,7 @@ class Tensor : public MetaTensor {
return false;
}

void Wait() {
void Wait() const {
if (event_ != nullptr) {
event_->Wait();
}
@@ -307,7 +307,7 @@ class Tensor : public MetaTensor {
TensorDataPtr data_{nullptr};
std::string id_{""};
std::shared_ptr<WaitEvent> event_{nullptr};
TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
DeviceSyncPtr device_sync_{nullptr};
std::vector<Axis> padding_type_;
};


Loading…
Cancel
Save