|
|
|
@@ -595,6 +595,9 @@ std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, |
|
|
|
std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const { |
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
std::string format = kOpFormat_NCHW; |
|
|
|
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { |
|
|
|
format = kOpFormat_NCDHW; |
|
|
|
} |
|
|
|
if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { |
|
|
|
format = AnfAlgo::GetInputFormat(anf_node, real_index); |
|
|
|
if (format == kOpFormat_FRAC_Z) { |
|
|
|
@@ -603,9 +606,6 @@ std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_nod |
|
|
|
format = kOpFormat_NCHW; |
|
|
|
} |
|
|
|
} |
|
|
|
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { |
|
|
|
format = kOpFormat_NCDHW; |
|
|
|
} |
|
|
|
return format; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -637,6 +637,9 @@ std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node |
|
|
|
std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const { |
|
|
|
MS_EXCEPTION_IF_NULL(anf_node); |
|
|
|
std::string format = kOpFormat_NCHW; |
|
|
|
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { |
|
|
|
format = kOpFormat_NCDHW; |
|
|
|
} |
|
|
|
if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) { |
|
|
|
format = AnfAlgo::GetOutputFormat(anf_node, real_index); |
|
|
|
if (format == kOpFormat_FRAC_Z) { |
|
|
|
@@ -645,9 +648,6 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no |
|
|
|
format = kOpFormat_NCHW; |
|
|
|
} |
|
|
|
} |
|
|
|
if (anf_node->isa<CNode>() && IsNeedChangeDefaultFormat(anf_node->cast<CNodePtr>())) { |
|
|
|
format = kOpFormat_NCDHW; |
|
|
|
} |
|
|
|
return format; |
|
|
|
} |
|
|
|
|
|
|
|
|