diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 957f1f7ca6..1c704a60d9 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -375,7 +375,7 @@ std::vector AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodeP MS_LOG(EXCEPTION) << "Not real kernel:" << "#node [" << node->DebugString() << "]"; } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -389,7 +389,7 @@ std::vector AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePt MS_LOG(EXCEPTION) << "Not real kernel:" << "#node [" << node->DebugString() << "]"; } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -403,7 +403,7 @@ std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) { MS_LOG(EXCEPTION) << "Not real kernel:" << "#node [" << node->DebugString() << "]"; } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -421,7 +421,7 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t if (!AnfAlgo::IsRealKernel(node)) { return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -443,7 +443,7 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i if (!IsRealKernel(node)) { return GetPrevNodeOutputFormat(node, input_idx); } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -549,7 +549,7 @@ std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &nod if (!IsRealKernel(node)) { return GetPrevNodeOutputReshapeType(node, input_idx); } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -568,7 +568,7 @@ std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &no if (!IsRealKernel(node)) { return GetPrevNodeOutputReshapeType(node, output_idx); } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -624,7 +624,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size if (!IsRealKernel(node)) { return GetPrevNodeOutputDeviceDataType(node, output_idx); } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -645,7 +645,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_ if (!IsRealKernel(node)) { return GetPrevNodeOutputDeviceDataType(node, 0); } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -675,7 +675,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; } } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetOutputAddr(output_idx); if (addr == nullptr) { @@ -697,7 +697,8 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; } } - auto kernel_info = dynamic_cast(node->kernel_info()); + // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice` + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetMutableOutputAddr(output_idx); if (addr == nullptr) { @@ -710,11 +711,8 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod // get output device addr of anf_node bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); - if (output_idx > GetOutputTensorNum(node)) { - MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " - << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; - } - auto kernel_info = dynamic_cast(node->kernel_info()); + // Critical path performance optimization: `KernelInfo` is unique subclass of `KernelInfoDevice` + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->OutputAddrExist(output_idx); } @@ -734,7 +732,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNode // set output device addr of anf_node void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); if (!kernel_info->SetOutputAddr(addr, output_idx)) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; @@ -744,7 +742,7 @@ void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t out // set workspace device addr of anf_node void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; @@ -754,7 +752,7 @@ void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t // get workspace device addr of anf_node DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetWorkspaceAddr(output_idx); if (addr == nullptr) { @@ -767,7 +765,7 @@ DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, siz // get workspace device mutable addr of anf_node DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableWorkspaceAddr(const AnfNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetMutableWorkspaceAddr(index); if (addr == nullptr) { @@ -810,7 +808,7 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_ kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); // select_kernel_build_info() has checked whether return pointer is null auto build_info = kernel_info->select_kernel_build_info(); @@ -821,7 +819,7 @@ kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { // get KernelBuildType of node, such as ATT,RT,FWK and so on KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); // select_kernel_build_info() has checked whether return pointer is null auto build_info = kernel_info->select_kernel_build_info(); @@ -831,7 +829,7 @@ KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -840,7 +838,7 @@ kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -850,7 +848,7 @@ kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { // set select kernel_build_info void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->set_select_kernel_build_info(select_kernel_build_info); } @@ -858,7 +856,7 @@ void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &sel // get select kernel_build_info KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->GetMutableSelectKernelBuildInfo(); } @@ -866,7 +864,7 @@ KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePt // get kernelMode KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->MutableKernelMod(); } @@ -874,7 +872,7 @@ KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { // set kernel mod void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_kernel_mod(kernel_mod); } @@ -940,42 +938,42 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_stream_id(stream_id); } uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->stream_id(); } void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_stream_distinction_label(stream_label); } uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->stream_distinction_label(); } void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_graph_id(graph_id); } uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->graph_id(); } @@ -1003,7 +1001,7 @@ bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { if (node->isa()) { return false; } - auto kernel_info = dynamic_cast(node->kernel_info()); + auto kernel_info = static_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->is_feature_map(); } diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 6c3fe80155..9cfffd2a3c 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -281,6 +281,11 @@ void GPUKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std::v } void GPUKernelRuntime::AllocInplaceNodeMemory(const session::KernelGraph *graph) { + if (is_alloc_inplace_res_[graph->graph_id()]) { + return; + } + is_alloc_inplace_res_[graph->graph_id()] = true; + std::map> inplace_groups; auto kernel_cnodes = graph->execution_order(); for (auto &kernel : kernel_cnodes) { @@ -921,6 +926,11 @@ bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::K } void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph *graph) { + if (is_alloc_communication_res_[graph->graph_id()]) { + return; + } + is_alloc_communication_res_[graph->graph_id()] = true; + MS_EXCEPTION_IF_NULL(graph); auto &kernels = graph->execution_order(); for (auto &kernel : kernels) { @@ -1031,6 +1041,11 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) } } + auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, i); + if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) { + continue; + } + auto kernel_ref_count_ptr = mem_reuse_util_->GetKernelInputRef(cnode, i); if (kernel_ref_count_ptr == nullptr) { continue; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index 55c977cd31..a4e299586d 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -101,6 +101,8 @@ class GPUKernelRuntime : public KernelRuntime { std::unordered_map mem_swap_map_; std::unordered_map is_first_step_map_; std::unordered_map> graph_output_map_; + std::unordered_map is_alloc_communication_res_; + std::unordered_map is_alloc_inplace_res_; MemReuseUtilPtr mem_reuse_util_{nullptr}; MemSwapManagerPtr mem_swap_manager_{nullptr};