| @@ -44,6 +44,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, | |||||
| MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); | MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); | ||||
| MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); | MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); | ||||
| } | } | ||||
| kernel_info_list->clear(); | |||||
| MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : [" | MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : [" | ||||
| << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" | << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" | ||||
| << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; | << "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !"; | ||||
| @@ -54,11 +55,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| MS_EXCEPTION_IF_NULL(kernel_info_list); | MS_EXCEPTION_IF_NULL(kernel_info_list); | ||||
| TbeMetadataInfo(kernel_node, kernel_info_list); | TbeMetadataInfo(kernel_node, kernel_info_list); | ||||
| FilterInvalidKernelInfo(kernel_node, kernel_info_list); | |||||
| if (kernel_info_list->empty()) { | if (kernel_info_list->empty()) { | ||||
| AicpuMetadataInfo(kernel_node, kernel_info_list); | AicpuMetadataInfo(kernel_node, kernel_info_list); | ||||
| if (!kernel_info_list->empty()) { | if (!kernel_info_list->empty()) { | ||||
| MS_LOG(INFO) << "Warning The node [" << kernel_node->DebugString() | |||||
| << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | |||||
| MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() | |||||
| << "] cannot find valid TBE kernel info, try to get aicpu kernel info"; | |||||
| AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); | AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); | ||||
| } | } | ||||
| } | } | ||||
| @@ -581,6 +581,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for | |||||
| bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| const size_t kCAxis = 1; | const size_t kCAxis = 1; | ||||
| for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { | ||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); | ||||
| @@ -593,6 +594,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: | |||||
| if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { | if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (kernel_name == "ReduceMean") { | |||||
| auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims); | |||||
| if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); | ||||
| @@ -605,6 +612,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (kernel_name == "ReduceMean") { | |||||
| auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims); | |||||
| if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { | ||||
| return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && | ||||