| @@ -87,6 +87,7 @@ | |||
| #include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" | |||
| #include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" | |||
| #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | |||
| #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h" | |||
| #include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" | |||
| #include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h" | |||
| #include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" | |||
| @@ -340,6 +341,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||
| other_pm->AddPass(std::make_shared<AllGatherFusion>()); | |||
| other_pm->AddPass(std::make_shared<ReduceScatterFusion>()); | |||
| other_pm->AddPass(std::make_shared<BroadcastFusion>()); | |||
| other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>()); | |||
| other_pm->AddPass(std::make_shared<ParameterTransOpFusion>()); | |||
| other_pm->AddPass(std::make_shared<RefreshParameterFormat>()); | |||
| optimizer->AddPassManager(other_pm); | |||
| @@ -0,0 +1,114 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h" | |||
| #include <vector> | |||
| #include <set> | |||
| #include <string> | |||
| #include "utils/utils.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "frontend/optimizer/opt.h" | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool IsPartOutputsOfHcclOp(const AnfNodePtr &node, const CNodePtr &cur_hccl, const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| MS_EXCEPTION_IF_NULL(cur_hccl); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||
| return false; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto prev_node = cnode->input(kRealInputNodeIndexInTupleGetItem); | |||
| MS_EXCEPTION_IF_NULL(prev_node); | |||
| if (!AnfAlgo::IsCommunicationOp(prev_node)) { | |||
| return false; | |||
| } | |||
| auto prev_hccl_op = prev_node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(prev_hccl_op); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| auto iter = node_users.find(prev_hccl_op); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||
| } | |||
| for (const auto &node_index : iter->second) { | |||
| AnfNodePtr output = node_index.first; | |||
| MS_EXCEPTION_IF_NULL(output); | |||
| if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | |||
| bool is_contain = false; | |||
| for (size_t i = 1; i < cur_hccl->size(); ++i) { | |||
| if (cur_hccl->input(i) == output) { | |||
| is_contain = true; | |||
| break; | |||
| } | |||
| } | |||
| if (!is_contain) { | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace | |||
| AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||
| std::vector<AnfNodePtr> memcpy_async_list; | |||
| std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)}; | |||
| for (size_t i = 1; i < hccl_node->size(); ++i) { | |||
| auto input = hccl_node->input(i); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| // when input is also a hccl op and just part outputs of it linking with cur_hccl_op | |||
| if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) { | |||
| auto memcpy_async = CreateMemcpyAsyncOp(graph, input); | |||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||
| memcpy_async->set_kernel_info(kernel_info); | |||
| MS_EXCEPTION_IF_NULL(kernel_select_); | |||
| kernel_select_->SelectKernel(memcpy_async->cast<CNodePtr>()); | |||
| new_inputs.push_back(memcpy_async); | |||
| memcpy_async_list.push_back(memcpy_async); | |||
| } else { | |||
| new_inputs.push_back(input); | |||
| } | |||
| } | |||
| if (!memcpy_async_list.empty()) { | |||
| CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node); | |||
| new_hccl_node->set_inputs(new_inputs); | |||
| return new_hccl_node; | |||
| } | |||
| return nullptr; | |||
| } | |||
| const AnfNodePtr InsertMemcpyAsyncForCascade::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (!AnfAlgo::IsCommunicationOp(node)) { | |||
| return nullptr; | |||
| } | |||
| return InsertMemcpyAsync(func_graph, cnode); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_ | |||
| #include <memory> | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class InsertMemcpyAsyncForCascade : public PatternProcessPass { | |||
| public: | |||
| explicit InsertMemcpyAsyncForCascade(bool multigraph = true) | |||
| : PatternProcessPass("insert_memcpy_async_for_cascade", multigraph), | |||
| kernel_select_(std::make_shared<KernelSelect>()) {} | |||
| ~InsertMemcpyAsyncForCascade() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| AnfNodePtr InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; | |||
| KernelSelectPtr kernel_select_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_OP_CASCADE_H_ | |||
| @@ -32,12 +32,17 @@ const std::set<std::string> kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNe | |||
| bool IsParameterOrValueNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | |||
| return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>(); | |||
| auto real_node = kernel_with_index.first; | |||
| MS_EXCEPTION_IF_NULL(real_node); | |||
| if (real_node->isa<Parameter>()) { | |||
| return true; | |||
| } | |||
| return real_node->isa<ValueNode>(); | |||
| } | |||
| void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { | |||
| void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list, | |||
| const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||
| MS_EXCEPTION_IF_NULL(memcpy_async); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| @@ -48,49 +53,62 @@ void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, | |||
| } | |||
| // find hccl_node's output which is a control depend | |||
| for (const auto &node_index : iter->second) { | |||
| AnfNodePtr output = node_index.first; | |||
| int output_index = node_index.second; | |||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { | |||
| CNodePtr control_depend = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(control_depend); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| for (size_t i = 0; i < control_depend->size(); ++i) { | |||
| if (i == IntToSize(output_index)) { | |||
| new_inputs.push_back(memcpy_async); | |||
| } else { | |||
| new_inputs.push_back(control_depend->input(i)); | |||
| } | |||
| if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimControlDepend)) { | |||
| continue; | |||
| } | |||
| CNodePtr control_depend = node_index.first->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(control_depend); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| for (size_t i = 0; i < control_depend->size(); ++i) { | |||
| if (i == IntToSize(node_index.second)) { | |||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||
| make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end()); | |||
| make_tuple_inputs.emplace_back(hccl_node); | |||
| auto make_tuple = graph->NewCNode(make_tuple_inputs); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| new_inputs.push_back(make_tuple); | |||
| } else { | |||
| new_inputs.push_back(control_depend->input(i)); | |||
| } | |||
| control_depend->set_inputs(new_inputs); | |||
| } | |||
| control_depend->set_inputs(new_inputs); | |||
| } | |||
| } | |||
| } // namespace | |||
| bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { | |||
| bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, | |||
| const CNodePtr &cur_node) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| MS_EXCEPTION_IF_NULL(cur_node); | |||
| // when input is a parameter or is a value node | |||
| if (IsParameterOrValueNode(input)) { | |||
| return true; | |||
| } | |||
| // when input is a Ref or some special cnodes | |||
| if (kernel_query_->IsTbeRef(input) || | |||
| kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { | |||
| return true; | |||
| } | |||
| if (input->isa<CNode>()) { | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| auto iter = node_users.find(input); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||
| } | |||
| // when input is used by others | |||
| if (iter->second.size() > 1) { | |||
| return true; | |||
| // when input is a Ref cnode | |||
| if (kernel_query_->IsTbeRef(input)) { | |||
| return true; | |||
| } | |||
| // when input is some special cnodes | |||
| if (kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { | |||
| return true; | |||
| } | |||
| // when input is used by others | |||
| auto iter = node_users.find(input); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||
| } | |||
| if (iter->second.size() > 1) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| @@ -98,21 +116,20 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con | |||
| void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||
| bool has_insert_memcpy = false; | |||
| AnfNodePtr memcpy_async = nullptr; | |||
| std::vector<AnfNodePtr> memcpy_async_list; | |||
| std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)}; | |||
| for (size_t i = 1; i < hccl_node->size(); ++i) { | |||
| auto input = hccl_node->input(i); | |||
| if (NeedInsertMemcpy(graph, input)) { | |||
| memcpy_async = CreateMemcpyAsyncOp(graph, input); | |||
| has_insert_memcpy = true; | |||
| if (NeedInsertMemcpy(graph, input, hccl_node)) { | |||
| auto memcpy_async = CreateMemcpyAsyncOp(graph, input); | |||
| new_inputs.push_back(memcpy_async); | |||
| memcpy_async_list.push_back(memcpy_async); | |||
| } else { | |||
| new_inputs.push_back(input); | |||
| } | |||
| } | |||
| if (has_insert_memcpy) { | |||
| if (!memcpy_async_list.empty()) { | |||
| CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node); | |||
| new_hccl_node->set_inputs(new_inputs); | |||
| auto manager = graph->manager(); | |||
| @@ -122,9 +139,7 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co | |||
| MS_LOG(DEBUG) << "end replace"; | |||
| // transer hccl op's control to the memcpy_async | |||
| if (hccl_node->size() == 2) { | |||
| TransferControl(new_hccl_node, memcpy_async, graph); | |||
| } | |||
| TransferControl(new_hccl_node, memcpy_async_list, graph); | |||
| } | |||
| } | |||
| @@ -32,7 +32,7 @@ class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { | |||
| private: | |||
| void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; | |||
| bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; | |||
| bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &cur_node) const; | |||
| KernelQueryPtr kernel_query_; | |||
| }; | |||
| } // namespace opt | |||
| @@ -22,6 +22,7 @@ | |||
| #include "utils/utils.h" | |||
| #include "backend/kernel_compiler/kernel_build_info.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "ir/param_value.h" | |||
| #define private public | |||
| #define protected public | |||
| #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | |||
| @@ -44,12 +45,10 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { | |||
| ~MockInsertMemcpyForHcclKernelQuery() override = default; | |||
| bool IsTbeRef(const AnfNodePtr &node) override { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto name = AnfAlgo::GetCNodeName(cnode); | |||
| return name == "ApplyMomentum"; | |||
| return AnfAlgo::GetCNodeName(node->cast<CNodePtr>()) == "ApplyMomentum"; | |||
| } | |||
| }; | |||
| @@ -105,6 +104,11 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) { | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kg, nullptr); | |||
| for (auto p : kg->parameters()) { | |||
| auto param = p->cast<ParameterPtr>(); | |||
| EXPECT_NE(param, nullptr); | |||
| param->set_default_param(std::make_shared<ParamValue>()); | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| @@ -146,10 +150,16 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { | |||
| ASSERT_TRUE(g != nullptr); | |||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kg, nullptr); | |||
| for (auto p : kg->parameters()) { | |||
| auto param = p->cast<ParameterPtr>(); | |||
| EXPECT_NE(param, nullptr); | |||
| param->set_default_param(std::make_shared<ParamValue>()); | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||
| @@ -161,5 +171,34 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond5) { | |||
| get_py_fun_.SetDoResolve(true); | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "before"); | |||
| ASSERT_TRUE(g != nullptr); | |||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kg, nullptr); | |||
| for (auto p : kg->parameters()) { | |||
| auto param = p->cast<ParameterPtr>(); | |||
| EXPECT_NE(param, nullptr); | |||
| param->set_default_param(std::make_shared<ParamValue>()); | |||
| } | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||
| pm->AddPass(pass); | |||
| optimizer->AddPassManager(pm); | |||
| auto new_graph = optimizer->Optimize(kg); | |||
| kg->SetExecOrderByDefault(); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -17,6 +17,7 @@ from mindspore.ops import Primitive | |||
| from mindspore.ops import operations as P | |||
| all_reduce = P.AllReduce() | |||
| broadcast = P.Broadcast(1) | |||
| memcpy_async = Primitive('memcpy_async') | |||
| make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| @@ -101,20 +102,40 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(a, b, c, d, e): | |||
| res1 = apply_momentun(a, b, c, d, e) | |||
| res2 = all_reduce(a) | |||
| res = control_depend(res1, res2) | |||
| res = make_tuple(res, res2) | |||
| def before(a, b): | |||
| x = relu(a) | |||
| y = all_reduce(b) | |||
| res = control_depend(x, y) | |||
| return res | |||
| @fns | |||
| def after(a, b, c, d, e): | |||
| res1 = apply_momentun(a, b, c, d, e) | |||
| res2 = memcpy_async(a) | |||
| res3 = all_reduce(res2) | |||
| res = control_depend(res1, res2) | |||
| res = make_tuple(res, res3) | |||
| def after(a, b): | |||
| x = relu(a) | |||
| y1 = memcpy_async(b) | |||
| y2 = all_reduce(y1) | |||
| res = control_depend(x, make_tuple(y1, y2)) | |||
| return make_tuple(res) | |||
| return fns[tag] | |||
| def test_insert_memcpy_async_for_hccl_op_cond5(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(a, b, c): | |||
| x = relu(a) | |||
| y = broadcast((b, c)) | |||
| res = control_depend(x, y) | |||
| return res | |||
| @fns | |||
| def after(a, b, c): | |||
| x = relu(a) | |||
| m1 = memcpy_async(b) | |||
| m2 = memcpy_async(c) | |||
| y = broadcast(m1, m2) | |||
| res = control_depend(x, make_tuple(m1, m2, y)) | |||
| return make_tuple(res) | |||
| return fns[tag] | |||