From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.1.0
| @@ -197,8 +197,8 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node) { | |||
| return std::make_pair(shape_list, dtype); | |||
| } | |||
| SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, | |||
| const int &user_node_stage, const int &node_stage) { | |||
| SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int user_node_stage, | |||
| int node_stage) { | |||
| Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag)); | |||
| send_tag += 1; | |||
| auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | |||
| @@ -221,7 +221,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod | |||
| } | |||
| void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, | |||
| const int &index, const int &user_node_stage, const int &node_stage) { | |||
| int index, int user_node_stage, int node_stage) { | |||
| Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag)); | |||
| recv_tag += 1; | |||
| auto src_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | |||
| @@ -369,9 +369,9 @@ void PipelineTransformer::ElimGraphStage() { | |||
| } | |||
| bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { | |||
| ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| auto anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>(); | |||
| auto prim = anf_node->value()->cast<PrimitivePtr>(); | |||
| return (prim->name() == name); | |||
| } | |||
| @@ -34,13 +34,14 @@ typedef struct { | |||
| class PipelineTransformer { | |||
| public: | |||
| PipelineTransformer(const FuncGraphManagerPtr &manager, const int &stage, const FuncGraphPtr &root, | |||
| const int64_t &global_rank, const int64_t &per_stage_rank_num) | |||
| PipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank, | |||
| int64_t per_stage_rank_num) | |||
| : manager_(manager), | |||
| stage_(stage), | |||
| root_(root), | |||
| global_rank_(global_rank), | |||
| per_stage_rank_num_(per_stage_rank_num) {} | |||
| virtual ~PipelineTransformer() = default; | |||
| void Coloring(); | |||
| void BroadCastColoring(); | |||
| void HandleSharedParameter(); | |||
| @@ -54,10 +55,9 @@ class PipelineTransformer { | |||
| std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users); | |||
| bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); | |||
| void DoBroadCast(const FuncGraphPtr &func); | |||
| SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, const int &user_node_stage, | |||
| const int &node_stage); | |||
| void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, const int &index, | |||
| const int &user_node_stage, const int &node_stage); | |||
| SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int user_node_stage, int node_stage); | |||
| void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, | |||
| int user_node_stage, int node_stage); | |||
| void CutBorder(const FuncGraphPtr &graph); | |||
| void ElimRootParameter(); | |||
| bool IsStageNode(const CNodePtr &node); | |||
| @@ -24,9 +24,6 @@ | |||
| namespace mindspore { | |||
| namespace pipeline { | |||
| static int64_t GetRank(); | |||
| static int64_t InferStage(const int64_t &rank_id, const int64_t &stage_num, const int64_t &device_num); | |||
| static int64_t GetRank() { | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| @@ -40,7 +37,7 @@ static int64_t GetRank() { | |||
| MS_LOG(EXCEPTION) << "Invalid backend: " << backend; | |||
| } | |||
| int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank(); | |||
| uint32_t rank_id; | |||
| uint32_t rank_id = 0; | |||
| if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) { | |||
| if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { | |||
| MS_LOG(EXCEPTION) << "Get rank id failed."; | |||
| @@ -50,7 +47,7 @@ static int64_t GetRank() { | |||
| return global_rank; | |||
| } | |||
| static int64_t InferStage(const int64_t &rank_id, const int64_t &stage_num, const int64_t &device_num) { | |||
| static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) { | |||
| if (device_num % stage_num != 0) { | |||
| MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num | |||
| << "stage_num: " << stage_num; | |||
| @@ -75,6 +72,12 @@ bool PipelineSplit(const ResourcePtr &res) { | |||
| auto root = res->func_graph(); | |||
| auto global_rank = GetRank(); | |||
| auto device_num = parallel::ParallelContext::GetInstance()->device_num(); | |||
| if (device_num < 1) { | |||
| MS_LOG(EXCEPTION) << "Invalid device num: " << device_num; | |||
| } | |||
| if (global_rank < 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid global rank: " << global_rank; | |||
| } | |||
| auto stage = InferStage(global_rank, stage_num, device_num); | |||
| auto per_stage_rank_num = device_num / stage_num; | |||
| auto transformer = | |||