From 70c2920615c8acfc19ccd0f356369188ba1ccab5 Mon Sep 17 00:00:00 2001 From: William Lian Date: Mon, 19 Oct 2020 19:48:56 +0800 Subject: [PATCH] fix bug of filter kernel info --- .../tbe_kernel_select/tbe_kernel_select.cc | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index a4a41042e0..6b7db95e5d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -15,23 +15,24 @@ */ #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" -#include + #include +#include #include #include -#include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/oplib/oplib.h" -#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" -#include "nlohmann/json.hpp" -#include "backend/optimizer/common/helper.h" #include "backend/kernel_compiler/tbe/tbe_convert_utils.h" -#include "frontend/parallel/ops_info/ops_utils.h" +#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" +#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" -#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h" -#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" +#include "backend/optimizer/common/helper.h" +#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/kernel_build_client.h" +#include "frontend/parallel/ops_info/ops_utils.h" +#include "nlohmann/json.hpp" namespace mindspore { namespace kernel { @@ -276,12 +277,13 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector &shape, const MS_LOG(INFO) << "Warning: Server not support format with C04 suffix."; return false; } + if (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) { + return true; + } // not support format: // 1 NDHWC with shape size != 5 - // 2 FRAC_NZ with shape size < 2 // 3 !NDHWC with shape size > 4 if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || - (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) || (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); return false;