|
|
|
@@ -32,6 +32,7 @@ namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; |
|
|
|
namespace { |
|
|
|
const std::set<std::string> k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; |
|
|
|
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW}; |
|
|
|
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, |
|
|
|
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) { |
|
|
|
@@ -64,20 +65,30 @@ void SetTransNodeAttr(const CNodePtr &trans_node) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { |
|
|
|
AnfNodePtr trans_node = nullptr; |
|
|
|
CNodePtr trans_data = nullptr; |
|
|
|
std::string InitDefaultFormat(const AnfNodePtr &node) { |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// Init |
|
|
|
std::string default_format = kOpFormat_DEFAULT; |
|
|
|
|
|
|
|
if (node->isa<CNode>() && AnfAlgo::HasNodeAttr("io_format", node->cast<CNodePtr>())) { |
|
|
|
auto attr = AnfAlgo::GetNodeAttr<std::string>(node, "io_format"); |
|
|
|
if (attr == kOpFormat_NCDHW) { |
|
|
|
default_format = kOpFormat_NCDHW; |
|
|
|
} |
|
|
|
} else if (node->isa<ValueNode>() || node->isa<Parameter>()) { |
|
|
|
auto out_format = AnfAlgo::GetOutputFormat(node, 0); |
|
|
|
if (k3DFormatSet.find(out_format) != k3DFormatSet.end()) { |
|
|
|
default_format = kOpFormat_NCDHW; |
|
|
|
} |
|
|
|
} |
|
|
|
return default_format; |
|
|
|
} |
|
|
|
|
|
|
|
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, |
|
|
|
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { |
|
|
|
AnfNodePtr trans_node = nullptr; |
|
|
|
CNodePtr trans_data = nullptr; |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
// Init |
|
|
|
std::string default_format = InitDefaultFormat(node); |
|
|
|
AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast<CNodePtr>(), insert_index) : node; |
|
|
|
std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); |
|
|
|
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; |
|
|
|
|