diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index cdee3ba0f5..f7458ff024 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -700,20 +700,21 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; } -void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, +void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector *input_tensors, VectorRef *outputs, const std::vector &tensors_mask) { MS_EXCEPTION_IF_NULL(input_tensors); - BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); + MS_EXCEPTION_IF_NULL(op_run_info); + BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); EraseValueNodeTensor(tensors_mask, input_tensors); - + // Run op auto graph = run_op_graphs_[graph_info]; MS_EXCEPTION_IF_NULL(graph); - MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; + MS_LOG(INFO) << "Run op " << op_run_info->op_name << " start!"; // malloc mem RunOpMemoryAlloc(*input_tensors, graph.get()); // Build dynamic kernel - if (op_run_info.is_dynamic_shape) { + if (op_run_info->is_dynamic_shape) { BuildDynamicKernel(graph); } // load input data to device @@ -722,8 +723,12 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra Execute(graph, false); // get output UpdateOutputs(graph, outputs, *input_tensors); + // update output abstract of dynamic op to op_run_info + if (op_run_info->is_dynamic_shape) { + UpdateOutputAbstract(graph, op_run_info); + } RunOpMemoryClear(graph.get()); - MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; + MS_LOG(INFO) << "Run op " << op_run_info->op_name << " finish!"; } void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector &inputs, @@ -750,7 +755,7 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector // Build and run current single op VectorRef op_outputs; - RunOpImpl(run_info, graph_info, &input_tensor_info.input_tensors, &op_outputs, + RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs, input_tensor_info.input_tensors_mask); // Handle inputs and outputs of current op diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index 4c94beba98..c4e815c8f7 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -59,9 +59,8 @@ class AscendSession : public SessionBasic { void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask) override; - void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - std::vector *input_tensors, VectorRef *outputs, - const std::vector &tensors_mask) override; + void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector *input_tensors, + VectorRef *outputs, const std::vector &tensors_mask) override; void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index 9b789e2c25..2787116251 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -170,11 +170,12 @@ void CPUSession::SetOutputFlags(const VectorRef &base_ref, std::vector *input_tensors, VectorRef *outputs, const std::vector &tensors_mask) { MS_EXCEPTION_IF_NULL(input_tensors); - BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); + MS_EXCEPTION_IF_NULL(op_run_info); + BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); EraseValueNodeTensor(tensors_mask, input_tensors); auto kernel_graph = run_op_graphs_[graph_info]; diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h index 0412a0c672..376b6e6a20 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.h +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -41,9 +41,8 @@ class CPUSession : public SessionBasic { void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask) override; - void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - std::vector *input_tensors, VectorRef *outputs, - const std::vector &tensors_mask) override; + void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector *input_tensors, + VectorRef *outputs, const std::vector &tensors_mask) override; private: void SetKernelInfo(const KernelGraph *kernel_graph); diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 36ce68895e..96493d069a 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -130,7 +130,7 @@ void RunGraphTask::Run() { void RunOpTask::Run() { MS_EXCEPTION_IF_NULL(session_); - session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_, tensors_mask_); + session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_); } void RunOpsInGraphTask::Run() { diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 599345dcf8..1cc8d86454 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -403,13 +403,14 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap run_op_graphs_[graph_info] = kernel_graph; } -void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, +void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector *input_tensors, VectorRef *outputs, const std::vector &tensors_mask) { MS_EXCEPTION_IF_NULL(input_tensors); - BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); + MS_EXCEPTION_IF_NULL(op_run_info); + BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); EraseValueNodeTensor(tensors_mask, input_tensors); - + // run op auto kernel_graph = run_op_graphs_[graph_info]; MS_EXCEPTION_IF_NULL(kernel_graph); // Remove NopOp from execution graph @@ -420,6 +421,10 @@ void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_ Execute(kernel_graph); // Fetch outputs UpdateOutputs(kernel_graph, outputs, *input_tensors); + // update output abstract of dynamic op to op_run_info + if (op_run_info->is_dynamic_shape) { + UpdateOutputAbstract(kernel_graph, op_run_info); + } RunOpClearMemory(kernel_graph.get()); } diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index 67033f00b9..e3dc825c21 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -39,9 +39,8 @@ class GPUSession : public SessionBasic { void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask) override; - void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - std::vector *input_tensors, VectorRef *outputs, - const std::vector &tensors_mask) override; + void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector *input_tensors, + VectorRef *outputs, const std::vector &tensors_mask) override; private: void SelectKernel(const std::shared_ptr &kernel_graph) const; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index ab0db60055..490e4379c8 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1142,6 +1142,19 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr &kernel_grap } } +void SessionBasic::UpdateOutputAbstract(const std::shared_ptr &kernel_graph, + OpRunInfo *op_run_info) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(op_run_info); + const auto &kernels = kernel_graph->execution_order(); + for (const auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) { + op_run_info->abstract = kernel->abstract(); + } + } +} + std::vector SessionBasic::GetInputNeedLockTensors(const GraphId &graph_id, const std::vector &inputs) { auto graph = GetGraph(graph_id); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 620016d611..9c95464bc9 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -153,7 +153,7 @@ class SessionBasic : public std::enable_shared_from_this { virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, const std::vector &tensors_mask) {} - virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector *input_tensors, VectorRef *outputs, const std::vector &tensors_mask) {} virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector &inputs, @@ -167,6 +167,7 @@ class SessionBasic : public std::enable_shared_from_this { void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector *input_tensors); void UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, const std::vector &input_tensors) const; + void UpdateOutputAbstract(const std::shared_ptr &kernel_graph, OpRunInfo *op_run_info) const; void Reorder(std::vector *node_list); void Summary(KernelGraph *graph); // create graph output for RunOp diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index a89f06a75d..ed218e2193 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -70,10 +70,12 @@ const std::set ignore_infer_prim = {"make_ref", "mixed_precision_ca const std::set force_infer_prim = {"TopK", "DropoutGenMask"}; const std::set ignore_judge_dynamic_cell = { "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal", - "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask"}; + "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"}; const std::set unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, parse::NAMED_PRIMITIVE_NAMECONSTANT, parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; +const std::set dynamic_shape_const_input_to_attr = {"Cast", "ExpandDims", "Reshape", "EmbeddingLookup", + "Transpose"}; } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 9a19a69fa8..28ccb484fe 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -467,6 +467,11 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector opt::ConstInputToAttrInfoRegister reg; bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); + if (op_run_info->is_dynamic_shape && + dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) { + MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name; + reg_exist = false; + } if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { reg_exist = false; } @@ -594,6 +599,7 @@ py::tuple RunOp(const py::args &args) { } py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { + MS_EXCEPTION_IF_NULL(op_exec_info); if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { return RunOpWithInitBackendPolicy(op_exec_info); } @@ -604,58 +610,27 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { op_exec_info->inputs_mask = op_masks; // get output abstract info bool is_find = false; + GetOpOutputAbstract(op_exec_info, args_spec_list, &is_find); + MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString(); + // infer output value for const prim auto prim = op_exec_info->py_primitive; - if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { - auto abs_list = prim_abs_list_[prim->id()]; - MS_LOG(DEBUG) << "Match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); - if (abs_list.find(args_spec_list) != abs_list.end()) { - MS_LOG(DEBUG) << "Match prim ok " << op_exec_info->op_name; - op_exec_info->abstract = abs_list[args_spec_list].abs; - op_exec_info->is_dynamic_shape = abs_list[args_spec_list].is_dynamic_shape; - prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); - is_find = true; - } - } - if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_exec_info->op_name) != force_infer_prim.end()) { - // use python infer method - if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { - PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); - } - // get output dynamic shape info - auto abstract = op_exec_info->abstract; - MS_EXCEPTION_IF_NULL(abstract); - auto shape = abstract->BuildShape(); - MS_EXCEPTION_IF_NULL(shape); - auto shape_info = shape->ToString(); - if (shape_info.find("-1") != string::npos) { - op_exec_info->is_dynamic_shape = true; - } - } - if (cnode != nullptr) { - cnode->set_abstract(op_exec_info->abstract); + MS_EXCEPTION_IF_NULL(prim); + py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); + if (!output["value"].is_none()) { + py::tuple value_ret(1); + value_ret[0] = output["value"]; + return value_ret; } - // infer output value for const prim - MS_EXCEPTION_IF_NULL(op_exec_info); - if (op_exec_info->abstract != nullptr) { - MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString(); - py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); - if (!output["value"].is_none()) { - py::tuple value_ret(1); - value_ret[0] = output["value"]; - return value_ret; - } - if (op_exec_info->py_primitive->is_const_prim()) { - py::tuple value_ret(1); - value_ret[0] = ""; - return value_ret; - } + if (prim->is_const_prim()) { + py::tuple value_ret(1); + value_ret[0] = ""; + return value_ret; } // add output abstract info into cache - if (!is_find) { + if (!is_find && !op_exec_info->is_dynamic_shape) { // const_value need infer every step auto &out = prim_abs_list_[prim->id()]; out[args_spec_list].abs = op_exec_info->abstract; - out[args_spec_list].is_dynamic_shape = op_exec_info->is_dynamic_shape; out[args_spec_list].attrs = prim->evaluate_added_attrs(); MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); } @@ -666,8 +641,13 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { MS_LOG(DEBUG) << "Output size is 1"; out_real = result[0]; } + // update output abstract for cnode + if (cnode != nullptr) { + cnode->set_abstract(op_exec_info->abstract); + } std::string obj_id = GetId(out_real); node_abs_map_[obj_id] = op_exec_info->abstract; + // save info for building grad graph SaveOutputNodeMap(obj_id, out_real, cnode); SaveAllResult(op_exec_info, cnode, out_real); // Update the abstract and device address of value node with tensor in grad graph @@ -784,6 +764,49 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v return cnode; } +void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, + const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) { + MS_EXCEPTION_IF_NULL(is_find); + MS_EXCEPTION_IF_NULL(op_exec_info); + *is_find = false; + auto op_name = op_exec_info->op_name; + auto prim = op_exec_info->py_primitive; + MS_EXCEPTION_IF_NULL(prim); + if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { + auto abs_list = prim_abs_list_[prim->id()]; + MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list); + if (abs_list.find(args_spec_list) != abs_list.end()) { + MS_LOG(DEBUG) << "Match prim ok " << op_name; + op_exec_info->abstract = abs_list[args_spec_list].abs; + prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); + *is_find = true; + } + } + if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) { + // use python infer method + if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) { + PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); + } + } + // get output dynamic shape info + auto py_abstract = op_exec_info->abstract; + MS_EXCEPTION_IF_NULL(py_abstract); + auto py_shape = py_abstract->BuildShape(); + MS_EXCEPTION_IF_NULL(py_shape); + auto py_shape_info = py_shape->ToString(); + if (py_shape_info.find("-1") != string::npos) { + auto c_abstract = abstract::CppInferShape(prim, args_spec_list); + MS_EXCEPTION_IF_NULL(c_abstract); + auto c_shape = c_abstract->BuildShape(); + MS_EXCEPTION_IF_NULL(c_shape); + auto c_shape_info = c_shape->ToString(); + MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info; + if (c_shape_info.find("-1") != string::npos) { + op_exec_info->is_dynamic_shape = true; + } + } +} + py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index) { py::tuple cast_args(3); @@ -1326,6 +1349,9 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati op_exec_info->next_input_index}; VectorRef outputs; session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask); + if (op_exec_info->is_dynamic_shape) { + op_exec_info->abstract = op_run_info.abstract; + } auto result = BaseRefToPyData(outputs); ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); *status = PYNATIVE_SUCCESS; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index bb9c7e5ffc..eea958b5e5 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -129,6 +129,8 @@ class PynativeExecutor : public std::enable_shared_from_this { AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, abstract::AbstractBasePtrList *args_spec_list); + void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, + bool *is_find); void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode); // replace for grad graph diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 88ed305b5b..c1a80a78e7 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -577,6 +577,7 @@ void AscendKernelRuntime::DumpTaskExceptionInfo(const session::KernelGraph *grap } bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { + MS_EXCEPTION_IF_NULL(graph); bool ret = false; #if defined(_WIN32) || defined(_WIN64) auto start_time = std::chrono::steady_clock::now(); diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 1e741e0196..9fbe40a7c9 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -336,6 +336,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { } bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { + MS_EXCEPTION_IF_NULL(graph); struct timeval start_time, end_time; (void)gettimeofday(&start_time, nullptr); bool ret = true; @@ -360,7 +361,12 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { ret = RunOneStep(graph); } else { - ret = LaunchKernel(graph); + if (graph->is_dynamic_shape()) { + // run dynamic shape graph in pynative + ret = RunOpLaunchKernelDynamic(graph); + } else { + ret = LaunchKernel(graph); + } } (void)gettimeofday(&end_time, nullptr); const uint64_t kUSecondInSecond = 1000000; @@ -674,6 +680,42 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo return true; } +bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + const auto &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + MS_EXCEPTION_IF_NULL(kernel); + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + // akg kernel do not support dynamic shape by now. + device::DynamicKernelPtr dynamic_kernel = nullptr; + kernel::GpuKernel *gpu_kernel = nullptr; + if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) != KernelType::AKG_KERNEL) { + gpu_kernel = dynamic_cast(kernel_mod); + dynamic_kernel = gpu_kernel->DynamicKernel(); + } + // pre-processing for dynamic shape kernel + if (dynamic_kernel && dynamic_kernel->is_dynamic_shape()) { + dynamic_kernel->InferShape(); + dynamic_kernel->UpdateArgs(); + } + // alloc kernel res + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); + if (!ret) { + MS_LOG(ERROR) << "Launch kernel failed."; + return false; + } + if (gpu_kernel && dynamic_kernel && dynamic_kernel->is_dynamic_shape()) { + gpu_kernel->PostExecute(); + } + } + return true; +} + void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, const AddressPtrList &workspace, const AddressPtrList &outputs) { auto kernel_mod = AnfAlgo::GetKernelMod(kernel); diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index b5c80e68f1..83b80ce0b0 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -73,6 +73,7 @@ class GPUKernelRuntime : public KernelRuntime { bool SearchMemSwapScheme(const session::KernelGraph *graph); bool RefineMemSwapScheme(const session::KernelGraph *graph); bool LaunchKernelDynamic(const session::KernelGraph *graph, bool mock = false, bool profiling = false); + bool RunOpLaunchKernelDynamic(const session::KernelGraph *graph); void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, const AddressPtrList &workspace, const AddressPtrList &outputs); bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index d0cb8d6938..057dc2bb3a 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -875,7 +875,7 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList } bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { - auto &kernels = graph.execution_order(); + const auto &kernels = graph.execution_order(); std::vector dynamic_kernel_list; auto iter = graph_dynamic_kernel_map_.find(graph.graph_id()); if (iter != graph_dynamic_kernel_map_.end()) {