|
|
|
@@ -32,15 +32,15 @@ namespace opt { |
|
|
|
namespace {
|
|
|
|
|
|
|
|
constexpr int kMinDimNeedToTransform = 3;
|
|
|
|
enum FormatTransformDir { ChannelFisrt2ChannelLast = 0, ChannelLast2ChannelFirst };
|
|
|
|
enum FormatTransformDir { ChannelFirst2ChannelLast = 0, ChannelLast2ChannelFirst };
|
|
|
|
|
|
|
|
// get perm between channel-first shape and channel-last shape.
|
|
|
|
// eg. 4D channe-first => channel-last: [0,1,2,3] => [0,2,3,1];
|
|
|
|
// eg. 4D channe-last => channel-first: [0,1,2,3] => [0,3,1,2];
|
|
|
|
// eg. 4D channel-first => channel-last: [0,1,2,3] => [0,2,3,1];
|
|
|
|
// eg. 4D channel-last => channel-first: [0,1,2,3] => [0,3,1,2];
|
|
|
|
std::vector<int64_t> TransposeAxis(const int dim, FormatTransformDir dir) {
|
|
|
|
std::vector<int64_t> axis;
|
|
|
|
axis.resize(dim);
|
|
|
|
if (dir == ChannelFisrt2ChannelLast) {
|
|
|
|
if (dir == ChannelFirst2ChannelLast) {
|
|
|
|
std::iota(axis.begin() + 1, axis.end(), 2);
|
|
|
|
axis[dim - 1] = 1;
|
|
|
|
} else {
|
|
|
|
@@ -52,7 +52,7 @@ std::vector<int64_t> TransposeAxis(const int dim, FormatTransformDir dir) { |
|
|
|
|
|
|
|
CNodePtr InsertTransposeOp(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &used_node,
|
|
|
|
int used_node_index, const std::vector<int64_t> &transpose_perm) {
|
|
|
|
MS_LOG(ERROR) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
|
|
|
MS_LOG(DEBUG) << "Node: " << node->fullname_with_scope() << ", used node: " << used_node->fullname_with_scope()
|
|
|
|
<< ", index: " << used_node_index;
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
// 1.Create a transpose node or a fake transpose node:reshape.
|
|
|
|
@@ -127,7 +127,7 @@ void InsertTransformOpForInput(const FuncGraphPtr &graph, const AnfNodePtr &node |
|
|
|
}
|
|
|
|
auto input_node = AnfAlgo::GetInputNode(utils::cast<CNodePtr>(node), i);
|
|
|
|
MS_EXCEPTION_IF_NULL(input_node);
|
|
|
|
auto transpose_perm = TransposeAxis(dim, ChannelFisrt2ChannelLast);
|
|
|
|
auto transpose_perm = TransposeAxis(dim, ChannelFirst2ChannelLast);
|
|
|
|
auto transpose_op = InsertTransposeOp(graph, input_node, node, i, transpose_perm);
|
|
|
|
SetTransposeOpBuildInfo(kOpFormat_DEFAULT, inputs_format[i], transpose_op);
|
|
|
|
}
|
|
|
|
|