Merge pull request !1868 from huanghui/insert-memcpy-async-passtags/v0.5.0-beta
| @@ -81,7 +81,7 @@ | |||
| #include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" | |||
| #include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" | |||
| #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" | |||
| #include "pre_activate/ascend/enhancer/add_memcpy_async.h" | |||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | |||
| #include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" | |||
| #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" | |||
| #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | |||
| @@ -227,7 +227,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | |||
| } | |||
| ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | |||
| if (context_ptr->ir_fusion_flag()) { | |||
| AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); | |||
| @@ -238,6 +237,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| } | |||
| ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>()); | |||
| optimizer->AddPassManager(ir_fusion_pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| @@ -22,6 +22,8 @@ | |||
| #include "device/ascend/kernel_select_ascend.h" | |||
| #include "kernel/kernel_query.h" | |||
| #include "kernel/tbe/tbe_kernel_select.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -56,6 +58,17 @@ class KernelQuery { | |||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | |||
| kernel::KernelQuery(kernel_node, kernel_info_list); | |||
| } | |||
| virtual bool IsTbeRef(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| } | |||
| auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(node), kernel::kTBE); | |||
| if (op_info != nullptr) { | |||
| return op_info->is_ref(); | |||
| } | |||
| return false; | |||
| } | |||
| }; | |||
| using KernelQueryPtr = std::shared_ptr<KernelQuery>; | |||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, | |||
| @@ -1,75 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/enhancer/add_memcpy_async.h" | |||
| #include <vector> | |||
| #include "utils/utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "optimizer/opt.h" | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| bool InputIsParameterOrValueNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | |||
| return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>(); | |||
| } | |||
| const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, const CNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| const std::vector<AnfNodePtr> &inputs = node->inputs(); | |||
| bool replace = false; | |||
| if (inputs.empty()) { | |||
| MS_LOG(EXCEPTION) << "node[" + AnfAlgo::GetCNodeName(node) + "]'s inputs is empty"; | |||
| } | |||
| std::vector<AnfNodePtr> new_inputs = {inputs[0]}; | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto input = node->input(i); | |||
| if (manager->node_users().find(input) == manager->node_users().end()) { | |||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||
| } | |||
| // when input is used by others or is a parameter or is a value node, insert a memcpy_async | |||
| if (manager->node_users()[input].size() > 1 || InputIsParameterOrValueNode(input)) { | |||
| replace = true; | |||
| new_inputs.push_back(CreateMemcpyAsyncOp(graph, input)); | |||
| } else { | |||
| new_inputs.push_back(input); | |||
| } | |||
| } | |||
| CNodePtr new_node = std::make_shared<CNode>(*node); | |||
| new_node->set_inputs(new_inputs); | |||
| return replace ? new_node : nullptr; | |||
| } | |||
| } // namespace | |||
| const AnfNodePtr AddMemcpyAsync::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (!AnfAlgo::IsCommunicationOp(node)) { | |||
| return nullptr; | |||
| } | |||
| return AddMemcpyAsyncIfInputIsUsedByOthers(func_graph, cnode); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,135 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | |||
| #include <vector> | |||
| #include <set> | |||
| #include <string> | |||
| #include "utils/utils.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "optimizer/opt.h" | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| // insert memcpy for some cnode even if not a Ref cnode | |||
| const std::set<std::string> kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName, | |||
| kLambUpdateWithLROpName}; | |||
| bool IsParameterOrValueNode(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | |||
| return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>(); | |||
| } | |||
| void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { | |||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||
| MS_EXCEPTION_IF_NULL(memcpy_async); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| auto iter = node_users.find(hccl_node); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||
| } | |||
| // find hccl_node's output which is a control depend | |||
| for (const auto &node_index : iter->second) { | |||
| AnfNodePtr output = node_index.first; | |||
| int output_index = node_index.second; | |||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { | |||
| CNodePtr control_depend = output->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(control_depend); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| for (size_t i = 0; i < control_depend->size(); ++i) { | |||
| if (i == IntToSize(output_index)) { | |||
| new_inputs.push_back(memcpy_async); | |||
| } else { | |||
| new_inputs.push_back(control_depend->input(i)); | |||
| } | |||
| } | |||
| control_depend->set_inputs(new_inputs); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(input); | |||
| // when input is a parameter or is a value node | |||
| if (IsParameterOrValueNode(input)) { | |||
| return true; | |||
| } | |||
| // when input is a Ref or some special cnodes | |||
| if (kernel_query_->IsTbeRef(input) || | |||
| kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { | |||
| return true; | |||
| } | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto &node_users = manager->node_users(); | |||
| auto iter = node_users.find(input); | |||
| if (iter == node_users.end()) { | |||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||
| } | |||
| // when input is used by others | |||
| if (iter->second.size() > 1) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||
| if (hccl_node->size() != 2) { | |||
| MS_LOG(INFO) << "node[" + AnfAlgo::GetCNodeName(hccl_node) + "]'s inputs size not equal 2"; | |||
| return; | |||
| } | |||
| auto input = hccl_node->input(1); | |||
| if (NeedInsertMemcpy(graph, input)) { | |||
| auto memcpy_async = CreateMemcpyAsyncOp(graph, input); | |||
| CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node); | |||
| new_hccl_node->set_inputs({hccl_node->input(0), memcpy_async}); | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; | |||
| (void)manager->Replace(hccl_node, new_hccl_node); | |||
| MS_LOG(DEBUG) << "end replace"; | |||
| // transer hccl op's control to the memcpy_async | |||
| TransferControl(new_hccl_node, memcpy_async, graph); | |||
| } | |||
| } | |||
| const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (!AnfAlgo::IsCommunicationOp(node)) { | |||
| return nullptr; | |||
| } | |||
| InsertMemcpyAsync(func_graph, cnode); | |||
| return nullptr; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -13,19 +13,28 @@ | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_MEMCPY_ASYNC_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_MEMCPY_ASYNC_H_ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ | |||
| #include <memory> | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class AddMemcpyAsync : public PatternProcessPass { | |||
| class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { | |||
| public: | |||
| explicit AddMemcpyAsync(bool multigraph = true) : PatternProcessPass("add_memcpy_async", multigraph) {} | |||
| ~AddMemcpyAsync() override = default; | |||
| explicit InsertMemcpyAsyncForHcclOp(bool multigraph = true) | |||
| : PatternProcessPass("insert_memcpy_async_for_hccl_op", multigraph), | |||
| kernel_query_(std::make_shared<KernelQuery>()) {} | |||
| ~InsertMemcpyAsyncForHcclOp() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; | |||
| bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; | |||
| KernelQueryPtr kernel_query_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_MEMCPY_ASYNC_H_ | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ | |||
| @@ -56,7 +56,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s | |||
| bn_outputs->push_back(output); | |||
| output_num++; | |||
| } | |||
| return output_num > kBatchNormLeastOutputNum; | |||
| return output_num >= kBatchNormLeastOutputNum; | |||
| } | |||
| AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { | |||
| @@ -1,58 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/tensor.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "utils/utils.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "pre_activate/ascend/enhancer/add_memcpy_async.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TestHWAddMemcpyAsync : public BackendCommon { | |||
| public: | |||
| TestHWAddMemcpyAsync() : get_py_fun_("gtest_input.pre_activate.add_memcpy_async", true) {} | |||
| public: | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| TEST_F(TestHWAddMemcpyAsync, test_add_memcpy_async) { | |||
| get_py_fun_.SetDoResolve(true); | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_add_memcpy_async", "before"); | |||
| ASSERT_TRUE(g != nullptr); | |||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| auto func_graph = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(func_graph, nullptr); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::AddMemcpyAsync>(); | |||
| pm->AddPass(pass); | |||
| optimizer->AddPassManager(pm); | |||
| auto new_graph = optimizer->Optimize(func_graph); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_add_memcpy_async", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,165 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/backend_common_test.h" | |||
| #include "common/py_func_graph_fetcher.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "operator/ops.h" | |||
| #include "ir/tensor.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #include "utils/utils.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| #include "pre_activate/common/optimizer.h" | |||
| #define private public | |||
| #define protected public | |||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | |||
| #undef private | |||
| #undef protected | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TestHWInsertMemcpyForHccl : public BackendCommon { | |||
| public: | |||
| TestHWInsertMemcpyForHccl() : get_py_fun_("gtest_input.pre_activate.insert_memcpy_async_for_hccl_op", true) {} | |||
| ~TestHWInsertMemcpyForHccl() override = default; | |||
| public: | |||
| UT::PyFuncGraphFetcher get_py_fun_; | |||
| }; | |||
| class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { | |||
| public: | |||
| MockInsertMemcpyForHcclKernelQuery() = default; | |||
| ~MockInsertMemcpyForHcclKernelQuery() override = default; | |||
| bool IsTbeRef(const AnfNodePtr &node) override { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (cnode == nullptr) { | |||
| return false; | |||
| } | |||
| auto name = AnfAlgo::GetCNodeName(cnode); | |||
| return name == "ApplyMomentum"; | |||
| } | |||
| }; | |||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond1) { | |||
| get_py_fun_.SetDoResolve(true); | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before1"); | |||
| ASSERT_TRUE(g != nullptr); | |||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kg, nullptr); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||
| pm->AddPass(pass); | |||
| optimizer->AddPassManager(pm); | |||
| auto new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond1_no_insert) { | |||
| get_py_fun_.SetDoResolve(true); | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before2"); | |||
| ASSERT_TRUE(g != nullptr); | |||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kg, nullptr); | |||
| auto origin_graph = std::make_shared<session::KernelGraph>(*kg); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||
| pm->AddPass(pass); | |||
| optimizer->AddPassManager(pm); | |||
| auto new_graph = optimizer->Optimize(kg); | |||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||
| } | |||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond2) { | |||
| get_py_fun_.SetDoResolve(true); | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond2", "before"); | |||
| ASSERT_TRUE(g != nullptr); | |||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kg, nullptr); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||
| pm->AddPass(pass); | |||
| optimizer->AddPassManager(pm); | |||
| auto new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond2", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond3) { | |||
| get_py_fun_.SetDoResolve(true); | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "before"); | |||
| ASSERT_TRUE(g != nullptr); | |||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kg, nullptr); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||
| pm->AddPass(pass); | |||
| optimizer->AddPassManager(pm); | |||
| auto new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { | |||
| get_py_fun_.SetDoResolve(true); | |||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "before"); | |||
| ASSERT_TRUE(g != nullptr); | |||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; | |||
| auto kg = GetKernelGraph(g, args_spec_list); | |||
| EXPECT_NE(kg, nullptr); | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>(); | |||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||
| pm->AddPass(pass); | |||
| optimizer->AddPassManager(pm); | |||
| auto new_graph = optimizer->Optimize(kg); | |||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after"); | |||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,50 +0,0 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import operations as P | |||
| all_reduce = P.AllReduce() | |||
| memcpy_async = Primitive('memcpy_async') | |||
| make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| def __call__(self, fn): | |||
| self.fnDict[fn.__name__] = fn | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_add_memcpy_async(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(x): | |||
| res = all_reduce(x) | |||
| return make_tuple(x, res) | |||
| @fns | |||
| def after(x): | |||
| res = memcpy_async(x) | |||
| res = all_reduce(res) | |||
| return make_tuple(make_tuple(x, res)) | |||
| return fns[tag] | |||
| @@ -0,0 +1,120 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| from mindspore.ops import Primitive | |||
| from mindspore.ops import operations as P | |||
| all_reduce = P.AllReduce() | |||
| memcpy_async = Primitive('memcpy_async') | |||
| make_tuple = Primitive('make_tuple') | |||
| tuple_getitem = Primitive('tuple_getitem') | |||
| apply_momentun = P.ApplyMomentum() | |||
| control_depend = P.ControlDepend() | |||
| relu = P.ReLU() | |||
| class FnDict: | |||
| def __init__(self): | |||
| self.fnDict = {} | |||
| def __call__(self, fn): | |||
| self.fnDict[fn.__name__] = fn | |||
| def __getitem__(self, name): | |||
| return self.fnDict[name] | |||
| def test_insert_memcpy_async_for_hccl_op_cond1(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before1(x): | |||
| res1 = relu(x) | |||
| res2 = all_reduce(res1) | |||
| return make_tuple(res1, res2) | |||
| @fns | |||
| def before2(x): | |||
| res1 = relu(x) | |||
| res2 = all_reduce(res1) | |||
| return res2 | |||
| @fns | |||
| def after(x): | |||
| res1 = relu(x) | |||
| res2 = memcpy_async(res1) | |||
| res2 = all_reduce(res2) | |||
| return make_tuple(make_tuple(res1, res2)) | |||
| return fns[tag] | |||
| def test_insert_memcpy_async_for_hccl_op_cond2(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(x): | |||
| res = all_reduce(x) | |||
| return res | |||
| @fns | |||
| def after(x): | |||
| res = memcpy_async(x) | |||
| res = all_reduce(res) | |||
| return make_tuple(res) | |||
| return fns[tag] | |||
| def test_insert_memcpy_async_for_hccl_op_cond3(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(a, b, c, d, e): | |||
| res = apply_momentun(a, b, c, d, e) | |||
| res = all_reduce(res) | |||
| return res | |||
| @fns | |||
| def after(a, b, c, d, e): | |||
| res = apply_momentun(a, b, c, d, e) | |||
| res = memcpy_async(res) | |||
| res = all_reduce(res) | |||
| return make_tuple(res) | |||
| return fns[tag] | |||
| def test_insert_memcpy_async_for_hccl_op_cond4(tag): | |||
| fns = FnDict() | |||
| @fns | |||
| def before(a, b, c, d, e): | |||
| res1 = apply_momentun(a, b, c, d, e) | |||
| res2 = all_reduce(a) | |||
| res = control_depend(res1, res2) | |||
| res = make_tuple(res, res2) | |||
| return res | |||
| @fns | |||
| def after(a, b, c, d, e): | |||
| res1 = apply_momentun(a, b, c, d, e) | |||
| res2 = memcpy_async(a) | |||
| res3 = all_reduce(res2) | |||
| res = control_depend(res1, res2) | |||
| res = make_tuple(res, res3) | |||
| return make_tuple(res) | |||
| return fns[tag] | |||