| @@ -700,20 +700,21 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g | |||
| MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; | |||
| } | |||
| void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||
| const std::vector<int64_t> &tensors_mask) { | |||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||
| BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); | |||
| MS_EXCEPTION_IF_NULL(op_run_info); | |||
| BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); | |||
| EraseValueNodeTensor(tensors_mask, input_tensors); | |||
| // Run op | |||
| auto graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; | |||
| MS_LOG(INFO) << "Run op " << op_run_info->op_name << " start!"; | |||
| // malloc mem | |||
| RunOpMemoryAlloc(*input_tensors, graph.get()); | |||
| // Build dynamic kernel | |||
| if (op_run_info.is_dynamic_shape) { | |||
| if (op_run_info->is_dynamic_shape) { | |||
| BuildDynamicKernel(graph); | |||
| } | |||
| // load input data to device | |||
| @@ -722,8 +723,12 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra | |||
| Execute(graph, false); | |||
| // get output | |||
| UpdateOutputs(graph, outputs, *input_tensors); | |||
| // update output abstract of dynamic op to op_run_info | |||
| if (op_run_info->is_dynamic_shape) { | |||
| UpdateOutputAbstract(graph, op_run_info); | |||
| } | |||
| RunOpMemoryClear(graph.get()); | |||
| MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; | |||
| MS_LOG(INFO) << "Run op " << op_run_info->op_name << " finish!"; | |||
| } | |||
| void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| @@ -750,7 +755,7 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector | |||
| // Build and run current single op | |||
| VectorRef op_outputs; | |||
| RunOpImpl(run_info, graph_info, &input_tensor_info.input_tensors, &op_outputs, | |||
| RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs, | |||
| input_tensor_info.input_tensors_mask); | |||
| // Handle inputs and outputs of current op | |||
| @@ -59,9 +59,8 @@ class AscendSession : public SessionBasic { | |||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int64_t> &tensors_mask) override; | |||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||
| const std::vector<int64_t> &tensors_mask) override; | |||
| void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, | |||
| VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | |||
| void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| VectorRef *outputs) override; | |||
| @@ -170,11 +170,12 @@ void CPUSession::SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::T | |||
| } | |||
| } | |||
| void CPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| void CPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||
| const std::vector<int64_t> &tensors_mask) { | |||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||
| BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); | |||
| MS_EXCEPTION_IF_NULL(op_run_info); | |||
| BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); | |||
| EraseValueNodeTensor(tensors_mask, input_tensors); | |||
| auto kernel_graph = run_op_graphs_[graph_info]; | |||
| @@ -41,9 +41,8 @@ class CPUSession : public SessionBasic { | |||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int64_t> &tensors_mask) override; | |||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||
| const std::vector<int64_t> &tensors_mask) override; | |||
| void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, | |||
| VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | |||
| private: | |||
| void SetKernelInfo(const KernelGraph *kernel_graph); | |||
| @@ -130,7 +130,7 @@ void RunGraphTask::Run() { | |||
| void RunOpTask::Run() { | |||
| MS_EXCEPTION_IF_NULL(session_); | |||
| session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_, tensors_mask_); | |||
| session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_); | |||
| } | |||
| void RunOpsInGraphTask::Run() { | |||
| @@ -403,13 +403,14 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap | |||
| run_op_graphs_[graph_info] = kernel_graph; | |||
| } | |||
| void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||
| const std::vector<int64_t> &tensors_mask) { | |||
| MS_EXCEPTION_IF_NULL(input_tensors); | |||
| BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); | |||
| MS_EXCEPTION_IF_NULL(op_run_info); | |||
| BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); | |||
| EraseValueNodeTensor(tensors_mask, input_tensors); | |||
| // run op | |||
| auto kernel_graph = run_op_graphs_[graph_info]; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| // Remove NopOp from execution graph | |||
| @@ -420,6 +421,10 @@ void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_ | |||
| Execute(kernel_graph); | |||
| // Fetch outputs | |||
| UpdateOutputs(kernel_graph, outputs, *input_tensors); | |||
| // update output abstract of dynamic op to op_run_info | |||
| if (op_run_info->is_dynamic_shape) { | |||
| UpdateOutputAbstract(kernel_graph, op_run_info); | |||
| } | |||
| RunOpClearMemory(kernel_graph.get()); | |||
| } | |||
| @@ -39,9 +39,8 @@ class GPUSession : public SessionBasic { | |||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int64_t> &tensors_mask) override; | |||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||
| const std::vector<int64_t> &tensors_mask) override; | |||
| void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, | |||
| VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | |||
| private: | |||
| void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||
| @@ -1142,6 +1142,19 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| } | |||
| } | |||
| void SessionBasic::UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| OpRunInfo *op_run_info) const { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(op_run_info); | |||
| const auto &kernels = kernel_graph->execution_order(); | |||
| for (const auto &kernel : kernels) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) { | |||
| op_run_info->abstract = kernel->abstract(); | |||
| } | |||
| } | |||
| } | |||
| std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const GraphId &graph_id, | |||
| const std::vector<tensor::TensorPtr> &inputs) { | |||
| auto graph = GetGraph(graph_id); | |||
| @@ -153,7 +153,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||
| const std::vector<int64_t> &tensors_mask) {} | |||
| virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||
| virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, | |||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||
| const std::vector<int64_t> &tensors_mask) {} | |||
| virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| @@ -167,6 +167,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors); | |||
| void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs, | |||
| const std::vector<tensor::TensorPtr> &input_tensors) const; | |||
| void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const; | |||
| void Reorder(std::vector<CNodePtr> *node_list); | |||
| void Summary(KernelGraph *graph); | |||
| // create graph output for RunOp | |||
| @@ -70,10 +70,12 @@ const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_ca | |||
| const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"}; | |||
| const std::set<std::string> ignore_judge_dynamic_cell = { | |||
| "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal", | |||
| "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask"}; | |||
| "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"}; | |||
| const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, | |||
| parse::NAMED_PRIMITIVE_NAMECONSTANT, | |||
| parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; | |||
| const std::set<std::string> dynamic_shape_const_input_to_attr = {"Cast", "ExpandDims", "Reshape", "EmbeddingLookup", | |||
| "Transpose"}; | |||
| } // namespace pynative | |||
| } // namespace mindspore | |||
| @@ -467,6 +467,11 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t> | |||
| opt::ConstInputToAttrInfoRegister reg; | |||
| bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); | |||
| if (op_run_info->is_dynamic_shape && | |||
| dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) { | |||
| MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name; | |||
| reg_exist = false; | |||
| } | |||
| if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { | |||
| reg_exist = false; | |||
| } | |||
| @@ -594,6 +599,7 @@ py::tuple RunOp(const py::args &args) { | |||
| } | |||
| py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||
| if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { | |||
| return RunOpWithInitBackendPolicy(op_exec_info); | |||
| } | |||
| @@ -604,58 +610,27 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||
| op_exec_info->inputs_mask = op_masks; | |||
| // get output abstract info | |||
| bool is_find = false; | |||
| GetOpOutputAbstract(op_exec_info, args_spec_list, &is_find); | |||
| MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString(); | |||
| // infer output value for const prim | |||
| auto prim = op_exec_info->py_primitive; | |||
| if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { | |||
| auto abs_list = prim_abs_list_[prim->id()]; | |||
| MS_LOG(DEBUG) << "Match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | |||
| if (abs_list.find(args_spec_list) != abs_list.end()) { | |||
| MS_LOG(DEBUG) << "Match prim ok " << op_exec_info->op_name; | |||
| op_exec_info->abstract = abs_list[args_spec_list].abs; | |||
| op_exec_info->is_dynamic_shape = abs_list[args_spec_list].is_dynamic_shape; | |||
| prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); | |||
| is_find = true; | |||
| } | |||
| } | |||
| if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_exec_info->op_name) != force_infer_prim.end()) { | |||
| // use python infer method | |||
| if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { | |||
| PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); | |||
| } | |||
| // get output dynamic shape info | |||
| auto abstract = op_exec_info->abstract; | |||
| MS_EXCEPTION_IF_NULL(abstract); | |||
| auto shape = abstract->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(shape); | |||
| auto shape_info = shape->ToString(); | |||
| if (shape_info.find("-1") != string::npos) { | |||
| op_exec_info->is_dynamic_shape = true; | |||
| } | |||
| } | |||
| if (cnode != nullptr) { | |||
| cnode->set_abstract(op_exec_info->abstract); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); | |||
| if (!output["value"].is_none()) { | |||
| py::tuple value_ret(1); | |||
| value_ret[0] = output["value"]; | |||
| return value_ret; | |||
| } | |||
| // infer output value for const prim | |||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||
| if (op_exec_info->abstract != nullptr) { | |||
| MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString(); | |||
| py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); | |||
| if (!output["value"].is_none()) { | |||
| py::tuple value_ret(1); | |||
| value_ret[0] = output["value"]; | |||
| return value_ret; | |||
| } | |||
| if (op_exec_info->py_primitive->is_const_prim()) { | |||
| py::tuple value_ret(1); | |||
| value_ret[0] = ""; | |||
| return value_ret; | |||
| } | |||
| if (prim->is_const_prim()) { | |||
| py::tuple value_ret(1); | |||
| value_ret[0] = ""; | |||
| return value_ret; | |||
| } | |||
| // add output abstract info into cache | |||
| if (!is_find) { | |||
| if (!is_find && !op_exec_info->is_dynamic_shape) { | |||
| // const_value need infer every step | |||
| auto &out = prim_abs_list_[prim->id()]; | |||
| out[args_spec_list].abs = op_exec_info->abstract; | |||
| out[args_spec_list].is_dynamic_shape = op_exec_info->is_dynamic_shape; | |||
| out[args_spec_list].attrs = prim->evaluate_added_attrs(); | |||
| MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | |||
| } | |||
| @@ -666,8 +641,13 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||
| MS_LOG(DEBUG) << "Output size is 1"; | |||
| out_real = result[0]; | |||
| } | |||
| // update output abstract for cnode | |||
| if (cnode != nullptr) { | |||
| cnode->set_abstract(op_exec_info->abstract); | |||
| } | |||
| std::string obj_id = GetId(out_real); | |||
| node_abs_map_[obj_id] = op_exec_info->abstract; | |||
| // save info for building grad graph | |||
| SaveOutputNodeMap(obj_id, out_real, cnode); | |||
| SaveAllResult(op_exec_info, cnode, out_real); | |||
| // Update the abstract and device address of value node with tensor in grad graph | |||
| @@ -784,6 +764,49 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v | |||
| return cnode; | |||
| } | |||
| void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, | |||
| const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) { | |||
| MS_EXCEPTION_IF_NULL(is_find); | |||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||
| *is_find = false; | |||
| auto op_name = op_exec_info->op_name; | |||
| auto prim = op_exec_info->py_primitive; | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { | |||
| auto abs_list = prim_abs_list_[prim->id()]; | |||
| MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list); | |||
| if (abs_list.find(args_spec_list) != abs_list.end()) { | |||
| MS_LOG(DEBUG) << "Match prim ok " << op_name; | |||
| op_exec_info->abstract = abs_list[args_spec_list].abs; | |||
| prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); | |||
| *is_find = true; | |||
| } | |||
| } | |||
| if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) { | |||
| // use python infer method | |||
| if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) { | |||
| PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); | |||
| } | |||
| } | |||
| // get output dynamic shape info | |||
| auto py_abstract = op_exec_info->abstract; | |||
| MS_EXCEPTION_IF_NULL(py_abstract); | |||
| auto py_shape = py_abstract->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(py_shape); | |||
| auto py_shape_info = py_shape->ToString(); | |||
| if (py_shape_info.find("-1") != string::npos) { | |||
| auto c_abstract = abstract::CppInferShape(prim, args_spec_list); | |||
| MS_EXCEPTION_IF_NULL(c_abstract); | |||
| auto c_shape = c_abstract->BuildShape(); | |||
| MS_EXCEPTION_IF_NULL(c_shape); | |||
| auto c_shape_info = c_shape->ToString(); | |||
| MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info; | |||
| if (c_shape_info.find("-1") != string::npos) { | |||
| op_exec_info->is_dynamic_shape = true; | |||
| } | |||
| } | |||
| } | |||
| py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, | |||
| size_t index) { | |||
| py::tuple cast_args(3); | |||
| @@ -1326,6 +1349,9 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati | |||
| op_exec_info->next_input_index}; | |||
| VectorRef outputs; | |||
| session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask); | |||
| if (op_exec_info->is_dynamic_shape) { | |||
| op_exec_info->abstract = op_run_info.abstract; | |||
| } | |||
| auto result = BaseRefToPyData(outputs); | |||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | |||
| *status = PYNATIVE_SUCCESS; | |||
| @@ -129,6 +129,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||
| AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); | |||
| AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | |||
| abstract::AbstractBasePtrList *args_spec_list); | |||
| void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, | |||
| bool *is_find); | |||
| void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode); | |||
| // replace for grad graph | |||
| @@ -577,6 +577,7 @@ void AscendKernelRuntime::DumpTaskExceptionInfo(const session::KernelGraph *grap | |||
| } | |||
| bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| bool ret = false; | |||
| #if defined(_WIN32) || defined(_WIN64) | |||
| auto start_time = std::chrono::steady_clock::now(); | |||
| @@ -336,6 +336,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { | |||
| } | |||
| bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| struct timeval start_time, end_time; | |||
| (void)gettimeofday(&start_time, nullptr); | |||
| bool ret = true; | |||
| @@ -360,7 +361,12 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { | |||
| ret = RunOneStep(graph); | |||
| } else { | |||
| ret = LaunchKernel(graph); | |||
| if (graph->is_dynamic_shape()) { | |||
| // run dynamic shape graph in pynative | |||
| ret = RunOpLaunchKernelDynamic(graph); | |||
| } else { | |||
| ret = LaunchKernel(graph); | |||
| } | |||
| } | |||
| (void)gettimeofday(&end_time, nullptr); | |||
| const uint64_t kUSecondInSecond = 1000000; | |||
| @@ -674,6 +680,42 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo | |||
| return true; | |||
| } | |||
| bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| const auto &kernels = graph->execution_order(); | |||
| for (const auto &kernel : kernels) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| // akg kernel do not support dynamic shape by now. | |||
| device::DynamicKernelPtr dynamic_kernel = nullptr; | |||
| kernel::GpuKernel *gpu_kernel = nullptr; | |||
| if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) != KernelType::AKG_KERNEL) { | |||
| gpu_kernel = dynamic_cast<kernel::GpuKernel *>(kernel_mod); | |||
| dynamic_kernel = gpu_kernel->DynamicKernel(); | |||
| } | |||
| // pre-processing for dynamic shape kernel | |||
| if (dynamic_kernel && dynamic_kernel->is_dynamic_shape()) { | |||
| dynamic_kernel->InferShape(); | |||
| dynamic_kernel->UpdateArgs(); | |||
| } | |||
| // alloc kernel res | |||
| AddressPtrList kernel_inputs; | |||
| AddressPtrList kernel_workspaces; | |||
| AddressPtrList kernel_outputs; | |||
| GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||
| if (!ret) { | |||
| MS_LOG(ERROR) << "Launch kernel failed."; | |||
| return false; | |||
| } | |||
| if (gpu_kernel && dynamic_kernel && dynamic_kernel->is_dynamic_shape()) { | |||
| gpu_kernel->PostExecute(); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | |||
| const AddressPtrList &workspace, const AddressPtrList &outputs) { | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||
| @@ -73,6 +73,7 @@ class GPUKernelRuntime : public KernelRuntime { | |||
| bool SearchMemSwapScheme(const session::KernelGraph *graph); | |||
| bool RefineMemSwapScheme(const session::KernelGraph *graph); | |||
| bool LaunchKernelDynamic(const session::KernelGraph *graph, bool mock = false, bool profiling = false); | |||
| bool RunOpLaunchKernelDynamic(const session::KernelGraph *graph); | |||
| void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | |||
| const AddressPtrList &workspace, const AddressPtrList &outputs); | |||
| bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock); | |||
| @@ -875,7 +875,7 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList | |||
| } | |||
| bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { | |||
| auto &kernels = graph.execution_order(); | |||
| const auto &kernels = graph.execution_order(); | |||
| std::vector<DynamicKernelPtr> dynamic_kernel_list; | |||
| auto iter = graph_dynamic_kernel_map_.find(graph.graph_id()); | |||
| if (iter != graph_dynamic_kernel_map_.end()) { | |||