Browse Source

Fix fusion condition of transpose and reshape

tags/v0.5.0-beta
yujianfeng 5 years ago
parent
commit
94818cf255
2 changed files with 3 additions and 3 deletions
  1. +2
    -2
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc
  2. +1
    -1
      tests/ut/cpp/python_input/gtest_input/pre_activate/transpose_reshape_fusion_test.py

+ 2
- 2
mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc View File

@@ -50,9 +50,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(reshape_cnode);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(transpose_cnode);
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
std::vector<size_t> reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0);
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);


+ 1
- 1
tests/ut/cpp/python_input/gtest_input/pre_activate/transpose_reshape_fusion_test.py View File

@@ -38,7 +38,7 @@ def test_transpose_reshape_fusion(tag):
@fns
def before(x):
transpose = Transpose(x, (1, 0, 2, 3))
reshape = Reshape(transpose, (2, 4, 8, 16))
reshape = Reshape(transpose, (2, 2, 16, 16))
return reshape

@fns


Loading…
Cancel
Save