| @@ -184,11 +184,17 @@ bool IsAtomicNode(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod); | |||
| auto atomic_flag = false; | |||
| std::vector<size_t> clean_output_indexs; | |||
| if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, kernel_node)) { | |||
| clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(kernel_node, kAttrAutomicOutputIndexs); | |||
| atomic_flag = true; | |||
| } | |||
| auto parameters_indexs = kernel_mod->GenParameters(); | |||
| if (parameters_indexs.empty()) { | |||
| return false; | |||
| return atomic_flag; | |||
| } | |||
| auto atomic_flag = false; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| auto workspace_size_list = kernel_mod->GetWorkspaceSizeList(); | |||
| @@ -199,7 +205,7 @@ bool IsAtomicNode(const CNodePtr &kernel_node) { | |||
| parameters_indexs.push_back(0); | |||
| } | |||
| } | |||
| std::vector<size_t> clean_output_indexs; | |||
| // in parameters data sort as input->workspace->output | |||
| size_t index = 0; | |||
| while (index < output_num) { | |||
| @@ -210,6 +216,8 @@ bool IsAtomicNode(const CNodePtr &kernel_node) { | |||
| index++; | |||
| } | |||
| if (atomic_flag) { | |||
| std::set<size_t> s(clean_output_indexs.begin(), clean_output_indexs.end()); | |||
| clean_output_indexs.assign(s.begin(), s.end()); | |||
| AnfAlgo::SetNodeAttr(kAttrAutomicOutputIndexs, MakeValue(clean_output_indexs), kernel_node); | |||
| } | |||
| for (size_t i = 0; i < workspace_num; ++i) { | |||
| @@ -238,11 +246,49 @@ bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) { | |||
| return ret; | |||
| } | |||
| std::map<AnfNodePtr, std::vector<size_t>> GetCommunicationOpInputInfo( | |||
| const mindspore::session::KernelGraph *kernel_graph) { | |||
| std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map; | |||
| for (auto &kernel : kernel_graph->execution_order()) { | |||
| auto input_num = AnfAlgo::GetInputTensorNum(kernel); | |||
| if (mindspore::session::AnfRuntimeAlgorithm::IsCommunicationOp(kernel)) { | |||
| for (size_t i = 0; i < input_num; i++) { | |||
| auto input_node = kernel->input(i + 1); | |||
| auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); | |||
| MS_LOG(INFO) << " Add atomic clean for single communication op input, comm:" << kernel->fullname_with_scope() | |||
| << " input_node: " << kernel_input.first->fullname_with_scope() | |||
| << " index: " << kernel_input.second; | |||
| auto iter = comm_input_info_map.find(kernel_input.first); | |||
| if (iter != comm_input_info_map.end()) { | |||
| iter->second.push_back(kernel_input.second); | |||
| } else { | |||
| std::vector<size_t> indexes = {kernel_input.second}; | |||
| comm_input_info_map[kernel_input.first] = indexes; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| // remove duplicate index | |||
| for (auto &info : comm_input_info_map) { | |||
| std::set<size_t> s(info.second.begin(), info.second.end()); | |||
| info.second.assign(s.begin(), s.end()); | |||
| } | |||
| return comm_input_info_map; | |||
| } | |||
| void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| std::vector<CNodePtr> new_nodes; | |||
| std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph); | |||
| for (const auto &anf_node : kernel_graph->execution_order()) { | |||
| std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); | |||
| if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) { | |||
| auto indexes = comm_input_info_map[anf_node]; | |||
| AnfAlgo::SetNodeAttr(kAttrAutomicOutputIndexs, MakeValue(indexes), anf_node); | |||
| } | |||
| if (apply_function_name == prim::kPrimMaxPoolGrad->name() && | |||
| AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { | |||
| auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName); | |||