Browse Source

!13002 3d format bug fix

From: @liubuyu
Reviewed-by: @zhoufeng54,@kisnwang
Signed-off-by: @kisnwang
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
fa4c19f938
2 changed files with 4 additions and 3 deletions
  1. +2
    -2
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc
  2. +2
    -1
      mindspore/ccsrc/utils/utils.h

+ 2
- 2
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc View File

@@ -241,8 +241,8 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector<size_t> &shape, const
return true;
}
// not support format:
// 1 NCDHW with shape size != 5
if (format == kOpFormat_NCDHW && shape.size() != kShape5dDims) {
// 1 3d formats with shape size > 5
if (k3DFormatSet.find(format) != k3DFormatSet.end() && shape.size() > kShape5dDims) {
return false;
}
return true;


+ 2
- 1
mindspore/ccsrc/utils/utils.h View File

@@ -518,7 +518,8 @@ const std::set<TypeId> kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat
const std::set<std::string> kComputeDepend = {kUniqueOpName, kComputeAccidentalHitsOpName, kSubAndFilterOpName,
kPadAndShiftOpName, kCTCGreedyDecoderOpName};

const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D,
kOpFormat_NDHWC, kOpFormat_DHWCN, kOpFormat_DHWNC};

const std::set<std::string> DynamicShapeConstInputToAttr = {
kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName,


Loading…
Cancel
Save