From c0f4a2a28c2812749cdcb640e0e6f26930e0cd5c Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Tue, 1 Dec 2020 19:39:46 +0800 Subject: [PATCH] review pipeline parallel --- .../pipeline_transformer/pipeline_transformer.cc | 10 +++++----- .../pipeline_transformer/pipeline_transformer.h | 12 ++++++------ mindspore/ccsrc/pipeline/jit/pipeline_split.cc | 13 ++++++++----- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index 7301a1503d..3ffbabcde4 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -197,8 +197,8 @@ static std::pair GetShapeType(const AnfNodePtr &node) { return std::make_pair(shape_list, dtype); } -SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, - const int &user_node_stage, const int &node_stage) { +SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int user_node_stage, + int node_stage) { Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag)); send_tag += 1; auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; @@ -221,7 +221,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod } void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, - const int &index, const int &user_node_stage, const int &node_stage) { + int index, int user_node_stage, int node_stage) { Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag)); recv_tag += 1; auto src_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; @@ -369,9 +369,9 @@ void PipelineTransformer::ElimGraphStage() { } bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { - ValueNodePtr anf_node = cnode->input(0)->cast(); + auto anf_node = cnode->input(0)->cast(); MS_EXCEPTION_IF_NULL(anf_node); - PrimitivePtr prim = anf_node->value()->cast(); + auto prim = anf_node->value()->cast(); return (prim->name() == name); } diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h index a9034c8a5a..cdfaf040e9 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h @@ -34,13 +34,14 @@ typedef struct { class PipelineTransformer { public: - PipelineTransformer(const FuncGraphManagerPtr &manager, const int &stage, const FuncGraphPtr &root, - const int64_t &global_rank, const int64_t &per_stage_rank_num) + PipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank, + int64_t per_stage_rank_num) : manager_(manager), stage_(stage), root_(root), global_rank_(global_rank), per_stage_rank_num_(per_stage_rank_num) {} + virtual ~PipelineTransformer() = default; void Coloring(); void BroadCastColoring(); void HandleSharedParameter(); @@ -54,10 +55,9 @@ class PipelineTransformer { std::pair IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users); bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); void DoBroadCast(const FuncGraphPtr &func); - SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, const int &user_node_stage, - const int &node_stage); - void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, const int &index, - const int &user_node_stage, const int &node_stage); + SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int user_node_stage, int node_stage); + void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, + int user_node_stage, int node_stage); void CutBorder(const FuncGraphPtr &graph); void ElimRootParameter(); bool IsStageNode(const CNodePtr &node); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_split.cc b/mindspore/ccsrc/pipeline/jit/pipeline_split.cc index fc7f9b90ad..594c277a0e 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline_split.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline_split.cc @@ -24,9 +24,6 @@ namespace mindspore { namespace pipeline { - -static int64_t GetRank(); -static int64_t InferStage(const int64_t &rank_id, const int64_t &stage_num, const int64_t &device_num); static int64_t GetRank() { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); @@ -40,7 +37,7 @@ static int64_t GetRank() { MS_LOG(EXCEPTION) << "Invalid backend: " << backend; } int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank(); - uint32_t rank_id; + uint32_t rank_id = 0; if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) { if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { MS_LOG(EXCEPTION) << "Get rank id failed."; @@ -50,7 +47,7 @@ static int64_t GetRank() { return global_rank; } -static int64_t InferStage(const int64_t &rank_id, const int64_t &stage_num, const int64_t &device_num) { +static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) { if (device_num % stage_num != 0) { MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num << "stage_num: " << stage_num; @@ -75,6 +72,12 @@ bool PipelineSplit(const ResourcePtr &res) { auto root = res->func_graph(); auto global_rank = GetRank(); auto device_num = parallel::ParallelContext::GetInstance()->device_num(); + if (device_num < 1) { + MS_LOG(EXCEPTION) << "Invalid device num: " << device_num; + } + if (global_rank < 0) { + MS_LOG(EXCEPTION) << "Invalid global rank: " << global_rank; + } auto stage = InferStage(global_rank, stage_num, device_num); auto per_stage_rank_num = device_num / stage_num; auto transformer =