|
|
|
@@ -55,6 +55,24 @@ void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_g |
|
|
|
kernel_graph_ptr->set_execution_order(new_order_list); |
|
|
|
} |
|
|
|
|
|
|
|
void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph_ptr); |
|
|
|
const std::vector<CNodePtr> &origin_cnode_list = kernel_graph_ptr->execution_order(); |
|
|
|
std::vector<CNodePtr> getnext_list; |
|
|
|
std::vector<CNodePtr> other_list; |
|
|
|
for (const auto &cnode : origin_cnode_list) { |
|
|
|
if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) { |
|
|
|
getnext_list.emplace_back(cnode); |
|
|
|
} else { |
|
|
|
other_list.emplace_back(cnode); |
|
|
|
} |
|
|
|
} |
|
|
|
std::vector<CNodePtr> new_order_list; |
|
|
|
new_order_list.insert(new_order_list.end(), getnext_list.begin(), getnext_list.end()); |
|
|
|
new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end()); |
|
|
|
kernel_graph_ptr->set_execution_order(new_order_list); |
|
|
|
} |
|
|
|
|
|
|
|
bool KernelAdjust::NeedInsertSwitch() { |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
@@ -124,6 +142,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> |
|
|
|
return; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph_ptr); |
|
|
|
ReorderGetNext(kernel_graph_ptr); |
|
|
|
std::map<std::string, mindspore::ParameterPtr> switch_loop_input; |
|
|
|
CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input); |
|
|
|
|
|
|
|
|