Merge pull request !1868 from huanghui/insert-memcpy-async-passtags/v0.5.0-beta
| @@ -81,7 +81,7 @@ | |||||
| #include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" | #include "pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h" | ||||
| #include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" | #include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" | ||||
| #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" | #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" | ||||
| #include "pre_activate/ascend/enhancer/add_memcpy_async.h" | |||||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | |||||
| #include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" | #include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" | ||||
| #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" | #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" | ||||
| #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | ||||
| @@ -227,7 +227,6 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); | ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion0>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion1>()); | ||||
| } | } | ||||
| ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); | ||||
| if (context_ptr->ir_fusion_flag()) { | if (context_ptr->ir_fusion_flag()) { | ||||
| AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); | AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); | ||||
| @@ -238,6 +237,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap | |||||
| ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>()); | ||||
| } | } | ||||
| ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForHcclOp>()); | |||||
| optimizer->AddPassManager(ir_fusion_pm); | optimizer->AddPassManager(ir_fusion_pm); | ||||
| (void)optimizer->Optimize(kernel_graph); | (void)optimizer->Optimize(kernel_graph); | ||||
| kernel_graph->SetExecOrderByDefault(); | kernel_graph->SetExecOrderByDefault(); | ||||
| @@ -22,6 +22,8 @@ | |||||
| #include "device/ascend/kernel_select_ascend.h" | #include "device/ascend/kernel_select_ascend.h" | ||||
| #include "kernel/kernel_query.h" | #include "kernel/kernel_query.h" | ||||
| #include "kernel/tbe/tbe_kernel_select.h" | #include "kernel/tbe/tbe_kernel_select.h" | ||||
| #include "kernel/oplib/oplib.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -56,6 +58,17 @@ class KernelQuery { | |||||
| std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | std::vector<std::shared_ptr<kernel::KernelBuildInfo>> *kernel_info_list) { | ||||
| kernel::KernelQuery(kernel_node, kernel_info_list); | kernel::KernelQuery(kernel_node, kernel_info_list); | ||||
| } | } | ||||
| virtual bool IsTbeRef(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| if (!node->isa<CNode>()) { | |||||
| return false; | |||||
| } | |||||
| auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(node), kernel::kTBE); | |||||
| if (op_info != nullptr) { | |||||
| return op_info->is_ref(); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| }; | }; | ||||
| using KernelQueryPtr = std::shared_ptr<KernelQuery>; | using KernelQueryPtr = std::shared_ptr<KernelQuery>; | ||||
| void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, | void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, | ||||
| @@ -1,75 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/enhancer/add_memcpy_async.h" | |||||
| #include <vector> | |||||
| #include "utils/utils.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "optimizer/opt.h" | |||||
| #include "pre_activate/ascend/ascend_helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| bool InputIsParameterOrValueNode(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | |||||
| return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>(); | |||||
| } | |||||
| const AnfNodePtr AddMemcpyAsyncIfInputIsUsedByOthers(const FuncGraphPtr &graph, const CNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| const std::vector<AnfNodePtr> &inputs = node->inputs(); | |||||
| bool replace = false; | |||||
| if (inputs.empty()) { | |||||
| MS_LOG(EXCEPTION) << "node[" + AnfAlgo::GetCNodeName(node) + "]'s inputs is empty"; | |||||
| } | |||||
| std::vector<AnfNodePtr> new_inputs = {inputs[0]}; | |||||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||||
| auto input = node->input(i); | |||||
| if (manager->node_users().find(input) == manager->node_users().end()) { | |||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||||
| } | |||||
| // when input is used by others or is a parameter or is a value node, insert a memcpy_async | |||||
| if (manager->node_users()[input].size() > 1 || InputIsParameterOrValueNode(input)) { | |||||
| replace = true; | |||||
| new_inputs.push_back(CreateMemcpyAsyncOp(graph, input)); | |||||
| } else { | |||||
| new_inputs.push_back(input); | |||||
| } | |||||
| } | |||||
| CNodePtr new_node = std::make_shared<CNode>(*node); | |||||
| new_node->set_inputs(new_inputs); | |||||
| return replace ? new_node : nullptr; | |||||
| } | |||||
| } // namespace | |||||
| const AnfNodePtr AddMemcpyAsync::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (!AnfAlgo::IsCommunicationOp(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| return AddMemcpyAsyncIfInputIsUsedByOthers(func_graph, cnode); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,135 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "utils/utils.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "optimizer/opt.h" | |||||
| #include "pre_activate/ascend/ascend_helper.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| namespace { | |||||
| // insert memcpy for some cnode even if not a Ref cnode | |||||
| const std::set<std::string> kNeedInsertMemcpyOpSet = {kLambNextMVOpName, kLambNextMVWithDecayOpName, | |||||
| kLambUpdateWithLROpName}; | |||||
| bool IsParameterOrValueNode(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true); | |||||
| return kernel_with_index.first->isa<Parameter>() || kernel_with_index.first->isa<ValueNode>(); | |||||
| } | |||||
| void TransferControl(const CNodePtr &hccl_node, const AnfNodePtr &memcpy_async, const FuncGraphPtr &graph) { | |||||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||||
| MS_EXCEPTION_IF_NULL(memcpy_async); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto &node_users = manager->node_users(); | |||||
| auto iter = node_users.find(hccl_node); | |||||
| if (iter == node_users.end()) { | |||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||||
| } | |||||
| // find hccl_node's output which is a control depend | |||||
| for (const auto &node_index : iter->second) { | |||||
| AnfNodePtr output = node_index.first; | |||||
| int output_index = node_index.second; | |||||
| if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { | |||||
| CNodePtr control_depend = output->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(control_depend); | |||||
| std::vector<AnfNodePtr> new_inputs; | |||||
| for (size_t i = 0; i < control_depend->size(); ++i) { | |||||
| if (i == IntToSize(output_index)) { | |||||
| new_inputs.push_back(memcpy_async); | |||||
| } else { | |||||
| new_inputs.push_back(control_depend->input(i)); | |||||
| } | |||||
| } | |||||
| control_depend->set_inputs(new_inputs); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace | |||||
| bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| // when input is a parameter or is a value node | |||||
| if (IsParameterOrValueNode(input)) { | |||||
| return true; | |||||
| } | |||||
| // when input is a Ref or some special cnodes | |||||
| if (kernel_query_->IsTbeRef(input) || | |||||
| kNeedInsertMemcpyOpSet.find(AnfAlgo::GetCNodeName(input)) != kNeedInsertMemcpyOpSet.end()) { | |||||
| return true; | |||||
| } | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| auto &node_users = manager->node_users(); | |||||
| auto iter = node_users.find(input); | |||||
| if (iter == node_users.end()) { | |||||
| MS_LOG(EXCEPTION) << "node has no output in manager"; | |||||
| } | |||||
| // when input is used by others | |||||
| if (iter->second.size() > 1) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(hccl_node); | |||||
| if (hccl_node->size() != 2) { | |||||
| MS_LOG(INFO) << "node[" + AnfAlgo::GetCNodeName(hccl_node) + "]'s inputs size not equal 2"; | |||||
| return; | |||||
| } | |||||
| auto input = hccl_node->input(1); | |||||
| if (NeedInsertMemcpy(graph, input)) { | |||||
| auto memcpy_async = CreateMemcpyAsyncOp(graph, input); | |||||
| CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node); | |||||
| new_hccl_node->set_inputs({hccl_node->input(0), memcpy_async}); | |||||
| auto manager = graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; | |||||
| (void)manager->Replace(hccl_node, new_hccl_node); | |||||
| MS_LOG(DEBUG) << "end replace"; | |||||
| // transer hccl op's control to the memcpy_async | |||||
| TransferControl(new_hccl_node, memcpy_async, graph); | |||||
| } | |||||
| } | |||||
| const AnfNodePtr InsertMemcpyAsyncForHcclOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| if (func_graph == nullptr || node == nullptr || !node->isa<CNode>()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (!AnfAlgo::IsCommunicationOp(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| InsertMemcpyAsync(func_graph, cnode); | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -13,19 +13,28 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_MEMCPY_ASYNC_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_MEMCPY_ASYNC_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ | |||||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ | |||||
| #include <memory> | #include <memory> | ||||
| #include "pre_activate/common/optimizer.h" | #include "pre_activate/common/optimizer.h" | ||||
| #include "pre_activate/ascend/ascend_helper.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| class AddMemcpyAsync : public PatternProcessPass { | |||||
| class InsertMemcpyAsyncForHcclOp : public PatternProcessPass { | |||||
| public: | public: | ||||
| explicit AddMemcpyAsync(bool multigraph = true) : PatternProcessPass("add_memcpy_async", multigraph) {} | |||||
| ~AddMemcpyAsync() override = default; | |||||
| explicit InsertMemcpyAsyncForHcclOp(bool multigraph = true) | |||||
| : PatternProcessPass("insert_memcpy_async_for_hccl_op", multigraph), | |||||
| kernel_query_(std::make_shared<KernelQuery>()) {} | |||||
| ~InsertMemcpyAsyncForHcclOp() override = default; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | |||||
| void InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const; | |||||
| bool NeedInsertMemcpy(const FuncGraphPtr &graph, const AnfNodePtr &input) const; | |||||
| KernelQueryPtr kernel_query_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_ADD_MEMCPY_ASYNC_H_ | |||||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_MEMCPY_ASYNC_FOR_HCCL_OP_H_ | |||||
| @@ -56,7 +56,7 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s | |||||
| bn_outputs->push_back(output); | bn_outputs->push_back(output); | ||||
| output_num++; | output_num++; | ||||
| } | } | ||||
| return output_num > kBatchNormLeastOutputNum; | |||||
| return output_num >= kBatchNormLeastOutputNum; | |||||
| } | } | ||||
| AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { | AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &bn) { | ||||
| @@ -1,58 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/backend_common_test.h" | |||||
| #include "common/py_func_graph_fetcher.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "operator/ops.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #include "utils/utils.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #include "pre_activate/ascend/enhancer/add_memcpy_async.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class TestHWAddMemcpyAsync : public BackendCommon { | |||||
| public: | |||||
| TestHWAddMemcpyAsync() : get_py_fun_("gtest_input.pre_activate.add_memcpy_async", true) {} | |||||
| public: | |||||
| UT::PyFuncGraphFetcher get_py_fun_; | |||||
| }; | |||||
| TEST_F(TestHWAddMemcpyAsync, test_add_memcpy_async) { | |||||
| get_py_fun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_add_memcpy_async", "before"); | |||||
| ASSERT_TRUE(g != nullptr); | |||||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||||
| auto func_graph = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(func_graph, nullptr); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::AddMemcpyAsync>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| auto new_graph = optimizer->Optimize(func_graph); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_add_memcpy_async", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,165 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/backend_common_test.h" | |||||
| #include "common/py_func_graph_fetcher.h" | |||||
| #include "session/anf_runtime_algorithm.h" | |||||
| #include "operator/ops.h" | |||||
| #include "ir/tensor.h" | |||||
| #include "debug/anf_ir_dump.h" | |||||
| #include "utils/utils.h" | |||||
| #include "kernel/kernel_build_info.h" | |||||
| #include "pre_activate/common/optimizer.h" | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "pre_activate/ascend/enhancer/insert_memcpy_async_for_hccl_op.h" | |||||
| #undef private | |||||
| #undef protected | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class TestHWInsertMemcpyForHccl : public BackendCommon { | |||||
| public: | |||||
| TestHWInsertMemcpyForHccl() : get_py_fun_("gtest_input.pre_activate.insert_memcpy_async_for_hccl_op", true) {} | |||||
| ~TestHWInsertMemcpyForHccl() override = default; | |||||
| public: | |||||
| UT::PyFuncGraphFetcher get_py_fun_; | |||||
| }; | |||||
| class MockInsertMemcpyForHcclKernelQuery : public KernelQuery { | |||||
| public: | |||||
| MockInsertMemcpyForHcclKernelQuery() = default; | |||||
| ~MockInsertMemcpyForHcclKernelQuery() override = default; | |||||
| bool IsTbeRef(const AnfNodePtr &node) override { | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto name = AnfAlgo::GetCNodeName(cnode); | |||||
| return name == "ApplyMomentum"; | |||||
| } | |||||
| }; | |||||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond1) { | |||||
| get_py_fun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before1"); | |||||
| ASSERT_TRUE(g != nullptr); | |||||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(kg, nullptr); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| auto new_graph = optimizer->Optimize(kg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond1_no_insert) { | |||||
| get_py_fun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond1", "before2"); | |||||
| ASSERT_TRUE(g != nullptr); | |||||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(kg, nullptr); | |||||
| auto origin_graph = std::make_shared<session::KernelGraph>(*kg); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| auto new_graph = optimizer->Optimize(kg); | |||||
| EXPECT_TRUE(CheckEqualGraph(origin_graph, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond2) { | |||||
| get_py_fun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond2", "before"); | |||||
| ASSERT_TRUE(g != nullptr); | |||||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| AbstractBasePtrList args_spec_list{x_abstract}; | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(kg, nullptr); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| auto new_graph = optimizer->Optimize(kg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond2", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond3) { | |||||
| get_py_fun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "before"); | |||||
| ASSERT_TRUE(g != nullptr); | |||||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(kg, nullptr); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| auto new_graph = optimizer->Optimize(kg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond3", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { | |||||
| get_py_fun_.SetDoResolve(true); | |||||
| FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "before"); | |||||
| ASSERT_TRUE(g != nullptr); | |||||
| std::vector<int> shp_x{1, 64, 112, 112}; | |||||
| auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x); | |||||
| AbstractBasePtrList args_spec_list{x_abstract, x_abstract, x_abstract, x_abstract, x_abstract}; | |||||
| auto kg = GetKernelGraph(g, args_spec_list); | |||||
| EXPECT_NE(kg, nullptr); | |||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||||
| auto pm = std::make_shared<opt::PassManager>(); | |||||
| auto pass = std::make_shared<opt::InsertMemcpyAsyncForHcclOp>(); | |||||
| pass->kernel_query_ = std::make_shared<MockInsertMemcpyForHcclKernelQuery>(); | |||||
| pm->AddPass(pass); | |||||
| optimizer->AddPassManager(pm); | |||||
| auto new_graph = optimizer->Optimize(kg); | |||||
| FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_insert_memcpy_async_for_hccl_op_cond4", "after"); | |||||
| EXPECT_TRUE(CheckEqualGraph(g_after, new_graph)); | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,50 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore.ops import Primitive | |||||
| from mindspore.ops import operations as P | |||||
| all_reduce = P.AllReduce() | |||||
| memcpy_async = Primitive('memcpy_async') | |||||
| make_tuple = Primitive('make_tuple') | |||||
| tuple_getitem = Primitive('tuple_getitem') | |||||
| class FnDict: | |||||
| def __init__(self): | |||||
| self.fnDict = {} | |||||
| def __call__(self, fn): | |||||
| self.fnDict[fn.__name__] = fn | |||||
| def __getitem__(self, name): | |||||
| return self.fnDict[name] | |||||
| def test_add_memcpy_async(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(x): | |||||
| res = all_reduce(x) | |||||
| return make_tuple(x, res) | |||||
| @fns | |||||
| def after(x): | |||||
| res = memcpy_async(x) | |||||
| res = all_reduce(res) | |||||
| return make_tuple(make_tuple(x, res)) | |||||
| return fns[tag] | |||||
| @@ -0,0 +1,120 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| from mindspore.ops import Primitive | |||||
| from mindspore.ops import operations as P | |||||
| all_reduce = P.AllReduce() | |||||
| memcpy_async = Primitive('memcpy_async') | |||||
| make_tuple = Primitive('make_tuple') | |||||
| tuple_getitem = Primitive('tuple_getitem') | |||||
| apply_momentun = P.ApplyMomentum() | |||||
| control_depend = P.ControlDepend() | |||||
| relu = P.ReLU() | |||||
| class FnDict: | |||||
| def __init__(self): | |||||
| self.fnDict = {} | |||||
| def __call__(self, fn): | |||||
| self.fnDict[fn.__name__] = fn | |||||
| def __getitem__(self, name): | |||||
| return self.fnDict[name] | |||||
| def test_insert_memcpy_async_for_hccl_op_cond1(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before1(x): | |||||
| res1 = relu(x) | |||||
| res2 = all_reduce(res1) | |||||
| return make_tuple(res1, res2) | |||||
| @fns | |||||
| def before2(x): | |||||
| res1 = relu(x) | |||||
| res2 = all_reduce(res1) | |||||
| return res2 | |||||
| @fns | |||||
| def after(x): | |||||
| res1 = relu(x) | |||||
| res2 = memcpy_async(res1) | |||||
| res2 = all_reduce(res2) | |||||
| return make_tuple(make_tuple(res1, res2)) | |||||
| return fns[tag] | |||||
| def test_insert_memcpy_async_for_hccl_op_cond2(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(x): | |||||
| res = all_reduce(x) | |||||
| return res | |||||
| @fns | |||||
| def after(x): | |||||
| res = memcpy_async(x) | |||||
| res = all_reduce(res) | |||||
| return make_tuple(res) | |||||
| return fns[tag] | |||||
| def test_insert_memcpy_async_for_hccl_op_cond3(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(a, b, c, d, e): | |||||
| res = apply_momentun(a, b, c, d, e) | |||||
| res = all_reduce(res) | |||||
| return res | |||||
| @fns | |||||
| def after(a, b, c, d, e): | |||||
| res = apply_momentun(a, b, c, d, e) | |||||
| res = memcpy_async(res) | |||||
| res = all_reduce(res) | |||||
| return make_tuple(res) | |||||
| return fns[tag] | |||||
| def test_insert_memcpy_async_for_hccl_op_cond4(tag): | |||||
| fns = FnDict() | |||||
| @fns | |||||
| def before(a, b, c, d, e): | |||||
| res1 = apply_momentun(a, b, c, d, e) | |||||
| res2 = all_reduce(a) | |||||
| res = control_depend(res1, res2) | |||||
| res = make_tuple(res, res2) | |||||
| return res | |||||
| @fns | |||||
| def after(a, b, c, d, e): | |||||
| res1 = apply_momentun(a, b, c, d, e) | |||||
| res2 = memcpy_async(a) | |||||
| res3 = all_reduce(res2) | |||||
| res = control_depend(res1, res2) | |||||
| res = make_tuple(res, res3) | |||||
| return make_tuple(res) | |||||
| return fns[tag] | |||||