Browse Source

enable async run

tags/v1.1.0
kswang 5 years ago
parent
commit
ece27f313e
6 changed files with 37 additions and 9 deletions
  1. +23
    -8
      mindspore/ccsrc/backend/session/executor.cc
  2. +3
    -1
      mindspore/ccsrc/backend/session/session_basic.cc
  3. +4
    -0
      mindspore/ccsrc/pybind_api/ir/tensor_py.cc
  4. +1
    -0
      mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc
  5. +2
    -0
      mindspore/ccsrc/utils/utils.h
  6. +4
    -0
      mindspore/core/ir/tensor.h

+ 23
- 8
mindspore/ccsrc/backend/session/executor.cc View File

@@ -15,8 +15,8 @@
*/
#include "backend/session/executor.h"
#include <exception>
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/comm_manager.h"
#include "utils/scoped_long_running.h"

@@ -52,6 +52,19 @@ void UpdateOutputTensors(const VectorRef *outputs,
tensor->set_device_address(nullptr);
tensor->set_sync_status(kNeedSyncHostToDevice);
}
}
}
}

void NotifyOutputTensors(const VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item);
NotifyOutputTensors(&vector_ref);
} else if (utils::isa<tensor::TensorPtr>(item)) {
auto tensor = utils::cast<tensor::TensorPtr>(item);
MS_EXCEPTION_IF_NULL(tensor);
tensor->SetNeedWait(false);
}
}
@@ -92,10 +105,12 @@ void RunGraphTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
try {
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
UpdateOutputTensors(&outputs_, tensor_to_node_);
} catch (const std::exception &e) {
MsException::GetInstance().SetException();
}
UpdateOutputTensors(&outputs_, tensor_to_node_);

NotifyOutputTensors(&outputs_);
for (auto &tensor : input_need_lock_tensors_) {
tensor->SetNeedWait(false);
}
@@ -252,19 +267,19 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(session);
MS_EXCEPTION_IF_NULL(outputs);
if (session != nullptr) {
RunGraph(session, graph_id, inputs, outputs);
return;
}
auto task = std::make_shared<RunGraphTask>();
task->session_ = session;
task->graph_id_ = graph_id;
task->input_tensors_ = inputs;
task->input_need_lock_tensors_ = session->GetNeedLockInputTensors(graph_id, inputs);
// lock inputs
for (auto &tensor : inputs) {
if (tensor->NeedWait()) {
task->input_need_wait_tensors_.emplace_back(tensor);
if (tensor->IsGraphOutput()) {
task->input_need_wait_tensors_.emplace_back(tensor);
} else {
mindspore::ScopedLongRunning long_running;
tensor->Wait();
}
}
}
for (auto &tensor : task->input_need_lock_tensors_) {


+ 3
- 1
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -78,6 +78,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
tensor->set_sync_status(kNoNeedSync);
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
return tensor;
}

@@ -102,6 +103,7 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
tensor->set_sync_status(kNeedSyncDeviceToHost);
}
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
return tensor;
}

@@ -1041,7 +1043,7 @@ std::vector<tensor::TensorPtr> SessionBasic::GetNeedLockInputTensors(const Graph
}
std::vector<tensor::TensorPtr> result;
for (auto &tensor : inputs) {
if (!tensor->NeedWait()) {
if (!tensor->IsGraphOutput()) {
result.emplace_back(tensor);
}
}


+ 4
- 0
mindspore/ccsrc/pybind_api/ir/tensor_py.cc View File

@@ -258,6 +258,10 @@ py::tuple TensorPy::GetPyTupleShape(const Tensor &tensor) {
}

py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
if (tensor.NeedWait()) {
py::gil_scoped_release gil_release;
tensor.Wait();
}
tensor.data_sync();
return AsNumpy(tensor);
}


+ 1
- 0
mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc View File

@@ -176,6 +176,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k
(void)bound_addresses_.insert(address);
}
tensor->SetNeedWait(true);
tensor->SetIsGraphOutput();
return tensor;
}



+ 2
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -174,6 +174,7 @@ constexpr auto kStridedReadOpName = "StridedRead";
constexpr auto kStridedWriteOpName = "StridedWrite";
constexpr auto kFusedAdamWeightDecayName = "FusedAdamWeightDecay";
constexpr auto kFusedAdamName = "FusedAdam";
constexpr auto kFusedSparseAdamName = "FusedSparseAdam";
constexpr auto kApplyAdagradV2OpName = "ApplyAdagradV2";
constexpr auto kSparseApplyAdagradV2OpName = "SparseApplyAdagradV2";
constexpr auto kSparseApplyFtrlOpName = "SparseApplyFtrl";
@@ -385,6 +386,7 @@ const std::set<std::string> kOptOperatorSet = {
kApplyRMSPropOpName,
kFusedAdamWeightDecayName,
kFusedAdamName,
kFusedSparseAdamName,
kFusedWeightScaleApplyMomentum,
kFusedScaleApplyMomentum,
kPullOpName,


+ 4
- 0
mindspore/core/ir/tensor.h View File

@@ -306,12 +306,16 @@ class Tensor : public MetaTensor {

bool NeedSyncHostToDevice() const { return sync_status_ == kNeedSyncHostToDevice; }

bool IsGraphOutput() { return graph_output_; }
void SetIsGraphOutput() { graph_output_ = true; }

private:
bool init_flag_{false};
TensorDataPtr data_{nullptr};
std::string id_{""};
mutable std::shared_ptr<WaitEvent> event_{nullptr};
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
bool graph_output_{false};
DeviceSyncPtr device_sync_{nullptr};
std::vector<Axis> padding_type_;
TypePtr cast_dtype_{nullptr};


Loading…
Cancel
Save