| @@ -42,6 +42,7 @@ void UpdateOutputTensors(const VectorRef *outputs, | |||||
| if (tensor->NeedSyncDeviceToHostImmediately()) { | if (tensor->NeedSyncDeviceToHostImmediately()) { | ||||
| tensor->data_sync(); | tensor->data_sync(); | ||||
| tensor->set_device_address(nullptr); | tensor->set_device_address(nullptr); | ||||
| tensor->set_sync_status(kNeedSyncHostToDevice); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -248,7 +248,6 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const | |||||
| tensor->data_c())) { | tensor->data_c())) { | ||||
| MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!"; | MS_LOG(EXCEPTION) << "Parameter node sync host to device failed!"; | ||||
| } | } | ||||
| tensor->set_sync_status(kNeedSyncHostToDevice); | |||||
| } | } | ||||
| address->ref_count_ = INIT_NODE_REF; | address->ref_count_ = INIT_NODE_REF; | ||||
| tensor->set_device_address(address); | tensor->set_device_address(address); | ||||
| @@ -557,7 +557,6 @@ void Tensor::data_sync() const { | |||||
| if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) { | if (!device_sync_->SyncDeviceToHost(shape(), static_cast<size_t>(data().nbytes()), data_type(), data_c())) { | ||||
| MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; | MS_LOG(EXCEPTION) << "SyncDeviceToHost failed."; | ||||
| } | } | ||||
| sync_status_ = kNeedSyncHostToDevice; | |||||
| } | } | ||||
| TypeId Tensor::set_data_type(const TypeId data_type) { | TypeId Tensor::set_data_type(const TypeId data_type) { | ||||
| @@ -289,7 +289,7 @@ class Tensor : public MetaTensor { | |||||
| if (event_ != nullptr) { | if (event_ != nullptr) { | ||||
| event_->Wait(); | event_->Wait(); | ||||
| } | } | ||||
| event_ == nullptr; | |||||
| event_ = nullptr; | |||||
| } | } | ||||
| void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; } | void set_sync_status(TensorSyncStatus sync_status) { sync_status_ = sync_status; } | ||||
| @@ -306,7 +306,7 @@ class Tensor : public MetaTensor { | |||||
| bool init_flag_{false}; | bool init_flag_{false}; | ||||
| TensorDataPtr data_{nullptr}; | TensorDataPtr data_{nullptr}; | ||||
| std::string id_{""}; | std::string id_{""}; | ||||
| std::shared_ptr<WaitEvent> event_{nullptr}; | |||||
| mutable std::shared_ptr<WaitEvent> event_{nullptr}; | |||||
| mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; | mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; | ||||
| DeviceSyncPtr device_sync_{nullptr}; | DeviceSyncPtr device_sync_{nullptr}; | ||||
| std::vector<Axis> padding_type_; | std::vector<Axis> padding_type_; | ||||