|
|
|
@@ -52,28 +52,6 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
TEST_F(TestHWInsertMemcpyForHccl, test_cond1) { |
|
|
|
get_py_fun_.SetDoResolve(true); |
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before1"); |
|
|
|
ASSERT_TRUE(g != nullptr); |
|
|
|
std::vector<int> shp_x{1, 64, 112, 112}; |
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); |
|
|
|
AbstractBasePtrList args_spec_list{x_abstract}; |
|
|
|
auto kg = GetKernelGraph(g, args_spec_list); |
|
|
|
EXPECT_NE(kg, nullptr); |
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>(); |
|
|
|
auto pm = std::make_shared<opt::PassManager>(); |
|
|
|
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); |
|
|
|
pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); |
|
|
|
pm->AddPass(pass); |
|
|
|
optimizer->AddPassManager(pm); |
|
|
|
auto new_graph = optimizer->Optimize(kg); |
|
|
|
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "after"); |
|
|
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); |
|
|
|
} |
|
|
|
|
|
|
|
TEST_F(TestHWInsertMemcpyForHccl, test_cond1_no_insert) { |
|
|
|
get_py_fun_.SetDoResolve(true); |
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before2"); |
|
|
|
|