diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc index e8f38aa339..596cf6790d 100644 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ b/mindspore/ccsrc/device/kernel_adjust.cc @@ -55,6 +55,24 @@ void KernelAdjust::Reorder(const std::shared_ptr &kernel_g kernel_graph_ptr->set_execution_order(new_order_list); } +void KernelAdjust::ReorderGetNext(const std::shared_ptr &kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + const std::vector &origin_cnode_list = kernel_graph_ptr->execution_order(); + std::vector getnext_list; + std::vector other_list; + for (const auto &cnode : origin_cnode_list) { + if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) { + getnext_list.emplace_back(cnode); + } else { + other_list.emplace_back(cnode); + } + } + std::vector new_order_list; + new_order_list.insert(new_order_list.end(), getnext_list.begin(), getnext_list.end()); + new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end()); + kernel_graph_ptr->set_execution_order(new_order_list); +} + bool KernelAdjust::NeedInsertSwitch() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -124,6 +142,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr return; } MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + ReorderGetNext(kernel_graph_ptr); std::map switch_loop_input; CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input); diff --git a/mindspore/ccsrc/device/kernel_adjust.h b/mindspore/ccsrc/device/kernel_adjust.h index 3dced257c1..4c69641a34 100644 --- a/mindspore/ccsrc/device/kernel_adjust.h +++ b/mindspore/ccsrc/device/kernel_adjust.h @@ -63,6 +63,7 @@ class KernelAdjust { KernelAdjust() = default; ~KernelAdjust() = default; + void ReorderGetNext(const std::shared_ptr &kernel_graph_ptr); CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); uint32_t FindFirstStreamSwitchLabel(const std::shared_ptr &kernel_graph_ptr);