Browse Source

!10691 [PipelineSplit]Opt PipelineSplit

From: @lichen666
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
6e5be437e1
3 changed files with 59 additions and 19 deletions
  1. +35
    -10
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc
  2. +5
    -2
      mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h
  3. +19
    -7
      mindspore/ccsrc/frontend/parallel/step_parallel.cc

+ 35
- 10
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc View File

@@ -286,7 +286,7 @@ void PipelineTransformer::HandleSharedParameter() {
manager_->SetEdge(node, user.second, depend); manager_->SetEdge(node, user.second, depend);
break; break;
} else { } else {
InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin());
(void)InsertReceive(graph, parameter, node, user.second, stage_, *parameter_stage.begin());
break; break;
} }
} }
@@ -403,8 +403,9 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
return send_out; return send_out;
} }


void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node,
int index, int64_t user_node_stage, int64_t node_stage) {
AnfNodePtr PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node,
const AnfNodePtr &use_node, int index, int64_t user_node_stage,
int64_t node_stage) {
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
int64_t recv_tag; int64_t recv_tag;
if (recv_tag_map.find(src_rank) != recv_tag_map.end()) { if (recv_tag_map.find(src_rank) != recv_tag_map.end()) {
@@ -464,6 +465,28 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
recv->set_user_data<OperatorInfo>(op_info_pair.first); recv->set_user_data<OperatorInfo>(op_info_pair.first);
} }
manager_->SetEdge(use_node, index, recv); manager_->SetEdge(use_node, index, recv);
return recv;
}

bool PipelineTransformer::Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage,
const std::vector<AnfNodePtr> &out_input) {
auto node_users = manager_->node_users()[node];
auto dest_rank = global_rank_ + (next_node_stage - node_stage) * per_stage_rank_num_;
for (auto &depend : out_input) {
if (!IsPrimitiveCNode(depend, prim::kPrimDepend)) {
continue;
}
auto cnode = depend->cast<CNodePtr>();
if (cnode->input(1) == node) {
auto send_cnode = cnode->input(2)->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(send_cnode->input(0));
auto dest_rank_send = GetValue<int64_t>(prim->GetAttr("dest_rank"));
if (dest_rank_send == dest_rank) {
return true;
}
}
}
return false;
} }


std::pair<bool, int64_t> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) { std::pair<bool, int64_t> PipelineTransformer::IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users) {
@@ -496,6 +519,7 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
auto shared_min_tag_pair = IsSharedNode(node, node_users); auto shared_min_tag_pair = IsSharedNode(node, node_users);
auto is_shared = shared_min_tag_pair.first; auto is_shared = shared_min_tag_pair.first;
auto min_tag = shared_min_tag_pair.second; auto min_tag = shared_min_tag_pair.second;
AnfNodePtr receive = nullptr;
for (auto &user_pair : node_users) { for (auto &user_pair : node_users) {
auto user_node = user_pair.first; auto user_node = user_pair.first;
auto node_stage = node->stage(); auto node_stage = node->stage();
@@ -508,18 +532,19 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
continue; continue;
} }
if (node_stage == stage_) { if (node_stage == stage_) {
if (Reuse(node, user_node_stage, node_stage, out_input)) {
continue;
}
auto send_out = InsertSend(graph, node, user_node_stage, node_stage); auto send_out = InsertSend(graph, node, user_node_stage, node_stage);
out_input.insert(out_input.begin() + 1, send_out.depend); out_input.insert(out_input.begin() + 1, send_out.depend);
type_ptr_ = send_out.type; type_ptr_ = send_out.type;
shape_ = send_out.shape; shape_ = send_out.shape;
} else { } else {
InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage);
}
continue;
}
if (node_stage == user_node_stage) {
if (is_shared && (min_tag != node_stage)) {
InsertReceive(graph, node, user_node, user_pair.second, stage_, min_tag);
if (!receive) {
receive = InsertReceive(graph, node, user_node, user_pair.second, user_node_stage, node_stage);
} else {
manager_->SetEdge(user_node, user_pair.second, receive);
}
} }
continue; continue;
} }


