Merge pull request !7147 from jjfeing/mastertags/v1.1.0
| @@ -94,6 +94,22 @@ void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &m | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i) | |||||
| bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) { | |||||
| if (node_users.size() == 1) { | |||||
| MS_LOG(INFO) << "This node only used once, no need to insert memcpy node."; | |||||
| return false; | |||||
| } | |||||
| for (const auto &node_pair : node_users) { | |||||
| auto node = node_pair.first; | |||||
| if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) { | |||||
| MS_LOG(INFO) << "This node only used other real kernel: " << node->fullname_with_scope(); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "This node used by other node, but the node is not real kernel, no need to insert memcpy node."; | |||||
| return false; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, | bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, | ||||
| @@ -126,7 +142,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con | |||||
| if (iter == node_users.end()) { | if (iter == node_users.end()) { | ||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | MS_LOG(EXCEPTION) << "node has no output in manager"; | ||||
| } | } | ||||
| if (iter->second.size() > 1) { | |||||
| if (IsNodeOutPutUsedByOtherRealKernel(iter->second)) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| @@ -52,28 +52,6 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { | |||||
| } | } | ||||
| }; | }; | ||||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond1) { | |||||
| get_py_fun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before1"); | |||||
| ASSERT_TRUE(g != nullptr); | |||||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(kg, nullptr); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| auto new_graph = optimizer->Optimize(kg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond1_no_insert) { | TEST_F(TestHWInsertMemcpyForHccl, test_cond1_no_insert) { | ||||
| get_py_fun_.SetDoResolve(true); | get_py_fun_.SetDoResolve(true); | ||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before2"); | FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before2"); | ||||