diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index b434b4dd08..06dbd6c0e2 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -15,8 +15,8 @@ */ #include "backend/session/executor.h" #include -#include "runtime/device/kernel_runtime_manager.h" #include "backend/session/executor_manager.h" +#include "runtime/device/kernel_runtime_manager.h" #include "utils/comm_manager.h" #include "utils/scoped_long_running.h" @@ -52,6 +52,19 @@ void UpdateOutputTensors(const VectorRef *outputs, tensor->set_device_address(nullptr); tensor->set_sync_status(kNeedSyncHostToDevice); } + } + } +} + +void NotifyOutputTensors(const VectorRef *outputs) { + MS_EXCEPTION_IF_NULL(outputs); + for (auto item : *outputs) { + if (utils::isa(item)) { + auto vector_ref = utils::cast(item); + NotifyOutputTensors(&vector_ref); + } else if (utils::isa(item)) { + auto tensor = utils::cast(item); + MS_EXCEPTION_IF_NULL(tensor); tensor->SetNeedWait(false); } } @@ -92,10 +105,12 @@ void RunGraphTask::Run() { MS_EXCEPTION_IF_NULL(session_); try { session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); + UpdateOutputTensors(&outputs_, tensor_to_node_); } catch (const std::exception &e) { MsException::GetInstance().SetException(); } - UpdateOutputTensors(&outputs_, tensor_to_node_); + + NotifyOutputTensors(&outputs_); for (auto &tensor : input_need_lock_tensors_) { tensor->SetNeedWait(false); } @@ -252,19 +267,19 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { MS_EXCEPTION_IF_NULL(session); MS_EXCEPTION_IF_NULL(outputs); - if (session != nullptr) { - RunGraph(session, graph_id, inputs, outputs); - return; - } auto task = std::make_shared(); task->session_ = session; task->graph_id_ = graph_id; task->input_tensors_ = inputs; task->input_need_lock_tensors_ = session->GetNeedLockInputTensors(graph_id, inputs); - // lock inputs for (auto &tensor : inputs) { if (tensor->NeedWait()) { - task->input_need_wait_tensors_.emplace_back(tensor); + if (tensor->IsGraphOutput()) { + task->input_need_wait_tensors_.emplace_back(tensor); + } else { + mindspore::ScopedLongRunning long_running; + tensor->Wait(); + } } } for (auto &tensor : task->input_need_lock_tensors_) { diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index bebe128a18..7a9b44ca66 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -78,6 +78,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); tensor->set_sync_status(kNoNeedSync); tensor->SetNeedWait(true); + tensor->SetIsGraphOutput(); return tensor; } @@ -102,6 +103,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o tensor->set_sync_status(kNeedSyncDeviceToHost); } tensor->SetNeedWait(true); + tensor->SetIsGraphOutput(); return tensor; } @@ -1041,7 +1043,7 @@ std::vector SessionBasic::GetNeedLockInputTensors(const Graph } std::vector result; for (auto &tensor : inputs) { - if (!tensor->NeedWait()) { + if (!tensor->IsGraphOutput()) { result.emplace_back(tensor); } } diff --git a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc index 7db7e22162..2d22a4424f 100644 --- a/mindspore/ccsrc/pybind_api/ir/tensor_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/tensor_py.cc @@ -258,6 +258,10 @@ py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) { } py::array TensorPy::SyncAsNumpy(const Tensor &tensor) { + if (tensor.NeedWait()) { + py::gil_scoped_release gil_release; + tensor.Wait(); + } tensor.data_sync(); return AsNumpy(tensor); } diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index e9ab2628e8..bb3e798861 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -176,6 +176,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k (void)bound_addresses_.insert(address); } tensor->SetNeedWait(true); + tensor->SetIsGraphOutput(); return tensor; } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 3d50e07624..26f39625a2 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -174,6 +174,7 @@ constexpr auto kStridedReadOpName = "StridedRead"; constexpr auto kStridedWriteOpName = "StridedWrite"; constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay"; constexpr auto kFusedAdamName = "FusedAdam"; +constexpr auto kFusedSparseAdamName = "FusedSparseAdam"; constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2"; constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2"; constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl"; @@ -385,6 +386,7 @@ const std::set kOptOperatorSet = { kApplyRMSPropOpName, kFusedAdamWeightDecayName, kFusedAdamName, + kFusedSparseAdamName, kFusedWeightScaleApplyMomentum, kFusedScaleApplyMomentum, kPullOpName, diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index dc7c0ec89c..25c89e7365 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -306,12 +306,16 @@ class Tensor : public MetaTensor { bool NeedSyncHostToDevice() const { return sync_status_ == kNeedSyncHostToDevice; } + bool IsGraphOutput() { return graph_output_; } + void SetIsGraphOutput() { graph_output_ = true; } + private: bool init_flag_{false}; TensorDataPtr data_{nullptr}; std::string id_{""}; mutable std::shared_ptr event_{nullptr}; mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice}; + bool graph_output_{false}; DeviceSyncPtr device_sync_{nullptr}; std::vector padding_type_; TypePtr cast_dtype_{nullptr};