+ 5
- 2
mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h View File

@@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <memory> #include <memory>
#include <vector>
#include "ir/value.h" #include "ir/value.h"
#include "ir/graph_utils.h" #include "ir/graph_utils.h"
#include "base/base.h" #include "base/base.h"
@@ -62,11 +63,13 @@ class PipelineTransformer {
void DoBroadCast(const FuncGraphPtr &func); void DoBroadCast(const FuncGraphPtr &func);
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int64_t user_node_stage, SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, int64_t user_node_stage,
int64_t node_stage); int64_t node_stage);
void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
int64_t user_node_stage, int64_t node_stage);
AnfNodePtr InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
int64_t user_node_stage, int64_t node_stage);
void SetNoStageNode(const FuncGraphPtr &func); void SetNoStageNode(const FuncGraphPtr &func);
void CutBorder(const FuncGraphPtr &graph); void CutBorder(const FuncGraphPtr &graph);
bool IsStageNode(const CNodePtr &node); bool IsStageNode(const CNodePtr &node);
bool Reuse(const AnfNodePtr &node, int64_t next_node_stage, int64_t node_stage,
const std::vector<AnfNodePtr> &out_input);
AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node); AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node); std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node); std::pair<OperatorInfoPtr, TensorInfoPtr> GetParameterPair(const AnfNodePtr &node);


+ 19
- 7
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -2675,6 +2675,20 @@ void InsertShapeOp(const CNodePtr &node, const AnfNodePtr &pre_node, const FuncG
InsertNode(op, node, 2, pre_node, root, "shape"); InsertNode(op, node, 2, pre_node, root, "shape");
} }


static AnfNodePtr FindGrad(const CNodePtr &cnode) {
for (auto &node : cnode->inputs()) {
if (!node->isa<CNode>()) {
continue;
}
if (!IsPrimitiveCNode(node, prim::kPrimEnvGetItem)) {
return FindGrad(node->cast<CNodePtr>());
} else {
return node;
}
}
return nullptr;
}

void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) { void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes) {
// If root graph has reshape op. Find the corresponding parameter. // If root graph has reshape op. Find the corresponding parameter.
// Reshape's shape is the shape of the parameter. // Reshape's shape is the shape of the parameter.
@@ -2706,12 +2720,9 @@ void HandleRootReshapeAndSaveStrategy(const std::vector<AnfNodePtr> &all_nodes)
continue; continue;
} }
auto root = node->func_graph(); auto root = node->func_graph();
auto all_dfs_nodes = DeepLinkedGraphSearch(node);
for (auto r_iter = all_dfs_nodes.rbegin(); r_iter != all_dfs_nodes.rend(); ++r_iter) {
if ((*r_iter)->isa<Parameter>()) {
InsertShapeOp(cnode, *r_iter, root);
break;
}
auto grad_node = FindGrad(cnode);
if (grad_node) {
InsertShapeOp(cnode, grad_node, root);
} }
} }
} }
@@ -3113,7 +3124,8 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
std::reverse(all_nodes.begin(), all_nodes.end()); std::reverse(all_nodes.begin(), all_nodes.end());
if (parallel_mode != AUTO_PARALLEL) { if (parallel_mode != AUTO_PARALLEL) {
TOTAL_OPS = 0; TOTAL_OPS = 0;
if (ParallelInit() != SUCCESS) {
auto pipeline_stages = ParallelContext::GetInstance()->pipeline_stage_split_num();
if (pipeline_stages <= 1 && ParallelInit() != SUCCESS) {
MS_LOG(EXCEPTION) << "Parallel init failed"; MS_LOG(EXCEPTION) << "Parallel init failed";
} }




Loading…
Cancel
Save