|
|
@@ -33,7 +33,25 @@ std::vector<int> TransposeAxis(const std::string &src_format, const std::string |
|
|
} else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) { |
|
|
} else if ((src_format == kOpFormat_NHWC) && (dst_format == kOpFormat_NCHW)) { |
|
|
return {0, 3, 1, 2}; |
|
|
return {0, 3, 1, 2}; |
|
|
} else { |
|
|
} else { |
|
|
MS_LOG(EXCEPTION) << "Invaild format transform, from " << src_format << " to " << dst_format; |
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid format transform, from " << src_format << " to " << dst_format; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Transpose can be replaceed by nop reshape in some situations. |
|
|
|
|
|
// 1. out_shape [x, 1, 1, y] with transpose perm {0, 2, 3, 1} |
|
|
|
|
|
// 2. out_shape [x, y, 1, 1] with transpose perm {0, 3, 1, 2} |
|
|
|
|
|
bool IsFakeTranspose(const std::vector<size_t> &out_shape, const std::vector<int> &transpose_perm) { |
|
|
|
|
|
if (out_shape.size() != 4) { |
|
|
|
|
|
MS_LOG(EXCEPTION) << "Invalid data shape, 4-D data was needed, but get " << out_shape.size() << "-D."; |
|
|
|
|
|
} |
|
|
|
|
|
std::vector<int> perm1 = {0, 2, 3, 1}; |
|
|
|
|
|
std::vector<int> perm2 = {0, 3, 1, 2}; |
|
|
|
|
|
if (transpose_perm == perm1) { |
|
|
|
|
|
return (out_shape[1] == 1 && out_shape[2] == 1); |
|
|
|
|
|
} else if (transpose_perm == perm2) { |
|
|
|
|
|
return (out_shape[2] == 1 && out_shape[3] == 1); |
|
|
|
|
|
} else { |
|
|
|
|
|
return false; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@@ -56,8 +74,16 @@ void SetTransposeOpBuildInfo(const std::string &input_format, const std::string |
|
|
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node, |
|
|
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node, |
|
|
int used_node_index, const std::vector<int> &transpose_perm) { |
|
|
int used_node_index, const std::vector<int> &transpose_perm) { |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
// 1.Create a transpose node. |
|
|
|
|
|
auto transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name()); |
|
|
|
|
|
|
|
|
// 0.Judge whether it is a fake transpose |
|
|
|
|
|
auto transed_shape = AnfAlgo::GetInputDeviceShape(used_node, used_node_index); |
|
|
|
|
|
bool is_fake = IsFakeTranspose(transed_shape, transpose_perm); |
|
|
|
|
|
// 1.Create a transpose node or a fake transpose node:reshape. |
|
|
|
|
|
mindspore::PrimitivePtr transpose_prim; |
|
|
|
|
|
if (is_fake) { |
|
|
|
|
|
transpose_prim = std::make_shared<Primitive>(prim::kPrimReshape->name()); |
|
|
|
|
|
} else { |
|
|
|
|
|
transpose_prim = std::make_shared<Primitive>(prim::kPrimTranspose->name()); |
|
|
|
|
|
} |
|
|
MS_EXCEPTION_IF_NULL(transpose_prim); |
|
|
MS_EXCEPTION_IF_NULL(transpose_prim); |
|
|
// 2.Set the input of transpose. |
|
|
// 2.Set the input of transpose. |
|
|
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node}; |
|
|
std::vector<AnfNodePtr> transpose_input = {NewValueNode(transpose_prim), node}; |
|
|
@@ -66,7 +92,9 @@ CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, co |
|
|
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)}; |
|
|
auto transpose_type = {AnfAlgo::GetPrevNodeOutputInferDataType(used_node, used_node_index)}; |
|
|
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)}; |
|
|
auto transpose_shape = {AnfAlgo::GetPrevNodeOutputInferShape(used_node, used_node_index)}; |
|
|
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get()); |
|
|
AnfAlgo::SetOutputInferTypeAndShape(transpose_type, transpose_shape, transpose_op.get()); |
|
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op); |
|
|
|
|
|
|
|
|
if (!is_fake) { |
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(transpose_perm), transpose_op); |
|
|
|
|
|
} |
|
|
// 4.Set the input of used_node. |
|
|
// 4.Set the input of used_node. |
|
|
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope() |
|
|
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope() |
|
|
<< ", index: " << used_node_index; |
|
|
<< ", index: " << used_node_index; |
|
|
|