From f279cd92ec994731f0a5cc87ac2c049db2e9df4b Mon Sep 17 00:00:00 2001 From: lvliang Date: Mon, 16 Nov 2020 16:43:21 +0800 Subject: [PATCH] optimize-time-of-getting-dynamic-info-in-single-op --- .../ccsrc/backend/session/ascend_session.cc | 7 +++++++ .../pipeline/pynative/pynative_execute.cc | 19 +++++++++++-------- .../pipeline/pynative/pynative_execute.h | 1 + .../ccsrc/runtime/device/kernel_runtime.cc | 7 ++++--- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 18a6102d9f..c61f521eea 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -565,6 +565,13 @@ void AscendSession::BuildKernel(const std::shared_ptr &kernel_graph void AscendSession::BuildDynamicKernel(const std::shared_ptr &kernel_graph) const { MS_LOG(INFO) << "Start!"; MS_EXCEPTION_IF_NULL(kernel_graph); + const auto &kernels = kernel_graph->execution_order(); + auto iter = std::find_if(kernels.begin(), kernels.end(), [](const CNodePtr &kernel) { + return AnfAlgo::GetKernelType(kernel) == AICPU_KERNEL && AnfAlgo::GetBooleanAttr(kernel, kAttrOutputIsDynamicShape); + }); + if (iter == kernels.end()) { + return; + } auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); if (!runtime_instance->GenDynamicKernel(kernel_graph.get())) { diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 7c122b77bb..8f5c0335d1 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -624,6 +624,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { if (abs_list.find(args_spec_list) != abs_list.end()) { MS_LOG(DEBUG) << "Match prim ok " << op_exec_info->op_name; op_exec_info->abstract = abs_list[args_spec_list].abs; + op_exec_info->is_dynamic_shape = abs_list[args_spec_list].is_dynamic_shape; prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); is_find = true; } @@ -634,19 +635,20 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); } + // get output dynamic shape info + auto abstract = op_exec_info->abstract; + MS_EXCEPTION_IF_NULL(abstract); + auto shape = abstract->BuildShape(); + MS_EXCEPTION_IF_NULL(shape); + auto shape_info = shape->ToString(); + if (shape_info.find("-1") != string::npos) { + op_exec_info->is_dynamic_shape = true; + } } - if (cnode != nullptr) { cnode->set_abstract(op_exec_info->abstract); } - // get output dynamic shape info - MS_EXCEPTION_IF_NULL(op_exec_info->abstract); - auto abstract_info = op_exec_info->abstract->ToString(); - if (abstract_info.find("-1") != string::npos) { - op_exec_info->is_dynamic_shape = true; - } - op_exec_info->inputs_mask = op_masks; MS_EXCEPTION_IF_NULL(op_exec_info); if (op_exec_info->abstract != nullptr) { @@ -668,6 +670,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { // const_value need infer every step auto &out = prim_abs_list_[prim->id()]; out[args_spec_list].abs = op_exec_info->abstract; + out[args_spec_list].is_dynamic_shape = op_exec_info->is_dynamic_shape; out[args_spec_list].attrs = prim->evaluate_added_attrs(); MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 7f1d6705db..a9886b55b9 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -46,6 +46,7 @@ using GradOperationPtr = std::shared_ptr; struct PrimAbsInfo { abstract::AbstractBasePtr abs; + bool is_dynamic_shape = false; std::unordered_map attrs; }; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index f2cb051ea1..5e5a24d201 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -831,8 +831,9 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { << " should be equal to the size of kernels " << kernels.size(); } for (size_t i = 0; i < kernels.size(); ++i) { + auto &kernel = kernels[i]; if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr && - dynamic_kernel_list[i]->is_dynamic_shape()) { + dynamic_kernel_list[i]->is_dynamic_shape() && AnfAlgo::GetKernelType(kernel) == AICPU_KERNEL) { dynamic_kernel_list[i]->InferShape(); dynamic_kernel_list[i]->UpdateArgs(); dynamic_kernel_list[i]->Execute(); @@ -842,12 +843,12 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { } dynamic_kernel_list[i]->PostExecute(); } else { - auto kernel_mod = AnfAlgo::GetKernelMod(kernels[i]); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); AddressPtrList kernel_inputs; AddressPtrList kernel_workspaces; AddressPtrList kernel_outputs; - GenLaunchArgs(*kernel_mod, kernels[i], &kernel_inputs, &kernel_workspaces, &kernel_outputs); + GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); if (!ret) { MS_LOG(ERROR) << "Launch kernel failed.";