Browse Source

!854 reorder getnext firstly for getnex parallel

Merge pull request !854 from laiyongqiang/reorder_getnext
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 5 years ago
parent
commit
bd6ece2900
2 changed files with 20 additions and 0 deletions
  1. +19
    -0
      mindspore/ccsrc/device/kernel_adjust.cc
  2. +1
    -0
      mindspore/ccsrc/device/kernel_adjust.h

+ 19
- 0
mindspore/ccsrc/device/kernel_adjust.cc View File

@@ -55,6 +55,24 @@ void KernelAdjust::Reorder(const std::shared_ptr<session::KernelGraph> &kernel_g
kernel_graph_ptr->set_execution_order(new_order_list);
}

void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
const std::vector<CNodePtr> &origin_cnode_list = kernel_graph_ptr->execution_order();
std::vector<CNodePtr> getnext_list;
std::vector<CNodePtr> other_list;
for (const auto &cnode : origin_cnode_list) {
if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
getnext_list.emplace_back(cnode);
} else {
other_list.emplace_back(cnode);
}
}
std::vector<CNodePtr> new_order_list;
new_order_list.insert(new_order_list.end(), getnext_list.begin(), getnext_list.end());
new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end());
kernel_graph_ptr->set_execution_order(new_order_list);
}

bool KernelAdjust::NeedInsertSwitch() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@@ -124,6 +142,7 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
return;
}
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
ReorderGetNext(kernel_graph_ptr);
std::map<std::string, mindspore::ParameterPtr> switch_loop_input;
CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input);



+ 1
- 0
mindspore/ccsrc/device/kernel_adjust.h View File

@@ -63,6 +63,7 @@ class KernelAdjust {
KernelAdjust() = default;
~KernelAdjust() = default;

void ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id);
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id);
uint32_t FindFirstStreamSwitchLabel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);


Loading…
Cancel
Save