/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "backend/kernel_compiler/kernel_query.h" #include #include #include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h" #include "backend/kernel_compiler/rts/rt_kernel_info.h" #include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" #include "backend/kernel_compiler/akg/akg_kernel_metadata.h" #include "backend/session/anf_runtime_algorithm.h" #include "utils/context/ms_context.h" namespace mindspore { namespace kernel { namespace { void FilterInvalidKernelInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_info_list); std::vector> filtered_list; (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), [&kernel_node](const std::shared_ptr &kernel_build_info) { return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); }); if (!filtered_list.empty()) { kernel_info_list->clear(); (void)std::copy(filtered_list.begin(), filtered_list.end(), std::back_inserter(*kernel_info_list)); } else { MS_LOG(INFO) << "All kernel Info list does not match any kernel info "; for (size_t index = 0; index < kernel_info_list->size(); ++index) { std::ostringstream buffer; auto kernel_info = kernel_info_list->at(index); MS_EXCEPTION_IF_NULL(kernel_info); if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info->GetOutputNum()) { buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" << " cannot match the kernel's output size [" << kernel_info->GetOutputNum() << "]"; } else { buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]" << " cannot match the kernel's output size [" << kernel_info->GetInputNum() << "]"; } MS_LOG(INFO) << "kernel [ " << index << " ] :" << kernel_info->ToString() << buffer.str(); } kernel_info_list->clear(); MS_LOG(INFO) << "node" << kernel_node->DebugString() << "'s output size : [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; } } } // namespace void KernelQueryAll(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_info_list); TbeMetadataInfo(kernel_node, kernel_info_list); if (kernel_info_list->empty()) { AicpuMetadataInfo(kernel_node, kernel_info_list); if (!kernel_info_list->empty()) { MS_LOG(INFO) << "The node [" << kernel_node->DebugString() << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); } } if (kernel_info_list->empty()) { GetRtKelInfo(kernel_node, kernel_info_list); } if (kernel_info_list->empty()) { HcclMetadataInfo(kernel_node, kernel_info_list); } if (kernel_info_list->empty()) { MS_LOG(EXCEPTION) << "Op " << kernel_node->DebugString() << "kernel query fail!"; } } void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list, KernelType kernel_type) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_info_list); std::string op_name = AnfAlgo::GetCNodeName(kernel_node); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); if (context_ptr->enable_graph_kernel() && IsPrimitiveCNode(kernel_node, prim::kPrimBatchMatMul)) { kernel_type = KernelType::AKG_KERNEL; } switch (kernel_type) { case KernelType::AKG_KERNEL: AkgMetadataInfo(kernel_node, kernel_info_list); break; default: KernelQueryAll(kernel_node, kernel_info_list); break; } if (kernel_info_list->empty()) { MS_EXCEPTION(NotExistsError) << "Op[" << kernel_node->DebugString() << "] kernel query fail!"; } // check output FilterInvalidKernelInfo(kernel_node, kernel_info_list); } void AICPUQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_info_list); kernel_info_list->clear(); AicpuMetadataInfo(kernel_node, kernel_info_list); FilterInvalidKernelInfo(kernel_node, kernel_info_list); } bool IsSupportedByAICPU(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(select_kernel_build_info); std::vector> kernel_info_list; auto cnode = kernel_node->cast(); MS_EXCEPTION_IF_NULL(cnode); AICPUQuery(cnode, &kernel_info_list); return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { MS_EXCEPTION_IF_NULL(item); return *item == *select_kernel_build_info; }); } bool IsSupportedByAICore(const AnfNodePtr &kernel_node, const KernelBuildInfoPtr &select_kernel_build_info) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(select_kernel_build_info); std::vector> kernel_info_list; auto cnode = kernel_node->cast(); MS_EXCEPTION_IF_NULL(cnode); TbeMetadataInfo(cnode, &kernel_info_list); return std::any_of(kernel_info_list.begin(), kernel_info_list.end(), [&select_kernel_build_info](const kernel::KernelBuildInfoPtr item) { MS_EXCEPTION_IF_NULL(item); return *item == *select_kernel_build_info; }); } } // namespace kernel } // namespace mindspore