| @@ -23,14 +23,16 @@ void GpuDynamicKernel::UpdateArgs() { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Update Args: " << cnode_ptr_->fullname_with_scope(); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "Update Args: " << cnode->fullname_with_scope(); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(cnode); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| auto gpu_kernel_mod = dynamic_cast<GpuKernel *>(kernel_mod); | |||
| MS_EXCEPTION_IF_NULL(gpu_kernel_mod); | |||
| gpu_kernel_mod->DestroyResource(); | |||
| gpu_kernel_mod->ResetResource(); | |||
| gpu_kernel_mod->Init(cnode_ptr_); | |||
| gpu_kernel_mod->Init(cnode); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -21,12 +21,14 @@ namespace mindspore { | |||
| namespace kernel { | |||
| void DynamicShapeKernel::Execute() { | |||
| MS_LOG(INFO) << "Execute DynamicShapeKernel Start"; | |||
| auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_num = AnfAlgo::GetInputTensorNum(cnode); | |||
| if (input_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Invalid Input Num:" << input_num; | |||
| } | |||
| auto prev_output_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, 0); | |||
| auto prev_output_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); | |||
| std::vector<int64_t> output_shape = {SizeToLong(prev_output_shape.size())}; | |||
| auto output_type = TypeId::kNumberTypeInt64; | |||
| @@ -38,7 +40,7 @@ void DynamicShapeKernel::Execute() { | |||
| *(data_ptr + i) = prev_output_shape[i]; | |||
| } | |||
| auto output_addr = AnfAlgo::GetOutputAddr(cnode_ptr_, 0); | |||
| auto output_addr = AnfAlgo::GetOutputAddr(cnode, 0); | |||
| MS_EXCEPTION_IF_NULL(output_addr); | |||
| output_addr->SyncHostToDevice(output_shape, LongToSize(output_tensor_for_sync->data().nbytes()), | |||
| output_tensor_for_sync->data_type(), output_tensor_for_sync->data_c()); | |||
| @@ -43,29 +43,33 @@ void AiCoreDynamicKernel::Execute() { | |||
| if (stream_ == nullptr) { | |||
| MS_LOG(EXCEPTION) << "stream_ptr should not be nullptr."; | |||
| } | |||
| MS_LOG(INFO) << "Start Execute node:" << cnode_ptr_->fullname_with_scope(); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "Start Execute node:" << cnode->fullname_with_scope(); | |||
| rtL2Ctrl_t *l2ctrl = nullptr; | |||
| auto args_size = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtime_args_.size()); | |||
| if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) { | |||
| MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error."; | |||
| } | |||
| MS_LOG(INFO) << "End Execute node:" << cnode_ptr_->fullname_with_scope(); | |||
| MS_LOG(INFO) << "End Execute node:" << cnode->fullname_with_scope(); | |||
| } | |||
| void AiCoreDynamicKernel::ParseCompileJson() { | |||
| if (!AnfAlgo::IsDynamicShape(cnode_ptr_)) { | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!AnfAlgo::IsDynamicShape(cnode)) { | |||
| return; | |||
| } | |||
| if (!AnfAlgo::HasNodeAttr(kAttrCompileInfo, cnode_ptr_)) { | |||
| if (!AnfAlgo::HasNodeAttr(kAttrCompileInfo, cnode)) { | |||
| MS_LOG(EXCEPTION) << "Get compile_info failed"; | |||
| } | |||
| auto compile_info_attr = AnfAlgo::GetNodeAttr<std::string>(cnode_ptr_, kAttrCompileInfo); | |||
| auto compile_info_attr = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrCompileInfo); | |||
| MS_LOG(INFO) << "Get compile_info:" << compile_info_attr; | |||
| op_compile_info_.str = compile_info_attr; | |||
| op_compile_info_.key = ""; | |||
| if (AnfAlgo::HasNodeAttr(kAttrFusionType, cnode_ptr_)) { | |||
| auto fusion_type = AnfAlgo::GetNodeAttr<std::string>(cnode_ptr_, kAttrFusionType); | |||
| if (AnfAlgo::HasNodeAttr(kAttrFusionType, cnode)) { | |||
| auto fusion_type = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFusionType); | |||
| MS_LOG(INFO) << "Get fusion_type:" << fusion_type; | |||
| (*compile_info_json_)["_pattern"] = fusion_type; | |||
| op_compile_info_.key = std::hash<std::string>{}(fusion_type); | |||
| @@ -85,14 +89,15 @@ void AiCoreDynamicKernel::UpdateArgs() { | |||
| } | |||
| AllocateWorkspace(); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(cnode); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| AddressPtrList kernel_inputs; | |||
| AddressPtrList kernel_workspaces; | |||
| AddressPtrList kernel_outputs; | |||
| KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| runtime_args_.clear(); | |||
| @@ -112,11 +117,12 @@ void AiCoreDynamicKernel::UpdateArgs() { | |||
| } | |||
| void AiCoreDynamicKernel::ComputeTiling() { | |||
| MS_EXCEPTION_IF_NULL(cnode_ptr_); | |||
| MS_LOG(INFO) << "Start compute tiling of:" << cnode_ptr_->fullname_with_scope(); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "Start compute tiling of:" << cnode->fullname_with_scope(); | |||
| optiling::OpRunInfo op_run_info; | |||
| OpTilingCalculater::GetInstance().CalculateTiling(NOT_NULL(cnode_ptr_), op_compile_info_, depend_tensor_map_, | |||
| OpTilingCalculater::GetInstance().CalculateTiling(NOT_NULL(cnode), op_compile_info_, depend_tensor_map_, | |||
| NOT_NULL(&op_run_info)); | |||
| block_dim_ = op_run_info.block_dim; | |||
| workspaces_size_ = op_run_info.workspaces; | |||
| @@ -62,23 +62,24 @@ void AiCpuDynamicKernel::Execute() { | |||
| void AiCpuDynamicKernel::Initialize() { | |||
| // is dynamic | |||
| MS_LOG(INFO) << "Initialize node:" << cnode_ptr_->fullname_with_scope(); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_LOG(INFO) << "Initialize node:" << cnode->fullname_with_scope(); | |||
| DynamicKernel::Initialize(); | |||
| input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||
| output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); | |||
| input_num_ = AnfAlgo::GetInputTensorNum(cnode); | |||
| output_num_ = AnfAlgo::GetOutputTensorNum(cnode); | |||
| UnknowShapeOpType shape_type = UnknowShapeOpType::DEPEND_IN_SHAPE; | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode_ptr_); | |||
| auto op_name = AnfAlgo::GetCNodeName(cnode); | |||
| if (kComputeDepend.find(op_name) != kComputeDepend.end()) { | |||
| shape_type = UnknowShapeOpType::DEPEND_COMPUTE; | |||
| } | |||
| unknow_type_ = shape_type; | |||
| // Parse aicpu ext info | |||
| if (is_dynamic_shape_) { | |||
| MS_EXCEPTION_IF_NULL(cnode_ptr_); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| ext_info_handler_ = | |||
| std::make_shared<AicpuExtInfoHandler>(cnode_ptr_->fullname_with_scope(), input_num_, output_num_, shape_type); | |||
| std::make_shared<AicpuExtInfoHandler>(cnode->fullname_with_scope(), input_num_, output_num_, shape_type); | |||
| ext_info_handler_->Parse(ext_info_data_); | |||
| } | |||
| @@ -108,14 +109,14 @@ void AiCpuDynamicKernel::Initialize() { | |||
| bool AiCpuDynamicKernel::UpdateInputOutputAddr() { | |||
| std::vector<uint64_t> io_addrs; | |||
| io_addrs.reserve(input_num_ + output_num_); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| for (size_t i = 0; i < input_num_; ++i) { | |||
| auto input_addr = AnfAlgo::GetPrevNodeOutputAddr(cnode_ptr_, i); | |||
| auto input_addr = AnfAlgo::GetPrevNodeOutputAddr(cnode, i); | |||
| io_addrs.emplace_back(reinterpret_cast<uintptr_t>(input_addr->GetMutablePtr())); | |||
| } | |||
| for (size_t i = 0; i < output_num_; ++i) { | |||
| auto output_addr = AnfAlgo::GetOutputAddr(cnode_ptr_, i); | |||
| auto output_addr = AnfAlgo::GetOutputAddr(cnode, i); | |||
| io_addrs.emplace_back(reinterpret_cast<uintptr_t>(output_addr->GetMutablePtr())); | |||
| } | |||
| @@ -135,19 +136,21 @@ bool AiCpuDynamicKernel::UpdateInputOutputAddr() { | |||
| } | |||
| bool AiCpuDynamicKernel::UpdateExtInfo() { | |||
| MS_LOG(INFO) << "UpdateExtInfo of " << cnode_ptr_->fullname_with_scope() << " start"; | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "UpdateExtInfo of " << cnode->fullname_with_scope() << " start"; | |||
| if (input_num_ == 0 && output_num_ == 0) { | |||
| MS_LOG(INFO) << "Node:" << cnode_ptr_->fullname_with_scope() << " no need to update output shape"; | |||
| MS_LOG(INFO) << "Node:" << cnode->fullname_with_scope() << " no need to update output shape"; | |||
| return true; | |||
| } | |||
| for (size_t i = 0; i < input_num_; ++i) { | |||
| ext_info_handler_->UpdateInputShapeAndType(i, NOT_NULL(cnode_ptr_)); | |||
| ext_info_handler_->UpdateInputShapeAndType(i, NOT_NULL(cnode)); | |||
| } | |||
| if (AnfAlgo::IsDynamicShape(cnode_ptr_) && unknow_type_ != DEPEND_COMPUTE) { | |||
| if (AnfAlgo::IsDynamicShape(cnode) && unknow_type_ != DEPEND_COMPUTE) { | |||
| for (size_t i = 0; i < output_num_; ++i) { | |||
| ext_info_handler_->UpdateOutputShapeAndType(i, NOT_NULL(cnode_ptr_)); | |||
| ext_info_handler_->UpdateOutputShapeAndType(i, NOT_NULL(cnode)); | |||
| } | |||
| } | |||
| @@ -158,7 +161,7 @@ bool AiCpuDynamicKernel::UpdateExtInfo() { | |||
| return false; | |||
| } | |||
| MS_LOG(INFO) << "UpdateExtInfo of " << cnode_ptr_->fullname_with_scope() << " end"; | |||
| MS_LOG(INFO) << "UpdateExtInfo of " << cnode->fullname_with_scope() << " end"; | |||
| return true; | |||
| } | |||
| @@ -196,12 +199,14 @@ bool AiCpuDynamicKernel::UpdateOutputShapeFromExtInfo() { | |||
| shapes.emplace_back(size_t_shape); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, cnode_ptr_.get()); | |||
| AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, cnode_ptr_.lock().get()); | |||
| return true; | |||
| } | |||
| void AiCpuDynamicKernel::PostExecute() { | |||
| MS_LOG(INFO) << "Aicpu " << cnode_ptr_->fullname_with_scope() << " PostExecute"; | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "Aicpu " << cnode->fullname_with_scope() << " PostExecute"; | |||
| if (unknow_type_ != DEPEND_COMPUTE) { | |||
| return; | |||
| } | |||
| @@ -209,7 +214,7 @@ void AiCpuDynamicKernel::PostExecute() { | |||
| MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error."; | |||
| return; | |||
| } | |||
| if (AnfAlgo::IsDynamicShape(cnode_ptr_) && unknow_type_ == DEPEND_COMPUTE) { | |||
| if (AnfAlgo::IsDynamicShape(cnode) && unknow_type_ == DEPEND_COMPUTE) { | |||
| MS_LOG(INFO) << "Update aicpu kernel output shape from ext_info"; | |||
| UpdateOutputShapeFromExtInfo(); | |||
| } | |||
| @@ -38,13 +38,15 @@ void HcclDynamicKernel::UpdateArgs() { | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Start to UpdateArgs"; | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(cnode); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| // Update input, output, count | |||
| AddressPtrList kernel_inputs; | |||
| AddressPtrList kernel_workspaces; | |||
| AddressPtrList kernel_outputs; | |||
| KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| if (kernel_inputs.empty() || kernel_outputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Inputs or outputs is empty"; | |||
| } | |||
| @@ -58,30 +60,31 @@ void HcclDynamicKernel::UpdateArgs() { | |||
| output_ptr_ = output0->addr; | |||
| std::vector<std::vector<size_t>> hccl_kernel_input_shape_list; | |||
| if (!HcomUtil::GetKernelInputShape(cnode_ptr_, &hccl_kernel_input_shape_list)) { | |||
| if (!HcomUtil::GetKernelInputShape(cnode, &hccl_kernel_input_shape_list)) { | |||
| MS_LOG(EXCEPTION) << "GetKernelInputShape fail!"; | |||
| } | |||
| std::vector<HcclDataType> hccl_data_type_list; | |||
| if (!HcomUtil::GetHcomDataType(cnode_ptr_, &hccl_data_type_list)) { | |||
| if (!HcomUtil::GetHcomDataType(cnode, &hccl_data_type_list)) { | |||
| MS_LOG(EXCEPTION) << "GetHcomDataType fail!"; | |||
| } | |||
| // Update Hccl count | |||
| if (!HcomUtil::GetHcomCount(cnode_ptr_, hccl_data_type_list, hccl_kernel_input_shape_list, &count_)) { | |||
| if (!HcomUtil::GetHcomCount(cnode, hccl_data_type_list, hccl_kernel_input_shape_list, &count_)) { | |||
| MS_LOG(EXCEPTION) << "GetHcomCount fail!"; | |||
| } | |||
| MS_LOG(INFO) << "Update Hccl count:" << count_; | |||
| } | |||
| void HcclDynamicKernel::StaticShapeExecute() { | |||
| MS_EXCEPTION_IF_NULL(cnode_ptr_); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(cnode); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| AddressPtrList kernel_inputs; | |||
| AddressPtrList kernel_workspaces; | |||
| AddressPtrList kernel_outputs; | |||
| KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs); | |||
| kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||
| } | |||
| @@ -30,16 +30,18 @@ namespace mindspore { | |||
| namespace device { | |||
| void DynamicKernel::Initialize() { | |||
| MS_LOG(INFO) << "Init Start"; | |||
| is_dynamic_shape_ = AnfAlgo::IsDynamicShape(cnode_ptr_); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| is_dynamic_shape_ = AnfAlgo::IsDynamicShape(cnode); | |||
| if (!is_dynamic_shape_) { | |||
| MS_LOG(DEBUG) << "cnode is not dynamic shape:" << cnode_ptr_->fullname_with_scope(); | |||
| MS_LOG(DEBUG) << "cnode is not dynamic shape:" << cnode->fullname_with_scope(); | |||
| return; | |||
| } | |||
| is_input_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrInputIsDynamicShape); | |||
| is_output_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrOutputIsDynamicShape); | |||
| is_input_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode, kAttrInputIsDynamicShape); | |||
| is_output_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode, kAttrOutputIsDynamicShape); | |||
| auto ret = abstract::GetDependsFormMap(cnode_ptr_); | |||
| auto ret = abstract::GetDependsFormMap(cnode); | |||
| if (ret.empty()) { | |||
| MS_LOG(DEBUG) << "No dynamic_shape_depends found"; | |||
| return; | |||
| @@ -50,13 +52,15 @@ void DynamicKernel::Initialize() { | |||
| MS_LOG(INFO) << "Init End"; | |||
| } | |||
| int DynamicKernel::GetKernelType() { return AnfAlgo::GetKernelType(cnode_ptr_); } | |||
| int DynamicKernel::GetKernelType() { return AnfAlgo::GetKernelType(cnode_ptr_.lock()); } | |||
| void DynamicKernel::RebuildDependTensor() { | |||
| depend_tensor_map_.clear(); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| for (auto depend : depend_list_) { | |||
| auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, depend); | |||
| auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode_ptr_, depend); | |||
| auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, depend); | |||
| auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode, depend); | |||
| std::vector<int64_t> shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second); | |||
| auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second); | |||
| auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes); | |||
| @@ -72,11 +76,12 @@ void DynamicKernel::InferShape() { | |||
| if (!is_input_dynamic_shape_ && is_output_dynamic_shape_ && !have_depends()) { | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(cnode_ptr_); | |||
| MS_LOG(INFO) << "InferShape start, node:" << cnode_ptr_->fullname_with_scope(); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_LOG(INFO) << "InferShape start, node:" << cnode->fullname_with_scope(); | |||
| InferShapeRecursive(); | |||
| auto inputs = cnode_ptr_->inputs(); | |||
| auto inputs = cnode->inputs(); | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "Invalid inputs"; | |||
| } | |||
| @@ -86,9 +91,9 @@ void DynamicKernel::InferShape() { | |||
| // rebuild depend tensor map for gpu dynamic memory allocation. | |||
| RebuildDependTensor(); | |||
| auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||
| auto input_size = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t i = 0; i < input_size; ++i) { | |||
| auto input_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i); | |||
| auto input_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); | |||
| auto real_input = input_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(real_input); | |||
| @@ -101,12 +106,12 @@ void DynamicKernel::InferShape() { | |||
| real_input->abstract()->set_value(tensor_ptr); | |||
| } | |||
| auto cnode_input = cnode_ptr_->input(i + 1); | |||
| auto cnode_input = cnode->input(i + 1); | |||
| MS_EXCEPTION_IF_NULL(cnode_input); | |||
| if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimTupleGetItem)) { | |||
| auto base_shape = real_input->Shape(); | |||
| if (!base_shape->isa<abstract::TupleShape>()) { | |||
| MS_LOG(EXCEPTION) << "Node:" << cnode_ptr_->fullname_with_scope() | |||
| MS_LOG(EXCEPTION) << "Node:" << cnode->fullname_with_scope() | |||
| << " input is a tuple_get_item but real input node shape is not a TupleShape"; | |||
| } | |||
| auto tuple_ptr = base_shape->cast<abstract::TupleShapePtr>(); | |||
| @@ -124,13 +129,15 @@ void DynamicKernel::InferShape() { | |||
| } | |||
| auto eval_result = opt::CppInferShape(primitive, args_spec_list); | |||
| cnode_ptr_->set_abstract(eval_result); | |||
| cnode->set_abstract(eval_result); | |||
| } | |||
| void DynamicKernel::InferShapeRecursive() { | |||
| auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_); | |||
| auto cnode = cnode_ptr_.lock(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto input_size = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t i = 0; i < input_size; i++) { | |||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i); | |||
| auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode, i); | |||
| auto input_node = input_node_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| InferShapeForNopNode(&input_node); | |||
| @@ -46,7 +46,7 @@ class DynamicKernel { | |||
| bool is_output_dynamic_shape() const { return is_output_dynamic_shape_; } | |||
| bool have_depends() const { return !depend_list_.empty(); } | |||
| virtual void Initialize(); | |||
| std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); } | |||
| std::string GetKernelName() { return cnode_ptr_.lock()->fullname_with_scope(); } | |||
| int GetKernelType(); | |||
| protected: | |||
| @@ -55,7 +55,7 @@ class DynamicKernel { | |||
| void InferShapeForNopNode(AnfNodePtr *input_node); | |||
| void *stream_; | |||
| const CNodePtr cnode_ptr_; | |||
| const CNodeWeakPtr cnode_ptr_; | |||
| bool is_dynamic_shape_; | |||
| bool is_input_dynamic_shape_; | |||
| bool is_output_dynamic_shape_; | |||