|
|
|
@@ -15,23 +15,24 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" |
|
|
|
#include <memory> |
|
|
|
|
|
|
|
#include <map> |
|
|
|
#include <memory> |
|
|
|
#include <set> |
|
|
|
#include <utility> |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "backend/kernel_compiler/oplib/oplib.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" |
|
|
|
#include "nlohmann/json.hpp" |
|
|
|
#include "backend/optimizer/common/helper.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h" |
|
|
|
#include "frontend/parallel/ops_info/ops_utils.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h" |
|
|
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" |
|
|
|
#include "backend/optimizer/common/helper.h" |
|
|
|
#include "backend/session/anf_runtime_algorithm.h" |
|
|
|
#include "backend/session/kernel_build_client.h" |
|
|
|
#include "frontend/parallel/ops_info/ops_utils.h" |
|
|
|
#include "nlohmann/json.hpp" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace kernel { |
|
|
|
@@ -276,12 +277,13 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const |
|
|
|
MS_LOG(INFO) << "Warning: Server not support format with C04 suffix."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
// not support format: |
|
|
|
// 1 NDHWC with shape size != 5 |
|
|
|
// 2 FRAC_NZ with shape size < 2 |
|
|
|
// 3 !NDHWC with shape size > 4 |
|
|
|
if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || |
|
|
|
(format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) || |
|
|
|
(format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { |
|
|
|
MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); |
|
|
|
return false; |
|
|
|
|