Browse Source

!7147 insert memcpy in hccl input

Merge pull request !7147 from jjfeing/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7af0d3374f
2 changed files with 17 additions and 23 deletions
  1. +17
    -1
      mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc
  2. +0
    -22
      tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc

+ 17
- 1
mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc View File

@@ -94,6 +94,22 @@ void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &m
}
}
}
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) {
if (node_users.size() == 1) {
MS_LOG(INFO) << "This node only used once, no need to insert memcpy node.";
return false;
}
for (const auto &node_pair : node_users) {
auto node = node_pair.first;
if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) {
MS_LOG(INFO) << "This node only used other real kernel: " << node->fullname_with_scope();
return true;
}
}
MS_LOG(INFO) << "This node used by other node, but the node is not real kernel, no need to insert memcpy node.";
return false;
}
} // namespace

bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input,
@@ -126,7 +142,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
if (iter == node_users.end()) {
MS_LOG(EXCEPTION) << "node has no output in manager";
}
if (iter->second.size() > 1) {
if (IsNodeOutPutUsedByOtherRealKernel(iter->second)) {
return true;
}
}


+ 0
- 22
tests/ut/cpp/pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op_test.cc View File

@@ -52,28 +52,6 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery {
}
};

TEST_F(TestHWInsertMemcpyForHccl, test_cond1) {
get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before1");
ASSERT_TRUE(g != nullptr);
std::vector<int> shp_x{1, 64, 112, 112};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
EXPECT_NE(kg, nullptr);

auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>();
pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>();
pm->AddPass(pass);
optimizer->AddPassManager(pm);
auto new_graph = optimizer->Optimize(kg);

FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "after");
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
}

TEST_F(TestHWInsertMemcpyForHccl, test_cond1_no_insert) {
get_py_fun_.SetDoResolve(true);
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before2");


Loading…
Cancel
Save