| @@ -23,6 +23,18 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool CheckShapeDimInfo(const std::vector<size_t> &shape) { | |||
| if (shape.empty()) { | |||
| return false; | |||
| } | |||
| if (shape.size() == 1 && shape[0] % kCubeSize != 0) { | |||
| return false; | |||
| } | |||
| return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); | |||
| } | |||
| } // namespace | |||
| const BaseRef ReshapeTransposeFusion::DefinePattern() const { | |||
| const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name()); | |||
| VectorRef reshape({prim_reshape, input_varptr_}); | |||
| @@ -38,6 +50,11 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, | |||
| MS_EXCEPTION_IF_NULL(transpose_cnode); | |||
| auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); | |||
| MS_EXCEPTION_IF_NULL(reshape_cnode); | |||
| std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); | |||
| std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); | |||
| if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])}; | |||
| auto new_node = func_graph->NewCNode(inputs); | |||
| @@ -23,6 +23,18 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool CheckShapeDimInfo(const std::vector<size_t> &shape) { | |||
| if (shape.empty()) { | |||
| return false; | |||
| } | |||
| if (shape.size() == 1 && shape[0] % kCubeSize != 0) { | |||
| return false; | |||
| } | |||
| return !(shape.size() >= 2 && (shape[shape.size() - 1] % kCubeSize != 0 || shape[shape.size() - 2] % kCubeSize != 0)); | |||
| } | |||
| } // namespace | |||
| const BaseRef TransposeReshapeFusion::DefinePattern() const { | |||
| const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name()); | |||
| VectorRef transpose({prim::kPrimTranspose, input_varptr_}); | |||
| @@ -38,6 +50,11 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, | |||
| MS_EXCEPTION_IF_NULL(reshape_cnode); | |||
| auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); | |||
| MS_EXCEPTION_IF_NULL(transpose_cnode); | |||
| std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); | |||
| std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); | |||
| if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { | |||
| return nullptr; | |||
| } | |||
| auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName); | |||
| std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])}; | |||
| auto new_node = func_graph->NewCNode(inputs); | |||
| @@ -39,7 +39,7 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) { | |||
| * return transpose | |||
| */ | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before"); | |||
| std::vector<int> shp{2, 4, 8, 16}; | |||
| std::vector<int> shp{2, 2, 16, 16}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| @@ -59,5 +59,26 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) { | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_no_fusion) { | |||
| /* | |||
| * def before(input0, input1): | |||
| * reshape = Reshape(input0, input1) | |||
| * transpose = Transpose(reshape) | |||
| * return transpose | |||
| */ | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before"); | |||
| std::vector<int> shp{2, 4, 8, 16}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::ReshapeTransposeFusion>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -39,7 +39,7 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) { | |||
| * return transpose | |||
| */ | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before"); | |||
| std::vector<int> shp{2, 4, 8, 16}; | |||
| std::vector<int> shp{2, 2, 16, 16}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| @@ -61,5 +61,26 @@ TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_fusion) { | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWTransposeReshapeFusion, test_transpose_reshape_no_fusion) { | |||
| /* | |||
| * def before(input0, input1): | |||
| * reshape = Reshape(input0, input1) | |||
| * transpose = Transpose(reshape) | |||
| * return transpose | |||
| */ | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_transpose_reshape_fusion", "before"); | |||
| std::vector<int> shp{2, 4, 8, 16}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| pm->AddPass(std::make_shared<opt::TransposeReshapeFusion>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(kg); | |||
| EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -36,7 +36,7 @@ def test_reshape_transpose_fusion(tag): | |||
| @fns | |||
| def before(input0): | |||
| reshape = Reshape(input0, (2, 4, 8, 16)) | |||
| reshape = Reshape(input0, (2, 2, 16, 16)) | |||
| transpose = Transpose(reshape, (1, 0, 2, 3)) | |||
| return transpose | |||