浏览代码

!7470 fix bug of filter kernel info

Merge pull request !7470 from lianliguang/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 年前
父节点
当前提交
6dd56e8d8b
共有 1 个文件被更改,包括 12 次插入10 次删除
  1. +12
    -10
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc

+ 12
- 10
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc 查看文件

@@ -15,23 +15,24 @@
*/

#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h"
#include <memory>
#include <map>
#include <memory>
#include <set>
#include <utility>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/oplib/oplib.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
#include "nlohmann/json.hpp"
#include "backend/optimizer/common/helper.h"
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h"
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_build_client.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "nlohmann/json.hpp"

namespace mindspore {
namespace kernel {
@@ -276,12 +277,13 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const
MS_LOG(INFO) << "Warning: Server not support format with C04 suffix.";
return false;
}
if (format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) {
return true;
}
// not support format:
// 1 NDHWC with shape size != 5
// 2 FRAC_NZ with shape size < 2
// 3 !NDHWC with shape size > 4
if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) ||
(format == kOpFormat_FRAC_NZ && shape.size() < kShape2dDims) ||
(format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) {
MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size();
return false;


正在加载...
取消
保存