From: @joylvliang Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -565,6 +565,13 @@ void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph | |||||
| void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const { | ||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| const auto &kernels = kernel_graph->execution_order(); | |||||
| auto iter = std::find_if(kernels.begin(), kernels.end(), [](const CNodePtr &kernel) { | |||||
| return AnfAlgo::GetKernelType(kernel) == AICPU_KERNEL && AnfAlgo::GetBooleanAttr(kernel, kAttrOutputIsDynamicShape); | |||||
| }); | |||||
| if (iter == kernels.end()) { | |||||
| return; | |||||
| } | |||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | ||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | MS_EXCEPTION_IF_NULL(runtime_instance); | ||||
| if (!runtime_instance->GenDynamicKernel(kernel_graph.get())) { | if (!runtime_instance->GenDynamicKernel(kernel_graph.get())) { | ||||
| @@ -624,6 +624,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||||
| if (abs_list.find(args_spec_list) != abs_list.end()) { | if (abs_list.find(args_spec_list) != abs_list.end()) { | ||||
| MS_LOG(DEBUG) << "Match prim ok " << op_exec_info->op_name; | MS_LOG(DEBUG) << "Match prim ok " << op_exec_info->op_name; | ||||
| op_exec_info->abstract = abs_list[args_spec_list].abs; | op_exec_info->abstract = abs_list[args_spec_list].abs; | ||||
| op_exec_info->is_dynamic_shape = abs_list[args_spec_list].is_dynamic_shape; | |||||
| prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); | prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); | ||||
| is_find = true; | is_find = true; | ||||
| } | } | ||||
| @@ -634,19 +635,20 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||||
| if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { | if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { | ||||
| PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); | PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list); | ||||
| } | } | ||||
| // get output dynamic shape info | |||||
| auto abstract = op_exec_info->abstract; | |||||
| MS_EXCEPTION_IF_NULL(abstract); | |||||
| auto shape = abstract->BuildShape(); | |||||
| MS_EXCEPTION_IF_NULL(shape); | |||||
| auto shape_info = shape->ToString(); | |||||
| if (shape_info.find("-1") != string::npos) { | |||||
| op_exec_info->is_dynamic_shape = true; | |||||
| } | |||||
| } | } | ||||
| if (cnode != nullptr) { | if (cnode != nullptr) { | ||||
| cnode->set_abstract(op_exec_info->abstract); | cnode->set_abstract(op_exec_info->abstract); | ||||
| } | } | ||||
| // get output dynamic shape info | |||||
| MS_EXCEPTION_IF_NULL(op_exec_info->abstract); | |||||
| auto abstract_info = op_exec_info->abstract->ToString(); | |||||
| if (abstract_info.find("-1") != string::npos) { | |||||
| op_exec_info->is_dynamic_shape = true; | |||||
| } | |||||
| op_exec_info->inputs_mask = op_masks; | op_exec_info->inputs_mask = op_masks; | ||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | MS_EXCEPTION_IF_NULL(op_exec_info); | ||||
| if (op_exec_info->abstract != nullptr) { | if (op_exec_info->abstract != nullptr) { | ||||
| @@ -668,6 +670,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||||
| // const_value need infer every step | // const_value need infer every step | ||||
| auto &out = prim_abs_list_[prim->id()]; | auto &out = prim_abs_list_[prim->id()]; | ||||
| out[args_spec_list].abs = op_exec_info->abstract; | out[args_spec_list].abs = op_exec_info->abstract; | ||||
| out[args_spec_list].is_dynamic_shape = op_exec_info->is_dynamic_shape; | |||||
| out[args_spec_list].attrs = prim->evaluate_added_attrs(); | out[args_spec_list].attrs = prim->evaluate_added_attrs(); | ||||
| MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | ||||
| } | } | ||||
| @@ -46,6 +46,7 @@ using GradOperationPtr = std::shared_ptr<prim::GradOperation>; | |||||
| struct PrimAbsInfo { | struct PrimAbsInfo { | ||||
| abstract::AbstractBasePtr abs; | abstract::AbstractBasePtr abs; | ||||
| bool is_dynamic_shape = false; | |||||
| std::unordered_map<std::string, ValuePtr> attrs; | std::unordered_map<std::string, ValuePtr> attrs; | ||||
| }; | }; | ||||
| @@ -830,8 +830,9 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { | |||||
| << " should be equal to the size of kernels " << kernels.size(); | << " should be equal to the size of kernels " << kernels.size(); | ||||
| } | } | ||||
| for (size_t i = 0; i < kernels.size(); ++i) { | for (size_t i = 0; i < kernels.size(); ++i) { | ||||
| auto &kernel = kernels[i]; | |||||
| if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr && | if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr && | ||||
| dynamic_kernel_list[i]->is_dynamic_shape()) { | |||||
| dynamic_kernel_list[i]->is_dynamic_shape() && AnfAlgo::GetKernelType(kernel) == AICPU_KERNEL) { | |||||
| dynamic_kernel_list[i]->InferShape(); | dynamic_kernel_list[i]->InferShape(); | ||||
| dynamic_kernel_list[i]->UpdateArgs(); | dynamic_kernel_list[i]->UpdateArgs(); | ||||
| dynamic_kernel_list[i]->Execute(); | dynamic_kernel_list[i]->Execute(); | ||||
| @@ -841,12 +842,12 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) { | |||||
| } | } | ||||
| dynamic_kernel_list[i]->PostExecute(); | dynamic_kernel_list[i]->PostExecute(); | ||||
| } else { | } else { | ||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernels[i]); | |||||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod); | MS_EXCEPTION_IF_NULL(kernel_mod); | ||||
| AddressPtrList kernel_inputs; | AddressPtrList kernel_inputs; | ||||
| AddressPtrList kernel_workspaces; | AddressPtrList kernel_workspaces; | ||||
| AddressPtrList kernel_outputs; | AddressPtrList kernel_outputs; | ||||
| GenLaunchArgs(*kernel_mod, kernels[i], &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||||
| GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||||
| auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | ||||
| if (!ret) { | if (!ret) { | ||||
| MS_LOG(ERROR) << "Launch kernel failed."; | MS_LOG(ERROR) << "Launch kernel failed."; | ||||