|
|
|
@@ -101,9 +101,9 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP |
|
|
|
auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second); |
|
|
|
auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index); |
|
|
|
auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index); |
|
|
|
auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, 0); |
|
|
|
auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index); |
|
|
|
// insert trans |
|
|
|
if (origin_format != cur_format) { |
|
|
|
if (origin_format != cur_format && cur_shape.size() > 1) { |
|
|
|
auto kernel_select = std::make_shared<KernelSelect>(); |
|
|
|
final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format, |
|
|
|
kTransDataOpName, false); |
|
|
|
|