|
|
|
@@ -581,6 +581,7 @@ bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &for |
|
|
|
|
|
|
|
bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); |
|
|
|
const size_t kCAxis = 1; |
|
|
|
for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) { |
|
|
|
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index); |
|
|
|
@@ -593,6 +594,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: |
|
|
|
if (!IsShapeMatchFormat(output_shape, kernel_build_info.GetOutputFormat(index))) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (kernel_name == "ReduceMean") { |
|
|
|
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims); |
|
|
|
if (!keep_dims && kernel_build_info.GetOutputFormat(index) != kOpFormat_DEFAULT) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) { |
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); |
|
|
|
@@ -605,6 +612,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel:: |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (kernel_name == "ReduceMean") { |
|
|
|
auto keep_dims = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrKeepDims); |
|
|
|
if (!keep_dims && kernel_build_info.GetInputFormat(index) != kOpFormat_DEFAULT) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) { |
|
|
|
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) && |
|
|
|
|