Browse Source

!1108 Add input shape condition for transpose_reshape fusion pass

Merge pull request !1108 from YuJianfeng/master
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
168dfb2555
5 changed files with 79 additions and 3 deletions
  1. +17
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc
  2. +17
    -0
      mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc
  3. +22
    -1
      tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc
  4. +22
    -1
      tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc
  5. +1
    -1
      tests/ut/cpp/python_input/gtest_input/pre_activate/reshape_transpose_fusion_test.py

+ 17
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/reshape_transpose_fusion.cc View File

@@ -23,6 +23,18 @@

namespace mindspore {
namespace opt {
namespace {
bool CheckShapeDimInfo(const std::vector<size_t> &shape) {
if (shape.empty()) {
return false;
}
if (shape.size() == 1 && shape[0] % kCubeSize != 0) {
return false;
}
return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0));
}
} // namespace

const BaseRef ReshapeTransposeFusion::DefinePattern() const {
const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name());
VectorRef reshape({prim_reshape, input_varptr_});
@@ -38,6 +50,11 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(transpose_cnode);
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(reshape_cnode);
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);


+ 17
- 0
mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_reshape_fusion.cc View File

@@ -23,6 +23,18 @@

namespace mindspore {
namespace opt {
namespace {
bool CheckShapeDimInfo(const std::vector<size_t> &shape) {
if (shape.empty()) {
return false;
}
if (shape.size() == 1 && shape[0] % kCubeSize != 0) {
return false;
}
return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0));
}
} // namespace

const BaseRef TransposeReshapeFusion::DefinePattern() const {
const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name());
VectorRef transpose({prim::kPrimTranspose, input_varptr_});
@@ -38,6 +50,11 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(reshape_cnode);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(transpose_cnode);
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);


+ 22
- 1
tests/ut/cpp/pre_activate/ascend/ir_fusion/reshape_transpose_fusion_test.cc View File

@@ -39,7 +39,7 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) {
* return transpose
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
std::vector<int> shp{2, 2, 16, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
@@ -59,5 +59,26 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}

TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_no_fusion) {
/*
* def before(input0, input1):
* reshape = Reshape(input0, input1)
* transpose = Transpose(reshape)
* return transpose
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::ReshapeTransposeFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
}
} // namespace opt
} // namespace mindspore

+ 22
- 1
tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_reshape_fusion_test.cc View File

@@ -39,7 +39,7 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) {
* return transpose
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
std::vector<int> shp{2, 2, 16, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
@@ -61,5 +61,26 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) {
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}

TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_no_fusion) {
/*
* def before(input0, input1):
* reshape = Reshape(input0, input1)
* transpose = Transpose(reshape)
* return transpose
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before");
std::vector<int> shp{2, 4, 8, 16};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::TransposeReshapeFusion>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kg);
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
}
} // namespace opt
} // namespace mindspore

+ 1
- 1
tests/ut/cpp/python_input/gtest_input/pre_activate/reshape_transpose_fusion_test.py View File

@@ -36,7 +36,7 @@ def test_reshape_transpose_fusion(tag):

@fns
def before(input0):
reshape = Reshape(input0, (2, 4, 8, 16))
reshape = Reshape(input0, (2, 2, 16, 16))
transpose = Transpose(reshape, (1, 0, 2, 3))
return transpose



Loading…
Cancel
Save