From: @yangzhenzhang Reviewed-by: @kisnwang,@stsuteng Signed-off-by: @stsutengtags/v1.1.0
| @@ -197,8 +197,8 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node) { | |||||
| return std::make_pair(shape_list, dtype); | return std::make_pair(shape_list, dtype); | ||||
| } | } | ||||
| SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, | |||||
| const int &user_node_stage, const int &node_stage) { | |||||
| SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int user_node_stage, | |||||
| int node_stage) { | |||||
| Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag)); | Attr attr_tag = std::make_pair("sr_tag", MakeValue(send_tag)); | ||||
| send_tag += 1; | send_tag += 1; | ||||
| auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | ||||
| @@ -221,7 +221,7 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod | |||||
| } | } | ||||
| void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, | void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, | ||||
| const int &index, const int &user_node_stage, const int &node_stage) { | |||||
| int index, int user_node_stage, int node_stage) { | |||||
| Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag)); | Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag)); | ||||
| recv_tag += 1; | recv_tag += 1; | ||||
| auto src_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | auto src_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; | ||||
| @@ -369,9 +369,9 @@ void PipelineTransformer::ElimGraphStage() { | |||||
| } | } | ||||
| bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { | bool PipelineTransformer::IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { | ||||
| ValueNodePtr anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| auto anf_node = cnode->input(0)->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | MS_EXCEPTION_IF_NULL(anf_node); | ||||
| PrimitivePtr prim = anf_node->value()->cast<PrimitivePtr>(); | |||||
| auto prim = anf_node->value()->cast<PrimitivePtr>(); | |||||
| return (prim->name() == name); | return (prim->name() == name); | ||||
| } | } | ||||
| @@ -34,13 +34,14 @@ typedef struct { | |||||
| class PipelineTransformer { | class PipelineTransformer { | ||||
| public: | public: | ||||
| PipelineTransformer(const FuncGraphManagerPtr &manager, const int &stage, const FuncGraphPtr &root, | |||||
| const int64_t &global_rank, const int64_t &per_stage_rank_num) | |||||
| PipelineTransformer(const FuncGraphManagerPtr &manager, int stage, const FuncGraphPtr &root, int64_t global_rank, | |||||
| int64_t per_stage_rank_num) | |||||
| : manager_(manager), | : manager_(manager), | ||||
| stage_(stage), | stage_(stage), | ||||
| root_(root), | root_(root), | ||||
| global_rank_(global_rank), | global_rank_(global_rank), | ||||
| per_stage_rank_num_(per_stage_rank_num) {} | per_stage_rank_num_(per_stage_rank_num) {} | ||||
| virtual ~PipelineTransformer() = default; | |||||
| void Coloring(); | void Coloring(); | ||||
| void BroadCastColoring(); | void BroadCastColoring(); | ||||
| void HandleSharedParameter(); | void HandleSharedParameter(); | ||||
| @@ -54,10 +55,9 @@ class PipelineTransformer { | |||||
| std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users); | std::pair<bool, int> IsSharedNode(const AnfNodePtr &node, const AnfNodeIndexSet &node_users); | ||||
| bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); | bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name); | ||||
| void DoBroadCast(const FuncGraphPtr &func); | void DoBroadCast(const FuncGraphPtr &func); | ||||
| SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, const int &user_node_stage, | |||||
| const int &node_stage); | |||||
| void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, const int &index, | |||||
| const int &user_node_stage, const int &node_stage); | |||||
| SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr ¶meter, int user_node_stage, int node_stage); | |||||
| void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, | |||||
| int user_node_stage, int node_stage); | |||||
| void CutBorder(const FuncGraphPtr &graph); | void CutBorder(const FuncGraphPtr &graph); | ||||
| void ElimRootParameter(); | void ElimRootParameter(); | ||||
| bool IsStageNode(const CNodePtr &node); | bool IsStageNode(const CNodePtr &node); | ||||
| @@ -24,9 +24,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace pipeline { | namespace pipeline { | ||||
| static int64_t GetRank(); | |||||
| static int64_t InferStage(const int64_t &rank_id, const int64_t &stage_num, const int64_t &device_num); | |||||
| static int64_t GetRank() { | static int64_t GetRank() { | ||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| @@ -40,7 +37,7 @@ static int64_t GetRank() { | |||||
| MS_LOG(EXCEPTION) << "Invalid backend: " << backend; | MS_LOG(EXCEPTION) << "Invalid backend: " << backend; | ||||
| } | } | ||||
| int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank(); | int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank(); | ||||
| uint32_t rank_id; | |||||
| uint32_t rank_id = 0; | |||||
| if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) { | if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) { | ||||
| if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { | if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) { | ||||
| MS_LOG(EXCEPTION) << "Get rank id failed."; | MS_LOG(EXCEPTION) << "Get rank id failed."; | ||||
| @@ -50,7 +47,7 @@ static int64_t GetRank() { | |||||
| return global_rank; | return global_rank; | ||||
| } | } | ||||
| static int64_t InferStage(const int64_t &rank_id, const int64_t &stage_num, const int64_t &device_num) { | |||||
| static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num) { | |||||
| if (device_num % stage_num != 0) { | if (device_num % stage_num != 0) { | ||||
| MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num | MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num | ||||
| << "stage_num: " << stage_num; | << "stage_num: " << stage_num; | ||||
| @@ -75,6 +72,12 @@ bool PipelineSplit(const ResourcePtr &res) { | |||||
| auto root = res->func_graph(); | auto root = res->func_graph(); | ||||
| auto global_rank = GetRank(); | auto global_rank = GetRank(); | ||||
| auto device_num = parallel::ParallelContext::GetInstance()->device_num(); | auto device_num = parallel::ParallelContext::GetInstance()->device_num(); | ||||
| if (device_num < 1) { | |||||
| MS_LOG(EXCEPTION) << "Invalid device num: " << device_num; | |||||
| } | |||||
| if (global_rank < 0) { | |||||
| MS_LOG(EXCEPTION) << "Invalid global rank: " << global_rank; | |||||
| } | |||||
| auto stage = InferStage(global_rank, stage_num, device_num); | auto stage = InferStage(global_rank, stage_num, device_num); | ||||
| auto per_stage_rank_num = device_num / stage_num; | auto per_stage_rank_num = device_num / stage_num; | ||||
| auto transformer = | auto transformer = | ||||