From 0c1e391556a739a928bee483e5f08211db699709 Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Mon, 13 Jul 2020 19:15:48 +0800 Subject: [PATCH] add atomic clean op for every communication op's input --- .../backend/optimizer/mem_reuse/mem_reuse.h | 2 +- .../device/ascend/kernel_build_ascend.cc | 45 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h index ad884f44b4..14c639fecc 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.h @@ -49,7 +49,7 @@ class MemReuseUtil { } MS_LOG(INFO) << "Total Dynamic Memory Size: " << total_dy_size_; MS_LOG(INFO) << "Total WorkSpace Memory Size: " << total_workspace_size_; - MS_LOG(INFO) << "Total Reused WorkSpafce Memory Size: " << total_reuseworkspace_size_; + MS_LOG(INFO) << "Total Reused WorkSpace Memory Size: " << total_reuseworkspace_size_; } void SetAllInfo(const KernelGraph *graph); diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc index 468879451f..d5b76edcf0 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc @@ -205,6 +205,10 @@ static bool IsAtomicNode(const CNodePtr &kernel_node) { } // process output std::vector output_indexs = {}; + if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, kernel_node)) { + output_indexs = AnfAlgo::GetNodeAttr>(kernel_node, kAttrAtomicOutputIndexs); + } + for (size_t i = 0; i < output_num; ++i) { auto param_output = parameters_indexs.at(input_num + workspace_num + i); if (param_output == 1) { @@ -212,7 +216,10 @@ static bool IsAtomicNode(const CNodePtr &kernel_node) { MS_LOG(INFO) << "Atomic clear output index: " << i; } } + if (!output_indexs.empty()) { + std::set s(output_indexs.begin(), output_indexs.end()); + output_indexs.assign(s.begin(), s.end()); AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexs), kernel_node); } // process workspace @@ -244,11 +251,49 @@ bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { return ret; } +std::map> GetCommunicationOpInputInfo( + const mindspore::session::KernelGraph *kernel_graph) { + std::map> comm_input_info_map; + for (auto &kernel : kernel_graph->execution_order()) { + auto input_num = AnfAlgo::GetInputTensorNum(kernel); + if (mindspore::session::AnfRuntimeAlgorithm::IsCommunicationOp(kernel)) { + for (size_t i = 0; i < input_num; i++) { + auto input_node = kernel->input(i + 1); + auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + MS_LOG(INFO) << " Add atomic clean for single communication op input, comm:" << kernel->fullname_with_scope() + << " input_node: " << kernel_input.first->fullname_with_scope() + << " index: " << kernel_input.second; + auto iter = comm_input_info_map.find(kernel_input.first); + if (iter != comm_input_info_map.end()) { + iter->second.push_back(kernel_input.second); + } else { + std::vector indexes = {kernel_input.second}; + comm_input_info_map[kernel_input.first] = indexes; + } + } + } + } + + // remove duplicate index + for (auto &info : comm_input_info_map) { + std::set s(info.second.begin(), info.second.end()); + info.second.assign(s.begin(), s.end()); + } + + return comm_input_info_map; +} + void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); std::vector new_nodes; + std::map> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph); for (const auto &anf_node : kernel_graph->execution_order()) { std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); + if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) { + auto indexes = comm_input_info_map[anf_node]; + AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node); + } + if (apply_function_name == prim::kPrimMaxPoolGrad->name() && AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { auto clear_zero_prim = std::make_shared(kClearZeroOpName);