|
|
|
@@ -263,13 +263,17 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { |
|
|
|
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph); |
|
|
|
for (const auto &anf_node : kernel_graph->execution_order()) { |
|
|
|
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); |
|
|
|
bool is_comm_input = false; |
|
|
|
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) { |
|
|
|
auto indexes = comm_input_info_map[anf_node]; |
|
|
|
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node); |
|
|
|
is_comm_input = true; |
|
|
|
} |
|
|
|
|
|
|
|
if (apply_function_name == prim::kPrimMaxPoolGrad->name() && |
|
|
|
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { |
|
|
|
if (is_comm_input) { |
|
|
|
AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); |
|
|
|
} else if (apply_function_name == prim::kPrimMaxPoolGrad->name() && |
|
|
|
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { |
|
|
|
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName); |
|
|
|
MS_EXCEPTION_IF_NULL(clear_zero_prim); |
|
|
|
auto new_value_node = NewValueNode(clear_zero_prim); |
|
|
|
|