Browse Source

add reducemean's kernel select rules

tags/v0.5.0-beta
WilliamLian 5 years ago
parent
commit
cdacd5ca76
2 changed files with 17 additions and 2 deletions
  1. +4
    -2
      mindspore/ccsrc/kernel/kernel_query.cc
  2. +13
    -0
      mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc

+ 4
- 2
mindspore/ccsrc/kernel/kernel_query.cc View File

@@ -44,6 +44,7 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
MS_EXCEPTION_IF_NULL(kernel_info_list->at(index));
MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString();
}
kernel_info_list->clear();
MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : ["
<< AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
<< "input size : [" << AnfAlgo::GetInputTensorNum(kernel_node) << "] cannot match any kernelInfo !";
@@ -54,11 +55,12 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
MS_EXCEPTION_IF_NULL(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_info_list);
TbeMetadataInfo(kernel_node, kernel_info_list);
FilterInvalidKernelInfo(kernel_node, kernel_info_list);
if (kernel_info_list->empty()) {
AicpuMetadataInfo(kernel_node, kernel_info_list);
if (!kernel_info_list->empty()) {
MS_LOG(INFO) << "Warning The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
}
}


+ 13
- 0
mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc View File

@@ -581,6 +581,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for

bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
const size_t kCAxis = 1;
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
@@ -593,6 +594,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) {
return false;
}
if (kernel_name == "ReduceMean") {
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) {
return false;
}
}
}
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
@@ -605,6 +612,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
}
return false;
}
if (kernel_name == "ReduceMean") {
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims);
if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) {
return false;
}
}
}
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&


Loading…
Cancel
Save