From 39306d64fbea61fd3faf7d88847558afc1d84a2c Mon Sep 17 00:00:00 2001 From: lichenever Date: Mon, 28 Dec 2020 15:09:19 +0800 Subject: [PATCH] opt_pipeline_split --- .../pipeline_transformer.cc | 45 ++++++++++++++----- .../pipeline_transformer.h | 7 ++- .../ccsrc/frontend/parallel/step_parallel.cc | 26 ++++++++--- 3 files changed, 59 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index 1382431088..7e490c63fa 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -286,7 +286,7 @@ void PipelineTransformer::HandleSharedParameter() { manager_->SetEdge(node, user.second, depend); break; } else { - InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin()); + (void)InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin()); break; } } @@ -403,8 +403,9 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod return send_out; } -void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, - int index, int64_t user_node_stage, int64_t node_stage) { +AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, + const AnfNodePtr &use_node, int index, int64_t user_node_stage, + int64_t node_stage) { auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; int64_t recv_tag; if (recv_tag_map.find(src_rank) != recv_tag_map.end()) { @@ -464,6 +465,28 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode recv->set_user_data(op_info_pair.first); } manager_->SetEdge(use_node, index, recv); + return recv; +} + +bool PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage, + const std::vector &out_input) { + auto node_users = manager_->node_users()[node]; + auto dest_rank = global_rank_ + (next_node_stage - node_stage) * per_stage_rank_num_; + for (auto &depend : out_input) { + if (!IsPrimitiveCNode(depend, prim::kPrimDepend)) { + continue; + } + auto cnode = depend->cast(); + if (cnode->input(1) == node) { + auto send_cnode = cnode->input(2)->cast(); + auto prim = GetValueNode(send_cnode->input(0)); + auto dest_rank_send = GetValue(prim->GetAttr("dest_rank")); + if (dest_rank_send == dest_rank) { + return true; + } + } + } + return false; } std::pair PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) { @@ -496,6 +519,7 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { auto shared_min_tag_pair = IsSharedNode(node, node_users); auto is_shared = shared_min_tag_pair.first; auto min_tag = shared_min_tag_pair.second; + AnfNodePtr receive = nullptr; for (auto &user_pair : node_users) { auto user_node = user_pair.first; auto node_stage = node->stage(); @@ -508,18 +532,19 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { continue; } if (node_stage == stage_) { + if (Reuse(node, user_node_stage, node_stage, out_input)) { + continue; + } auto send_out = InsertSend(graph, node, user_node_stage, node_stage); out_input.insert(out_input.begin() + 1, send_out.depend); type_ptr_ = send_out.type; shape_ = send_out.shape; } else { - InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage); - } - continue; - } - if (node_stage == user_node_stage) { - if (is_shared && (min_tag != node_stage)) { - InsertReceive(graph, node, user_node, user_pair.second, stage_, min_tag); + if (!receive) { + receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage); + } else { + manager_->SetEdge(user_node, user_pair.second, receive); + } } continue; } diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h index 290549eeb0..694c049b11 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "ir/value.h" #include "ir/graph_utils.h" #include "base/base.h" @@ -62,11 +63,13 @@ class PipelineTransformer { void DoBroadCast(const FuncGraphPtr &func); SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int64_t user_node_stage, int64_t node_stage); - void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, - int64_t user_node_stage, int64_t node_stage); + AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, + int64_t user_node_stage, int64_t node_stage); void SetNoStageNode(const FuncGraphPtr &func); void CutBorder(const FuncGraphPtr &graph); bool IsStageNode(const CNodePtr &node); + bool Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage, + const std::vector &out_input); AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node); std::pair GetOpInfo(const AnfNodePtr &node); std::pair GetParameterPair(const AnfNodePtr &node); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 79f90bf977..3e69ddc74d 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -2675,6 +2675,20 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG InsertNode(op, node, 2, pre_node, root, "shape"); } +static AnfNodePtr FindGrad(const CNodePtr &cnode) { + for (auto &node : cnode->inputs()) { + if (!node->isa()) { + continue; + } + if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) { + return FindGrad(node->cast()); + } else { + return node; + } + } + return nullptr; +} + void HandleRootReshapeAndSaveStrategy(const std::vector &all_nodes) { // If root graph has reshape op. Find the corresponding parameter. // Reshape's shape is the shape of the parameter. @@ -2706,12 +2720,9 @@ void HandleRootReshapeAndSaveStrategy(const std::vector &all_nodes) continue; } auto root = node->func_graph(); - auto all_dfs_nodes = DeepLinkedGraphSearch(node); - for (auto r_iter = all_dfs_nodes.rbegin(); r_iter != all_dfs_nodes.rend(); ++r_iter) { - if ((*r_iter)->isa()) { - InsertShapeOp(cnode, *r_iter, root); - break; - } + auto grad_node = FindGrad(cnode); + if (grad_node) { + InsertShapeOp(cnode, grad_node, root); } } } @@ -3113,7 +3124,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) std::reverse(all_nodes.begin(), all_nodes.end()); if (parallel_mode != AUTO_PARALLEL) { TOTAL_OPS = 0; - if (ParallelInit() != SUCCESS) { + auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num(); + if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) { MS_LOG(EXCEPTION) << "Parallel init failed"; }