| @@ -40,7 +40,7 @@ void UpdateOutputTensors(VectorRef *outputs, | |||||
| } | } | ||||
| if (tensor->NeedSyncDeviceToHostImmediately()) { | if (tensor->NeedSyncDeviceToHostImmediately()) { | ||||
| tensor->data_sync(); | tensor->data_sync(); | ||||
| tensor->set_sync_status(kNoNeedSync); | |||||
| tensor->set_device_address(nullptr); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -112,7 +112,9 @@ Executor::Executor(const std::string &device_name, uint32_t device_id) { | |||||
| void Executor::CheckException() { | void Executor::CheckException() { | ||||
| if (exception_ptr_ != nullptr) { | if (exception_ptr_ != nullptr) { | ||||
| std::rethrow_exception(exception_ptr_); | |||||
| auto exception_ptr = exception_ptr_; | |||||
| exception_ptr_ = nullptr; | |||||
| std::rethrow_exception(exception_ptr); | |||||
| } | } | ||||
| } | } | ||||
| @@ -550,12 +550,14 @@ std::string Tensor::ToStringRepr() const { | |||||
| } | } | ||||
| void Tensor::data_sync() const { | void Tensor::data_sync() const { | ||||
| const_cast<Tensor *>(this)->Wait(); | |||||
| if (device_sync_ != nullptr) { | |||||
| if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) { | |||||
| MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; | |||||
| } | |||||
| Wait(); | |||||
| if (device_sync_ == nullptr) { | |||||
| return; | |||||
| } | |||||
| if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) { | |||||
| MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; | |||||
| } | } | ||||
| sync_status_ = kNeedSyncHostToDevice; | |||||
| } | } | ||||
| TypeId Tensor::set_data_type(const TypeId data_type) { | TypeId Tensor::set_data_type(const TypeId data_type) { | ||||
| @@ -79,10 +79,10 @@ using TensorDataPtr = std::shared_ptr<TensorData>; | |||||
| struct WaitEvent { | struct WaitEvent { | ||||
| bool need_wait_{false}; | bool need_wait_{false}; | ||||
| std::mutex mutex_; | |||||
| std::condition_variable cond_var_; | |||||
| mutable std::mutex mutex_; | |||||
| mutable std::condition_variable cond_var_; | |||||
| void Wait() { | |||||
| void Wait() const { | |||||
| std::unique_lock<std::mutex> lock(mutex_); | std::unique_lock<std::mutex> lock(mutex_); | ||||
| if (!need_wait_) { | if (!need_wait_) { | ||||
| return; | return; | ||||
| @@ -285,7 +285,7 @@ class Tensor : public MetaTensor { | |||||
| return false; | return false; | ||||
| } | } | ||||
| void Wait() { | |||||
| void Wait() const { | |||||
| if (event_ != nullptr) { | if (event_ != nullptr) { | ||||
| event_->Wait(); | event_->Wait(); | ||||
| } | } | ||||
| @@ -307,7 +307,7 @@ class Tensor : public MetaTensor { | |||||
| TensorDataPtr data_{nullptr}; | TensorDataPtr data_{nullptr}; | ||||
| std::string id_{""}; | std::string id_{""}; | ||||
| std::shared_ptr<WaitEvent> event_{nullptr}; | std::shared_ptr<WaitEvent> event_{nullptr}; | ||||
| TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; | |||||
| mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; | |||||
| DeviceSyncPtr device_sync_{nullptr}; | DeviceSyncPtr device_sync_{nullptr}; | ||||
| std::vector<Axis> padding_type_; | std::vector<Axis> padding_type_; | ||||
| }; | }; | ||||