|
|
|
@@ -50,9 +50,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, |
|
|
|
MS_EXCEPTION_IF_NULL(reshape_cnode); |
|
|
|
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); |
|
|
|
MS_EXCEPTION_IF_NULL(transpose_cnode); |
|
|
|
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); |
|
|
|
std::vector<size_t> reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); |
|
|
|
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); |
|
|
|
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { |
|
|
|
if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName); |
|
|
|
|