|
|
|
@@ -55,6 +55,24 @@ bool RunOpInsertTransData::InsertTransdataForOutput(const FuncGraphPtr &graph) { |
|
|
|
return changed; |
|
|
|
} |
|
|
|
|
|
|
|
bool RunOpInsertTransData::ConvertNodeFormat(const FuncGraphPtr &graph, const AnfNodePtr &node, |
|
|
|
const std::string &format, size_t insert_index, size_t input_index, |
|
|
|
bool is_insert) const { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
bool changed = false; |
|
|
|
// convert the format of node to default |
|
|
|
if (kCommonFormatSet.find(format) == kCommonFormatSet.end() && (input_size_ > 1 || format == kOpFormat_ND_RNN_BIAS)) { |
|
|
|
auto input_node = (!is_insert) ? common::AnfAlgo::GetInputNode(cnode, input_index) : node; |
|
|
|
auto trans_node = AddTransOpNodeToGraph(graph, input_node, kernel_select_, insert_index, is_insert); |
|
|
|
common::AnfAlgo::SetNodeInput(cnode, trans_node, input_index); |
|
|
|
changed = true; |
|
|
|
} |
|
|
|
return changed; |
|
|
|
} |
|
|
|
|
|
|
|
bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
bool changed = false; |
|
|
|
@@ -65,29 +83,19 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) { |
|
|
|
if (!node->cast<CNodePtr>() || !AnfUtils::IsRealKernel(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode); |
|
|
|
size_t input_num = common::AnfAlgo::GetInputTensorNum(node); |
|
|
|
for (size_t index = 0; index < input_num; ++index) { |
|
|
|
auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); |
|
|
|
auto prev_node_out_infer_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); |
|
|
|
auto input_format = AnfAlgo::GetInputFormat(cnode, index); |
|
|
|
auto input_node = common::AnfAlgo::GetInputNode(cnode, index); |
|
|
|
// convert the format of node's input node to default |
|
|
|
if (kCommonFormatSet.find(prev_input_format) == kCommonFormatSet.end() && |
|
|
|
(prev_node_out_infer_shape.size() > 1 || prev_input_format == kOpFormat_ND_RNN_BIAS)) { |
|
|
|
auto trans_node = AddTransOpNodeToGraph(graph, input_node, kernel_select_, 0, false); |
|
|
|
common::AnfAlgo::SetNodeInput(cnode, trans_node, index); |
|
|
|
has_changed = true; |
|
|
|
} |
|
|
|
// convert node's output format |
|
|
|
if (kCommonFormatSet.find(input_format) == kCommonFormatSet.end() && |
|
|
|
(prev_node_out_infer_shape.size() > 1 || input_format == kOpFormat_ND_RNN_BIAS)) { |
|
|
|
auto trans_node = AddTransOpNodeToGraph(graph, cnode, kernel_select_, index, true); |
|
|
|
common::AnfAlgo::SetNodeInput(cnode, trans_node, index); |
|
|
|
has_changed = true; |
|
|
|
} |
|
|
|
auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(node, index); |
|
|
|
auto prev_node_out_infer_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(node, index); |
|
|
|
input_size_ = prev_node_out_infer_shape.size(); |
|
|
|
auto input_format = AnfAlgo::GetInputFormat(node, index); |
|
|
|
// convert the format of node's input or output |
|
|
|
auto input_changed = ConvertNodeFormat(graph, node, prev_input_format, 0, index, false); |
|
|
|
auto output_changed = ConvertNodeFormat(graph, node, input_format, index, index, true); |
|
|
|
has_changed = input_changed || output_changed; |
|
|
|
} |
|
|
|
if (has_changed) { |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
auto kernel_graph = graph->cast<KernelGraphPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
auto new_node = kernel_graph->NewCNode(cnode); |
|
|
|
|