| @@ -87,6 +87,7 @@ | |||||
| #include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" | #include "backend/optimizer/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" | ||||
| #include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" | #include "backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.h" | ||||
| #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | ||||
| #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h" | |||||
| #include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" | #include "backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.h" | ||||
| #include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h" | #include "backend/optimizer/ascend/format_type/insert_transdata_for_runop.h" | ||||
| #include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" | #include "backend/optimizer/ascend/enhancer/getnext_memcpy_elimination.h" | ||||
| @@ -340,6 +341,7 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern | |||||
| other_pm->AddPass(std::make_shared<AllGatherFusion>()); | other_pm->AddPass(std::make_shared<AllGatherFusion>()); | ||||
| other_pm->AddPass(std::make_shared<ReduceScatterFusion>()); | other_pm->AddPass(std::make_shared<ReduceScatterFusion>()); | ||||
| other_pm->AddPass(std::make_shared<BroadcastFusion>()); | other_pm->AddPass(std::make_shared<BroadcastFusion>()); | ||||
| other_pm->AddPass(std::make_shared<InsertMemcpyAsyncForCascade>()); | |||||
| other_pm->AddPass(std::make_shared<ParameterTransOpFusion>()); | other_pm->AddPass(std::make_shared<ParameterTransOpFusion>()); | ||||
| other_pm->AddPass(std::make_shared<RefreshParameterFormat>()); | other_pm->AddPass(std::make_shared<RefreshParameterFormat>()); | ||||
| optimizer->AddPassManager(other_pm); | optimizer->AddPassManager(other_pm); | ||||
| @@ -0,0 +1,114 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_cascade.h" | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "utils/utils.h" | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "frontend/optimizer/opt.h" | |||||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| bool IsPartOutputsOfHcclOp(const AnfNodePtr &node, const CNodePtr &cur_hccl, const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(cur_hccl); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { | |||||
| return false; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto prev_node = cnode->input(kRealInputNodeIndexInTupleGetItem); | |||||
| MS_EXCEPTION_IF_NULL(prev_node); | |||||
| if (!AnfAlgo::IsCommunicationOp(prev_node)) { | |||||
| return false; | |||||
| } | |||||
| auto prev_hccl_op = prev_node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(prev_hccl_op); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto &node_users = manager->node_users(); | |||||
| auto iter = node_users.find(prev_hccl_op); | |||||
| if (iter == node_users.end()) { | |||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||||
| } | |||||
| for (const auto &node_index : iter->second) { | |||||
| AnfNodePtr output = node_index.first; | |||||
| MS_EXCEPTION_IF_NULL(output); | |||||
| if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { | |||||
| bool is_contain = false; | |||||
| for (size_t i = 1; i < cur_hccl->size(); ++i) { | |||||
| if (cur_hccl->input(i) == output) { | |||||
| is_contain = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!is_contain) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace | |||||
| AnfNodePtr InsertMemcpyAsyncForCascade::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||||
| std::vector<AnfNodePtr> memcpy_async_list; | |||||
| std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)}; | |||||
| for (size_t i = 1; i < hccl_node->size(); ++i) { | |||||
| auto input = hccl_node->input(i); | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| // when input is also a hccl op and just part outputs of it linking with cur_hccl_op | |||||
| if (IsPartOutputsOfHcclOp(input, hccl_node, graph)) { | |||||
| auto memcpy_async = CreateMemcpyAsyncOp(graph, input); | |||||
| auto kernel_info = std::make_shared<device::KernelInfo>(); | |||||
| memcpy_async->set_kernel_info(kernel_info); | |||||
| MS_EXCEPTION_IF_NULL(kernel_select_); | |||||
| kernel_select_->SelectKernel(memcpy_async->cast<CNodePtr>()); | |||||
| new_inputs.push_back(memcpy_async); | |||||
| memcpy_async_list.push_back(memcpy_async); | |||||
| } else { | |||||
| new_inputs.push_back(input); | |||||
| } | |||||
| } | |||||
| if (!memcpy_async_list.empty()) { | |||||
| CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node); | |||||
| new_hccl_node->set_inputs(new_inputs); | |||||
| return new_hccl_node; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| const AnfNodePtr InsertMemcpyAsyncForCascade::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (!AnfAlgo::IsCommunicationOp(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| return InsertMemcpyAsync(func_graph, cnode); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,39 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_CASCADE_H_ | |||||
| #include <memory> | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class InsertMemcpyAsyncForCascade : public PatternProcessPass { | |||||
| public: | |||||
| explicit InsertMemcpyAsyncForCascade(bool multigraph = true) | |||||
| : PatternProcessPass("insert_memcpy_async_for_cascade", multigraph), | |||||
| kernel_select_(std::make_shared<KernelSelect>()) {} | |||||
| ~InsertMemcpyAsyncForCascade() override = default; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| AnfNodePtr InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; | |||||
| KernelSelectPtr kernel_select_; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_OP_CASCADE_H_ | |||||
| @@ -32,12 +32,17 @@ const std::set<std::string> kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNe | |||||
| bool IsParameterOrValueNode(const AnfNodePtr &node) { | bool IsParameterOrValueNode(const AnfNodePtr &node) { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | ||||
| return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>(); | |||||
| auto real_node = kernel_with_index.first; | |||||
| MS_EXCEPTION_IF_NULL(real_node); | |||||
| if (real_node->isa<Parameter>()) { | |||||
| return true; | |||||
| } | |||||
| return real_node->isa<ValueNode>(); | |||||
| } | } | ||||
| void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { | |||||
| void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list, | |||||
| const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(hccl_node); | MS_EXCEPTION_IF_NULL(hccl_node); | ||||
| MS_EXCEPTION_IF_NULL(memcpy_async); | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto manager = graph->manager(); | auto manager = graph->manager(); | ||||
| MS_EXCEPTION_IF_NULL(manager); | MS_EXCEPTION_IF_NULL(manager); | ||||
| @@ -48,49 +53,62 @@ void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, | |||||
| } | } | ||||
| // find hccl_node's output which is a control depend | // find hccl_node's output which is a control depend | ||||
| for (const auto &node_index : iter->second) { | for (const auto &node_index : iter->second) { | ||||
| AnfNodePtr output = node_index.first; | |||||
| int output_index = node_index.second; | |||||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { | |||||
| CNodePtr control_depend = output->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(control_depend); | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| for (size_t i = 0; i < control_depend->size(); ++i) { | |||||
| if (i == IntToSize(output_index)) { | |||||
| new_inputs.push_back(memcpy_async); | |||||
| } else { | |||||
| new_inputs.push_back(control_depend->input(i)); | |||||
| } | |||||
| if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimControlDepend)) { | |||||
| continue; | |||||
| } | |||||
| CNodePtr control_depend = node_index.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(control_depend); | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| for (size_t i = 0; i < control_depend->size(); ++i) { | |||||
| if (i == IntToSize(node_index.second)) { | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; | |||||
| make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end()); | |||||
| make_tuple_inputs.emplace_back(hccl_node); | |||||
| auto make_tuple = graph->NewCNode(make_tuple_inputs); | |||||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||||
| new_inputs.push_back(make_tuple); | |||||
| } else { | |||||
| new_inputs.push_back(control_depend->input(i)); | |||||
| } | } | ||||
| control_depend->set_inputs(new_inputs); | |||||
| } | } | ||||
| control_depend->set_inputs(new_inputs); | |||||
| } | } | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { | |||||
| bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, | |||||
| const CNodePtr &cur_node) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(input); | MS_EXCEPTION_IF_NULL(input); | ||||
| MS_EXCEPTION_IF_NULL(cur_node); | |||||
| // when input is a parameter or is a value node | // when input is a parameter or is a value node | ||||
| if (IsParameterOrValueNode(input)) { | if (IsParameterOrValueNode(input)) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| // when input is a Ref or some special cnodes | |||||
| if (kernel_query_->IsTbeRef(input) || | |||||
| kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { | |||||
| return true; | |||||
| } | |||||
| if (input->isa<CNode>()) { | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto &node_users = manager->node_users(); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto &node_users = manager->node_users(); | |||||
| auto iter = node_users.find(input); | |||||
| if (iter == node_users.end()) { | |||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||||
| } | |||||
| // when input is used by others | |||||
| if (iter->second.size() > 1) { | |||||
| return true; | |||||
| // when input is a Ref cnode | |||||
| if (kernel_query_->IsTbeRef(input)) { | |||||
| return true; | |||||
| } | |||||
| // when input is some special cnodes | |||||
| if (kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { | |||||
| return true; | |||||
| } | |||||
| // when input is used by others | |||||
| auto iter = node_users.find(input); | |||||
| if (iter == node_users.end()) { | |||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||||
| } | |||||
| if (iter->second.size() > 1) { | |||||
| return true; | |||||
| } | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -98,21 +116,20 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con | |||||
| void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { | void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(hccl_node); | MS_EXCEPTION_IF_NULL(hccl_node); | ||||
| bool has_insert_memcpy = false; | |||||
| AnfNodePtr memcpy_async = nullptr; | |||||
| std::vector<AnfNodePtr> memcpy_async_list; | |||||
| std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)}; | std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)}; | ||||
| for (size_t i = 1; i < hccl_node->size(); ++i) { | for (size_t i = 1; i < hccl_node->size(); ++i) { | ||||
| auto input = hccl_node->input(i); | auto input = hccl_node->input(i); | ||||
| if (NeedInsertMemcpy(graph, input)) { | |||||
| memcpy_async = CreateMemcpyAsyncOp(graph, input); | |||||
| has_insert_memcpy = true; | |||||
| if (NeedInsertMemcpy(graph, input, hccl_node)) { | |||||
| auto memcpy_async = CreateMemcpyAsyncOp(graph, input); | |||||
| new_inputs.push_back(memcpy_async); | new_inputs.push_back(memcpy_async); | ||||
| memcpy_async_list.push_back(memcpy_async); | |||||
| } else { | } else { | ||||
| new_inputs.push_back(input); | new_inputs.push_back(input); | ||||
| } | } | ||||
| } | } | ||||
| if (has_insert_memcpy) { | |||||
| if (!memcpy_async_list.empty()) { | |||||
| CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node); | CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node); | ||||
| new_hccl_node->set_inputs(new_inputs); | new_hccl_node->set_inputs(new_inputs); | ||||
| auto manager = graph->manager(); | auto manager = graph->manager(); | ||||
| @@ -122,9 +139,7 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co | |||||
| MS_LOG(DEBUG) << "end replace"; | MS_LOG(DEBUG) << "end replace"; | ||||
| // transer hccl op's control to the memcpy_async | // transer hccl op's control to the memcpy_async | ||||
| if (hccl_node->size() == 2) { | |||||
| TransferControl(new_hccl_node, memcpy_async, graph); | |||||
| } | |||||
| TransferControl(new_hccl_node, memcpy_async_list, graph); | |||||
| } | } | ||||
| } | } | ||||
| @@ -32,7 +32,7 @@ class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { | |||||
| private: | private: | ||||
| void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; | void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; | ||||
| bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; | |||||
| bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input, const CNodePtr &cur_node) const; | |||||
| KernelQueryPtr kernel_query_; | KernelQueryPtr kernel_query_; | ||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "backend/kernel_compiler/kernel_build_info.h" | #include "backend/kernel_compiler/kernel_build_info.h" | ||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| #include "ir/param_value.h" | |||||
| #define private public | #define private public | ||||
| #define protected public | #define protected public | ||||
| #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | #include "backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | ||||
| @@ -44,12 +45,10 @@ class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { | |||||
| ~MockInsertMemcpyForHcclKernelQuery() override = default; | ~MockInsertMemcpyForHcclKernelQuery() override = default; | ||||
| bool IsTbeRef(const AnfNodePtr &node) override { | bool IsTbeRef(const AnfNodePtr &node) override { | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto name = AnfAlgo::GetCNodeName(cnode); | |||||
| return name == "ApplyMomentum"; | |||||
| return AnfAlgo::GetCNodeName(node->cast<CNodePtr>()) == "ApplyMomentum"; | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -105,6 +104,11 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) { | |||||
| AbstractBasePtrList args_spec_list{x_abstract}; | AbstractBasePtrList args_spec_list{x_abstract}; | ||||
| auto kg = GetKernelGraph(g, args_spec_list); | auto kg = GetKernelGraph(g, args_spec_list); | ||||
| EXPECT_NE(kg, nullptr); | EXPECT_NE(kg, nullptr); | ||||
| for (auto p : kg->parameters()) { | |||||
| auto param = p->cast<ParameterPtr>(); | |||||
| EXPECT_NE(param, nullptr); | |||||
| param->set_default_param(std::make_shared<ParamValue>()); | |||||
| } | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| @@ -146,10 +150,16 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { | |||||
| ASSERT_TRUE(g != nullptr); | ASSERT_TRUE(g != nullptr); | ||||
| std::vector<int> shp_x{1, 64, 112, 112}; | std::vector<int> shp_x{1, 64, 112, 112}; | ||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | ||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; | |||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract}; | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | auto kg = GetKernelGraph(g, args_spec_list); | ||||
| EXPECT_NE(kg, nullptr); | EXPECT_NE(kg, nullptr); | ||||
| for (auto p : kg->parameters()) { | |||||
| auto param = p->cast<ParameterPtr>(); | |||||
| EXPECT_NE(param, nullptr); | |||||
| param->set_default_param(std::make_shared<ParamValue>()); | |||||
| } | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | ||||
| @@ -161,5 +171,34 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after"); | FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after"); | ||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | ||||
| } | } | ||||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond5) { | |||||
| get_py_fun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "before"); | |||||
| ASSERT_TRUE(g != nullptr); | |||||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract}; | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(kg, nullptr); | |||||
| for (auto p : kg->parameters()) { | |||||
| auto param = p->cast<ParameterPtr>(); | |||||
| EXPECT_NE(param, nullptr); | |||||
| param->set_default_param(std::make_shared<ParamValue>()); | |||||
| } | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| auto new_graph = optimizer->Optimize(kg); | |||||
| kg->SetExecOrderByDefault(); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond5", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,6 +17,7 @@ from mindspore.ops import Primitive | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| all_reduce = P.AllReduce() | all_reduce = P.AllReduce() | ||||
| broadcast = P.Broadcast(1) | |||||
| memcpy_async = Primitive('memcpy_async') | memcpy_async = Primitive('memcpy_async') | ||||
| make_tuple = Primitive('make_tuple') | make_tuple = Primitive('make_tuple') | ||||
| tuple_getitem = Primitive('tuple_getitem') | tuple_getitem = Primitive('tuple_getitem') | ||||
| @@ -101,20 +102,40 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag): | |||||
| fns = FnDict() | fns = FnDict() | ||||
| @fns | @fns | ||||
| def before(a, b, c, d, e): | |||||
| res1 = apply_momentun(a, b, c, d, e) | |||||
| res2 = all_reduce(a) | |||||
| res = control_depend(res1, res2) | |||||
| res = make_tuple(res, res2) | |||||
| def before(a, b): | |||||
| x = relu(a) | |||||
| y = all_reduce(b) | |||||
| res = control_depend(x, y) | |||||
| return res | return res | ||||
| @fns | @fns | ||||
| def after(a, b, c, d, e): | |||||
| res1 = apply_momentun(a, b, c, d, e) | |||||
| res2 = memcpy_async(a) | |||||
| res3 = all_reduce(res2) | |||||
| res = control_depend(res1, res2) | |||||
| res = make_tuple(res, res3) | |||||
| def after(a, b): | |||||
| x = relu(a) | |||||
| y1 = memcpy_async(b) | |||||
| y2 = all_reduce(y1) | |||||
| res = control_depend(x, make_tuple(y1, y2)) | |||||
| return make_tuple(res) | |||||
| return fns[tag] | |||||
| def test_insert_memcpy_async_for_hccl_op_cond5(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(a, b, c): | |||||
| x = relu(a) | |||||
| y = broadcast((b, c)) | |||||
| res = control_depend(x, y) | |||||
| return res | |||||
| @fns | |||||
| def after(a, b, c): | |||||
| x = relu(a) | |||||
| m1 = memcpy_async(b) | |||||
| m2 = memcpy_async(c) | |||||
| y = broadcast(m1, m2) | |||||
| res = control_depend(x, make_tuple(m1, m2, y)) | |||||
| return make_tuple(res) | return make_tuple(res) | ||||
| return fns[tag] | return fns[tag] | ||||