Browse Source

!12393 update inout format for kernel json

From: @liubuyu
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
1f9a5503d6
1 changed files with 6 additions and 6 deletions
  1. +6
    -6
      mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc

+ 6
- 6
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc View File

@@ -595,6 +595,9 @@ std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node,
std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
std::string format = kOpFormat_NCHW;
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
format = kOpFormat_NCDHW;
}
if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) {
format = AnfAlgo::GetInputFormat(anf_node, real_index);
if (format == kOpFormat_FRAC_Z) {
@@ -603,9 +606,6 @@ std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_nod
format = kOpFormat_NCHW;
}
}
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
format = kOpFormat_NCDHW;
}
return format;
}

@@ -637,6 +637,9 @@ std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node
std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
MS_EXCEPTION_IF_NULL(anf_node);
std::string format = kOpFormat_NCHW;
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
format = kOpFormat_NCDHW;
}
if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) {
format = AnfAlgo::GetOutputFormat(anf_node, real_index);
if (format == kOpFormat_FRAC_Z) {
@@ -645,9 +648,6 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no
format = kOpFormat_NCHW;
}
}
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) {
format = kOpFormat_NCDHW;
}
return format;
}



Loading…
Cancel
Save