| @@ -700,20 +700,21 @@ void AscendSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &g | |||||
| MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; | MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; | ||||
| } | } | ||||
| void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| void AscendSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | ||||
| const std::vector<int64_t> &tensors_mask) { | const std::vector<int64_t> &tensors_mask) { | ||||
| MS_EXCEPTION_IF_NULL(input_tensors); | MS_EXCEPTION_IF_NULL(input_tensors); | ||||
| BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); | |||||
| MS_EXCEPTION_IF_NULL(op_run_info); | |||||
| BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); | |||||
| EraseValueNodeTensor(tensors_mask, input_tensors); | EraseValueNodeTensor(tensors_mask, input_tensors); | ||||
| // Run op | |||||
| auto graph = run_op_graphs_[graph_info]; | auto graph = run_op_graphs_[graph_info]; | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; | |||||
| MS_LOG(INFO) << "Run op " << op_run_info->op_name << " start!"; | |||||
| // malloc mem | // malloc mem | ||||
| RunOpMemoryAlloc(*input_tensors, graph.get()); | RunOpMemoryAlloc(*input_tensors, graph.get()); | ||||
| // Build dynamic kernel | // Build dynamic kernel | ||||
| if (op_run_info.is_dynamic_shape) { | |||||
| if (op_run_info->is_dynamic_shape) { | |||||
| BuildDynamicKernel(graph); | BuildDynamicKernel(graph); | ||||
| } | } | ||||
| // load input data to device | // load input data to device | ||||
| @@ -722,8 +723,12 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra | |||||
| Execute(graph, false); | Execute(graph, false); | ||||
| // get output | // get output | ||||
| UpdateOutputs(graph, outputs, *input_tensors); | UpdateOutputs(graph, outputs, *input_tensors); | ||||
| // update output abstract of dynamic op to op_run_info | |||||
| if (op_run_info->is_dynamic_shape) { | |||||
| UpdateOutputAbstract(graph, op_run_info); | |||||
| } | |||||
| RunOpMemoryClear(graph.get()); | RunOpMemoryClear(graph.get()); | ||||
| MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; | |||||
| MS_LOG(INFO) << "Run op " << op_run_info->op_name << " finish!"; | |||||
| } | } | ||||
| void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| @@ -750,7 +755,7 @@ void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector | |||||
| // Build and run current single op | // Build and run current single op | ||||
| VectorRef op_outputs; | VectorRef op_outputs; | ||||
| RunOpImpl(run_info, graph_info, &input_tensor_info.input_tensors, &op_outputs, | |||||
| RunOpImpl(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs, | |||||
| input_tensor_info.input_tensors_mask); | input_tensor_info.input_tensors_mask); | ||||
| // Handle inputs and outputs of current op | // Handle inputs and outputs of current op | ||||
| @@ -59,9 +59,8 @@ class AscendSession : public SessionBasic { | |||||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, | const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| const std::vector<int64_t> &tensors_mask) override; | const std::vector<int64_t> &tensors_mask) override; | ||||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||||
| const std::vector<int64_t> &tensors_mask) override; | |||||
| void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, | |||||
| VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | |||||
| void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| VectorRef *outputs) override; | VectorRef *outputs) override; | ||||
| @@ -170,11 +170,12 @@ void CPUSession::SetOutputFlags(const VectorRef &base_ref, std::vector<tensor::T | |||||
| } | } | ||||
| } | } | ||||
| void CPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| void CPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | ||||
| const std::vector<int64_t> &tensors_mask) { | const std::vector<int64_t> &tensors_mask) { | ||||
| MS_EXCEPTION_IF_NULL(input_tensors); | MS_EXCEPTION_IF_NULL(input_tensors); | ||||
| BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); | |||||
| MS_EXCEPTION_IF_NULL(op_run_info); | |||||
| BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); | |||||
| EraseValueNodeTensor(tensors_mask, input_tensors); | EraseValueNodeTensor(tensors_mask, input_tensors); | ||||
| auto kernel_graph = run_op_graphs_[graph_info]; | auto kernel_graph = run_op_graphs_[graph_info]; | ||||
| @@ -41,9 +41,8 @@ class CPUSession : public SessionBasic { | |||||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, | const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| const std::vector<int64_t> &tensors_mask) override; | const std::vector<int64_t> &tensors_mask) override; | ||||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||||
| const std::vector<int64_t> &tensors_mask) override; | |||||
| void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, | |||||
| VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | |||||
| private: | private: | ||||
| void SetKernelInfo(const KernelGraph *kernel_graph); | void SetKernelInfo(const KernelGraph *kernel_graph); | ||||
| @@ -130,7 +130,7 @@ void RunGraphTask::Run() { | |||||
| void RunOpTask::Run() { | void RunOpTask::Run() { | ||||
| MS_EXCEPTION_IF_NULL(session_); | MS_EXCEPTION_IF_NULL(session_); | ||||
| session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_, tensors_mask_); | |||||
| session_->RunOpImpl(graph_info_, op_run_info_, input_tensors_, &outputs_, tensors_mask_); | |||||
| } | } | ||||
| void RunOpsInGraphTask::Run() { | void RunOpsInGraphTask::Run() { | ||||
| @@ -403,13 +403,14 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap | |||||
| run_op_graphs_[graph_info] = kernel_graph; | run_op_graphs_[graph_info] = kernel_graph; | ||||
| } | } | ||||
| void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| void GPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | ||||
| const std::vector<int64_t> &tensors_mask) { | const std::vector<int64_t> &tensors_mask) { | ||||
| MS_EXCEPTION_IF_NULL(input_tensors); | MS_EXCEPTION_IF_NULL(input_tensors); | ||||
| BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); | |||||
| MS_EXCEPTION_IF_NULL(op_run_info); | |||||
| BuildOpImpl(*op_run_info, graph_info, *input_tensors, tensors_mask); | |||||
| EraseValueNodeTensor(tensors_mask, input_tensors); | EraseValueNodeTensor(tensors_mask, input_tensors); | ||||
| // run op | |||||
| auto kernel_graph = run_op_graphs_[graph_info]; | auto kernel_graph = run_op_graphs_[graph_info]; | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| // Remove NopOp from execution graph | // Remove NopOp from execution graph | ||||
| @@ -420,6 +421,10 @@ void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_ | |||||
| Execute(kernel_graph); | Execute(kernel_graph); | ||||
| // Fetch outputs | // Fetch outputs | ||||
| UpdateOutputs(kernel_graph, outputs, *input_tensors); | UpdateOutputs(kernel_graph, outputs, *input_tensors); | ||||
| // update output abstract of dynamic op to op_run_info | |||||
| if (op_run_info->is_dynamic_shape) { | |||||
| UpdateOutputAbstract(kernel_graph, op_run_info); | |||||
| } | |||||
| RunOpClearMemory(kernel_graph.get()); | RunOpClearMemory(kernel_graph.get()); | ||||
| } | } | ||||
| @@ -39,9 +39,8 @@ class GPUSession : public SessionBasic { | |||||
| void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, | const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| const std::vector<int64_t> &tensors_mask) override; | const std::vector<int64_t> &tensors_mask) override; | ||||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | |||||
| const std::vector<int64_t> &tensors_mask) override; | |||||
| void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, | |||||
| VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; | |||||
| private: | private: | ||||
| void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| @@ -1142,6 +1142,19 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_grap | |||||
| } | } | ||||
| } | } | ||||
| void SessionBasic::UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, | |||||
| OpRunInfo *op_run_info) const { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| MS_EXCEPTION_IF_NULL(op_run_info); | |||||
| const auto &kernels = kernel_graph->execution_order(); | |||||
| for (const auto &kernel : kernels) { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | |||||
| if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) { | |||||
| op_run_info->abstract = kernel->abstract(); | |||||
| } | |||||
| } | |||||
| } | |||||
| std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const GraphId &graph_id, | std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const GraphId &graph_id, | ||||
| const std::vector<tensor::TensorPtr> &inputs) { | const std::vector<tensor::TensorPtr> &inputs) { | ||||
| auto graph = GetGraph(graph_id); | auto graph = GetGraph(graph_id); | ||||
| @@ -153,7 +153,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | virtual void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, | const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| const std::vector<int64_t> &tensors_mask) {} | const std::vector<int64_t> &tensors_mask) {} | ||||
| virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | |||||
| virtual void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, | |||||
| std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, | ||||
| const std::vector<int64_t> &tensors_mask) {} | const std::vector<int64_t> &tensors_mask) {} | ||||
| virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| @@ -167,6 +167,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors); | void EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask, std::vector<tensor::TensorPtr> *input_tensors); | ||||
| void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs, | void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors) const; | const std::vector<tensor::TensorPtr> &input_tensors) const; | ||||
| void UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph, OpRunInfo *op_run_info) const; | |||||
| void Reorder(std::vector<CNodePtr> *node_list); | void Reorder(std::vector<CNodePtr> *node_list); | ||||
| void Summary(KernelGraph *graph); | void Summary(KernelGraph *graph); | ||||
| // create graph output for RunOp | // create graph output for RunOp | ||||
| @@ -70,10 +70,12 @@ const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_ca | |||||
| const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"}; | const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"}; | ||||
| const std::set<std::string> ignore_judge_dynamic_cell = { | const std::set<std::string> ignore_judge_dynamic_cell = { | ||||
| "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal", | "Cell mindspore.nn.layer.basic.Dense", "Cell mindspore.nn.probability.distribution.normal.Normal", | ||||
| "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask"}; | |||||
| "Cell src.transformer.create_attn_mask.CreateAttentionMaskFromInputMask", "Cell mindspore.nn.layer.math.MatMul"}; | |||||
| const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, | const std::set<std::string> unchanged_named_primitive = {parse::NAMED_PRIMITIVE_ATTRIBUTE, | ||||
| parse::NAMED_PRIMITIVE_NAMECONSTANT, | parse::NAMED_PRIMITIVE_NAMECONSTANT, | ||||
| parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; | parse::NAMED_PRIMITIVE_NUM, parse::NAMED_PRIMITIVE_STR}; | ||||
| const std::set<std::string> dynamic_shape_const_input_to_attr = {"Cast", "ExpandDims", "Reshape", "EmbeddingLookup", | |||||
| "Transpose"}; | |||||
| } // namespace pynative | } // namespace pynative | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -467,6 +467,11 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector<int64_t> | |||||
| opt::ConstInputToAttrInfoRegister reg; | opt::ConstInputToAttrInfoRegister reg; | ||||
| bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); | bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info->op_name, ®); | ||||
| if (op_run_info->is_dynamic_shape && | |||||
| dynamic_shape_const_input_to_attr.find(op_run_info->op_name) == dynamic_shape_const_input_to_attr.end()) { | |||||
| MS_LOG(INFO) << "current node is dynamic shape: " << op_run_info->op_name; | |||||
| reg_exist = false; | |||||
| } | |||||
| if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { | if (op_run_info->op_name == prim::kPrimEmbeddingLookup->name()) { | ||||
| reg_exist = false; | reg_exist = false; | ||||
| } | } | ||||
| @@ -594,6 +599,7 @@ py::tuple RunOp(const py::args &args) { | |||||
| } | } | ||||
| py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | ||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||||
| if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { | if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { | ||||
| return RunOpWithInitBackendPolicy(op_exec_info); | return RunOpWithInitBackendPolicy(op_exec_info); | ||||
| } | } | ||||
| @@ -604,58 +610,27 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||||
| op_exec_info->inputs_mask = op_masks; | op_exec_info->inputs_mask = op_masks; | ||||
| // get output abstract info | // get output abstract info | ||||
| bool is_find = false; | bool is_find = false; | ||||
| GetOpOutputAbstract(op_exec_info, args_spec_list, &is_find); | |||||
| MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString(); | |||||
| // infer output value for const prim | |||||
| auto prim = op_exec_info->py_primitive; | auto prim = op_exec_info->py_primitive; | ||||
| if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { | |||||
| auto abs_list = prim_abs_list_[prim->id()]; | |||||
| MS_LOG(DEBUG) << "Match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | |||||
| if (abs_list.find(args_spec_list) != abs_list.end()) { | |||||
| MS_LOG(DEBUG) << "Match prim ok " << op_exec_info->op_name; | |||||
| op_exec_info->abstract = abs_list[args_spec_list].abs; | |||||
| op_exec_info->is_dynamic_shape = abs_list[args_spec_list].is_dynamic_shape; | |||||
| prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); | |||||
| is_find = true; | |||||
| } | |||||
| } | |||||
| if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_exec_info->op_name) != force_infer_prim.end()) { | |||||
| // use python infer method | |||||
| if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { | |||||
| PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); | |||||
| } | |||||
| // get output dynamic shape info | |||||
| auto abstract = op_exec_info->abstract; | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| auto shape = abstract->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| auto shape_info = shape->ToString(); | |||||
| if (shape_info.find("-1") != string::npos) { | |||||
| op_exec_info->is_dynamic_shape = true; | |||||
| } | |||||
| } | |||||
| if (cnode != nullptr) { | |||||
| cnode->set_abstract(op_exec_info->abstract); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); | |||||
| if (!output["value"].is_none()) { | |||||
| py::tuple value_ret(1); | |||||
| value_ret[0] = output["value"]; | |||||
| return value_ret; | |||||
| } | } | ||||
| // infer output value for const prim | |||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||||
| if (op_exec_info->abstract != nullptr) { | |||||
| MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString(); | |||||
| py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); | |||||
| if (!output["value"].is_none()) { | |||||
| py::tuple value_ret(1); | |||||
| value_ret[0] = output["value"]; | |||||
| return value_ret; | |||||
| } | |||||
| if (op_exec_info->py_primitive->is_const_prim()) { | |||||
| py::tuple value_ret(1); | |||||
| value_ret[0] = ""; | |||||
| return value_ret; | |||||
| } | |||||
| if (prim->is_const_prim()) { | |||||
| py::tuple value_ret(1); | |||||
| value_ret[0] = ""; | |||||
| return value_ret; | |||||
| } | } | ||||
| // add output abstract info into cache | // add output abstract info into cache | ||||
| if (!is_find) { | |||||
| if (!is_find && !op_exec_info->is_dynamic_shape) { | |||||
| // const_value need infer every step | // const_value need infer every step | ||||
| auto &out = prim_abs_list_[prim->id()]; | auto &out = prim_abs_list_[prim->id()]; | ||||
| out[args_spec_list].abs = op_exec_info->abstract; | out[args_spec_list].abs = op_exec_info->abstract; | ||||
| out[args_spec_list].is_dynamic_shape = op_exec_info->is_dynamic_shape; | |||||
| out[args_spec_list].attrs = prim->evaluate_added_attrs(); | out[args_spec_list].attrs = prim->evaluate_added_attrs(); | ||||
| MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | ||||
| } | } | ||||
| @@ -666,8 +641,13 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||||
| MS_LOG(DEBUG) << "Output size is 1"; | MS_LOG(DEBUG) << "Output size is 1"; | ||||
| out_real = result[0]; | out_real = result[0]; | ||||
| } | } | ||||
| // update output abstract for cnode | |||||
| if (cnode != nullptr) { | |||||
| cnode->set_abstract(op_exec_info->abstract); | |||||
| } | |||||
| std::string obj_id = GetId(out_real); | std::string obj_id = GetId(out_real); | ||||
| node_abs_map_[obj_id] = op_exec_info->abstract; | node_abs_map_[obj_id] = op_exec_info->abstract; | ||||
| // save info for building grad graph | |||||
| SaveOutputNodeMap(obj_id, out_real, cnode); | SaveOutputNodeMap(obj_id, out_real, cnode); | ||||
| SaveAllResult(op_exec_info, cnode, out_real); | SaveAllResult(op_exec_info, cnode, out_real); | ||||
| // Update the abstract and device address of value node with tensor in grad graph | // Update the abstract and device address of value node with tensor in grad graph | ||||
| @@ -784,6 +764,49 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v | |||||
| return cnode; | return cnode; | ||||
| } | } | ||||
| void PynativeExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, | |||||
| const abstract::AbstractBasePtrList &args_spec_list, bool *is_find) { | |||||
| MS_EXCEPTION_IF_NULL(is_find); | |||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||||
| *is_find = false; | |||||
| auto op_name = op_exec_info->op_name; | |||||
| auto prim = op_exec_info->py_primitive; | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { | |||||
| auto abs_list = prim_abs_list_[prim->id()]; | |||||
| MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list); | |||||
| if (abs_list.find(args_spec_list) != abs_list.end()) { | |||||
| MS_LOG(DEBUG) << "Match prim ok " << op_name; | |||||
| op_exec_info->abstract = abs_list[args_spec_list].abs; | |||||
| prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); | |||||
| *is_find = true; | |||||
| } | |||||
| } | |||||
| if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) { | |||||
| // use python infer method | |||||
| if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) { | |||||
| PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); | |||||
| } | |||||
| } | |||||
| // get output dynamic shape info | |||||
| auto py_abstract = op_exec_info->abstract; | |||||
| MS_EXCEPTION_IF_NULL(py_abstract); | |||||
| auto py_shape = py_abstract->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(py_shape); | |||||
| auto py_shape_info = py_shape->ToString(); | |||||
| if (py_shape_info.find("-1") != string::npos) { | |||||
| auto c_abstract = abstract::CppInferShape(prim, args_spec_list); | |||||
| MS_EXCEPTION_IF_NULL(c_abstract); | |||||
| auto c_shape = c_abstract->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(c_shape); | |||||
| auto c_shape_info = c_shape->ToString(); | |||||
| MS_LOG(DEBUG) << "Final infer output shape: " << c_shape_info; | |||||
| if (c_shape_info.find("-1") != string::npos) { | |||||
| op_exec_info->is_dynamic_shape = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, | py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, | ||||
| size_t index) { | size_t index) { | ||||
| py::tuple cast_args(3); | py::tuple cast_args(3); | ||||
| @@ -1326,6 +1349,9 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati | |||||
| op_exec_info->next_input_index}; | op_exec_info->next_input_index}; | ||||
| VectorRef outputs; | VectorRef outputs; | ||||
| session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask); | session->RunOp(&op_run_info, graph_info, &input_tensors, &outputs, tensors_mask); | ||||
| if (op_exec_info->is_dynamic_shape) { | |||||
| op_exec_info->abstract = op_run_info.abstract; | |||||
| } | |||||
| auto result = BaseRefToPyData(outputs); | auto result = BaseRefToPyData(outputs); | ||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); | ||||
| *status = PYNATIVE_SUCCESS; | *status = PYNATIVE_SUCCESS; | ||||
| @@ -129,6 +129,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); | AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id); | ||||
| AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | ||||
| abstract::AbstractBasePtrList *args_spec_list); | abstract::AbstractBasePtrList *args_spec_list); | ||||
| void GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, const abstract::AbstractBasePtrList &args_spec_list, | |||||
| bool *is_find); | |||||
| void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode); | void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode); | ||||
| // replace for grad graph | // replace for grad graph | ||||
| @@ -577,6 +577,7 @@ void AscendKernelRuntime::DumpTaskExceptionInfo(const session::KernelGraph *grap | |||||
| } | } | ||||
| bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { | bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| bool ret = false; | bool ret = false; | ||||
| #if defined(_WIN32) || defined(_WIN64) | #if defined(_WIN32) || defined(_WIN64) | ||||
| auto start_time = std::chrono::steady_clock::now(); | auto start_time = std::chrono::steady_clock::now(); | ||||
| @@ -336,6 +336,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { | |||||
| } | } | ||||
| bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { | bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| struct timeval start_time, end_time; | struct timeval start_time, end_time; | ||||
| (void)gettimeofday(&start_time, nullptr); | (void)gettimeofday(&start_time, nullptr); | ||||
| bool ret = true; | bool ret = true; | ||||
| @@ -360,7 +361,12 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink) { | |||||
| ret = RunOneStep(graph); | ret = RunOneStep(graph); | ||||
| } else { | } else { | ||||
| ret = LaunchKernel(graph); | |||||
| if (graph->is_dynamic_shape()) { | |||||
| // run dynamic shape graph in pynative | |||||
| ret = RunOpLaunchKernelDynamic(graph); | |||||
| } else { | |||||
| ret = LaunchKernel(graph); | |||||
| } | |||||
| } | } | ||||
| (void)gettimeofday(&end_time, nullptr); | (void)gettimeofday(&end_time, nullptr); | ||||
| const uint64_t kUSecondInSecond = 1000000; | const uint64_t kUSecondInSecond = 1000000; | ||||
| @@ -674,6 +680,42 @@ bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph, bo | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool GPUKernelRuntime::RunOpLaunchKernelDynamic(const session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| const auto &kernels = graph->execution_order(); | |||||
| for (const auto &kernel : kernels) { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||||
| // akg kernel do not support dynamic shape by now. | |||||
| device::DynamicKernelPtr dynamic_kernel = nullptr; | |||||
| kernel::GpuKernel *gpu_kernel = nullptr; | |||||
| if (session::AnfRuntimeAlgorithm::GetKernelType(kernel) != KernelType::AKG_KERNEL) { | |||||
| gpu_kernel = dynamic_cast<kernel::GpuKernel *>(kernel_mod); | |||||
| dynamic_kernel = gpu_kernel->DynamicKernel(); | |||||
| } | |||||
| // pre-processing for dynamic shape kernel | |||||
| if (dynamic_kernel && dynamic_kernel->is_dynamic_shape()) { | |||||
| dynamic_kernel->InferShape(); | |||||
| dynamic_kernel->UpdateArgs(); | |||||
| } | |||||
| // alloc kernel res | |||||
| AddressPtrList kernel_inputs; | |||||
| AddressPtrList kernel_workspaces; | |||||
| AddressPtrList kernel_outputs; | |||||
| GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||||
| auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||||
| if (!ret) { | |||||
| MS_LOG(ERROR) << "Launch kernel failed."; | |||||
| return false; | |||||
| } | |||||
| if (gpu_kernel && dynamic_kernel && dynamic_kernel->is_dynamic_shape()) { | |||||
| gpu_kernel->PostExecute(); | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | void GPUKernelRuntime::LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | ||||
| const AddressPtrList &workspace, const AddressPtrList &outputs) { | const AddressPtrList &workspace, const AddressPtrList &outputs) { | ||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | ||||
| @@ -73,6 +73,7 @@ class GPUKernelRuntime : public KernelRuntime { | |||||
| bool SearchMemSwapScheme(const session::KernelGraph *graph); | bool SearchMemSwapScheme(const session::KernelGraph *graph); | ||||
| bool RefineMemSwapScheme(const session::KernelGraph *graph); | bool RefineMemSwapScheme(const session::KernelGraph *graph); | ||||
| bool LaunchKernelDynamic(const session::KernelGraph *graph, bool mock = false, bool profiling = false); | bool LaunchKernelDynamic(const session::KernelGraph *graph, bool mock = false, bool profiling = false); | ||||
| bool RunOpLaunchKernelDynamic(const session::KernelGraph *graph); | |||||
| void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | void LaunchKernelWithTimeProfiling(const AnfNodePtr &kernel, const AddressPtrList &inputs, | ||||
| const AddressPtrList &workspace, const AddressPtrList &outputs); | const AddressPtrList &workspace, const AddressPtrList &outputs); | ||||
| bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock); | bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size, bool mock); | ||||
| @@ -875,7 +875,7 @@ void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList | |||||
| } | } | ||||
| bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { | bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { | ||||
| auto &kernels = graph.execution_order(); | |||||
| const auto &kernels = graph.execution_order(); | |||||
| std::vector<DynamicKernelPtr> dynamic_kernel_list; | std::vector<DynamicKernelPtr> dynamic_kernel_list; | ||||
| auto iter = graph_dynamic_kernel_map_.find(graph.graph_id()); | auto iter = graph_dynamic_kernel_map_.find(graph.graph_id()); | ||||
| if (iter != graph_dynamic_kernel_map_.end()) { | if (iter != graph_dynamic_kernel_map_.end()) { | ||||