|
|
|
@@ -39,7 +39,7 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) { |
|
|
|
* return transpose |
|
|
|
*/ |
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before"); |
|
|
|
std::vector<int> shp{2, 4, 8, 16}; |
|
|
|
std::vector<int> shp{2, 2, 16, 16}; |
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); |
|
|
|
AbstractBasePtrList args_spec_list{x_abstract}; |
|
|
|
auto kg = GetKernelGraph(g, args_spec_list); |
|
|
|
@@ -59,5 +59,26 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) { |
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "after"); |
|
|
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_no_fusion) { |
|
|
|
/* |
|
|
|
* def before(input0, input1): |
|
|
|
* reshape = Reshape(input0, input1) |
|
|
|
* transpose = Transpose(reshape) |
|
|
|
* return transpose |
|
|
|
*/ |
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before"); |
|
|
|
std::vector<int> shp{2, 4, 8, 16}; |
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp); |
|
|
|
AbstractBasePtrList args_spec_list{x_abstract}; |
|
|
|
auto kg = GetKernelGraph(g, args_spec_list); |
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>(); |
|
|
|
auto pm = std::make_shared<opt::PassManager>(); |
|
|
|
pm->AddPass(std::make_shared<opt::ReshapeTransposeFusion>()); |
|
|
|
optimizer->AddPassManager(pm); |
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(kg); |
|
|
|
EXPECT_TRUE(CheckEqualGraph(kg, new_graph)); |
|
|
|
} |
|
|
|
} // namespace opt |
|
|
|
} // namespace mindspore |