| @@ -236,7 +236,7 @@ RankList DeviceManager::GetDeviceListBetweenStage() const { | |||
| if (stage_num < 1) { | |||
| MS_LOG(EXCEPTION) << "Stage num got " << stage_num << ", expected a positive integer."; | |||
| } | |||
| auto device_num = parallel::ParallelContext::GetInstance()->device_num(); | |||
| auto device_num = DeviceNum(); | |||
| auto per_stage_rank_num = device_num / stage_num; | |||
| for (int64_t i = 0; i < stage_num; ++i) { | |||
| rank_list.push_back(rank_id + per_stage_rank_num * (i - stage_id)); | |||
| @@ -179,16 +179,7 @@ void PipelineTransformer::LabelMicroBatch() { | |||
| } | |||
| void PipelineTransformer::CreateForwardGroup() { | |||
| std::vector<int64_t> rank_list; | |||
| auto rank_id = g_device_manager->global_rank(); | |||
| auto stage_id = g_device_manager->stage_id(); | |||
| auto stage_num = g_device_manager->stage_num(); | |||
| if (stage_num < 1) { | |||
| MS_LOG(EXCEPTION) << "Stage num got " << stage_num << ", expected a positive integer."; | |||
| } | |||
| for (int64_t i = 0; i < stage_num; ++i) { | |||
| rank_list.push_back(rank_id + per_stage_rank_num_ * (i - stage_id)); | |||
| } | |||
| std::vector<int64_t> rank_list = g_device_manager->GetDeviceListBetweenStage(); | |||
| auto dev_list = g_device_manager->CreateDeviceListByRankList(rank_list); | |||
| auto g = g_device_manager->CreateGroup(rank_list); | |||
| auto g_back_name = g.name() + BACKWARD; | |||
| @@ -208,20 +208,12 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An | |||
| const FuncGraphPtr &root = nullptr) { | |||
| // insert new node before the node | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| auto node_user_map = manager->node_users(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| ScopePtr scope = node->scope(); | |||
| MS_EXCEPTION_IF_NULL(scope); | |||
| std::vector<AnfNodePtr> node_input; | |||
| AnfNodePtr pre_node_ = pre_node; | |||
| if (root && !param_name.empty()) { | |||
| TypePtr next_node_dtype = FindChildCastWithFP32ToFP16(node, node_user_map); | |||
| if (next_node_dtype) { | |||
| MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving" | |||
| << " communication."; | |||
| pre_node_ = CreateFP16Cast(node, pre_node, next_node_dtype); | |||
| } | |||
| node_input = CreateMirrorInput(root, op, pre_node_, instance_name, param_name); | |||
| node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name); | |||
| } else { | |||
| node_input = CreateInput(op, pre_node, instance_name); | |||
| } | |||
| @@ -1572,7 +1564,16 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group | |||
| allgather = ReplaceNode(op, cast_node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, root); | |||
| MS_LOG(INFO) << "Parallel optimizer is applied before Cast for " << param_name; | |||
| } else { | |||
| InsertNode(op, cnode, IntToSize(res.second), node, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, | |||
| auto pre_node = node; | |||
| AnfNodePtr pre_node_ = node; | |||
| auto node_user_map = manager->node_users(); | |||
| TypePtr next_node_dtype = FindChildCastWithFP32ToFP16(cnode, node_user_map); | |||
| if (next_node_dtype) { | |||
| MS_LOG(INFO) << "Inserting Cast from float32 to float16 for node " << node->fullname_with_scope() << " for saving" | |||
| << " communication."; | |||
| pre_node_ = CreateFP16Cast(cnode, pre_node, next_node_dtype); | |||
| } | |||
| InsertNode(op, cnode, IntToSize(res.second), pre_node_, graph, PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE, param_name, | |||
| root); | |||
| allgather = cnode->input(IntToSize(res.second))->cast<CNodePtr>(); | |||
| MS_LOG(INFO) << "Parallel optimizer is applied before " << GetPrimName(cnode) << " for " << param_name; | |||
| @@ -3188,7 +3189,7 @@ static void InsertAllReduceForNormValue(const AnfNodePtr &res_node) { | |||
| } | |||
| } | |||
| AnfNodePtr FindPrimitiveWithAtrribute(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map, uint32_t limits) { | |||
| AnfNodePtr FindExpanDimsWIthGradScale(const AnfNodePtr &node_ptr, const NodeUsersMap &node_users_map, uint32_t limits) { | |||
| std::queue<AnfNodePtr> visited; | |||
| AnfNodePtr queue_node = nullptr; | |||
| CNodePtr cnode = nullptr; | |||
| @@ -3239,7 +3240,7 @@ static void InsertDivAndAllReduceForNorm(const NodeUsersMap &node_user_map, cons | |||
| if (cnode->in_forward_flag()) { | |||
| continue; | |||
| } | |||
| expand_dims_node = FindPrimitiveWithAtrribute(cnode, node_user_map, MAX_BFS_DEPTH); | |||
| expand_dims_node = FindExpanDimsWIthGradScale(cnode, node_user_map, MAX_BFS_DEPTH); | |||
| if (!expand_dims_node) { | |||
| continue; | |||
| } | |||
| @@ -124,8 +124,8 @@ uint32_t GetHcomTaskNum(const CNodePtr &cnode) { | |||
| return kTaskNumPerHcomSendRecvNode; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); | |||
| auto device_num = parallel::ParallelContext::GetInstance()->device_num(); | |||
| MS_EXCEPTION_IF_NULL(parallel::g_device_manager); | |||
| auto device_num = parallel::g_device_manager->DeviceNum(); | |||
| auto group_name = common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup); | |||
| auto group_info = parallel::g_device_manager->group_info(); | |||
| for (const auto &info : group_info) { | |||
| @@ -117,7 +117,7 @@ class TestSharedParameterCast: | |||
| """ | |||
| auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)), | |||
| interleaved_batch=1) | |||
| self.cat_fp16_from_ir(target_count=28) | |||
| self.cat_fp16_from_ir(target_count=23) | |||
| def test_optimizer_fp16_micro_batch(self): | |||
| """ | |||
| @@ -127,7 +127,7 @@ class TestSharedParameterCast: | |||
| """ | |||
| auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)), | |||
| interleaved_batch=2) | |||
| self.cat_fp16_from_ir(target_count=42) | |||
| self.cat_fp16_from_ir(target_count=39) | |||
| def test_optimizer_fp16_pipeline(self): | |||
| """ | |||
| @@ -138,7 +138,7 @@ class TestSharedParameterCast: | |||
| auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)), | |||
| interleaved_batch=1, | |||
| stages=1, micro_size=1) | |||
| self.cat_fp16_from_ir(target_count=28) | |||
| self.cat_fp16_from_ir(target_count=23) | |||
| def test_optimizer_fp16_pipeline_micro_batch(self): | |||
| """ | |||
| @@ -149,4 +149,4 @@ class TestSharedParameterCast: | |||
| auto_parallel_compile_net("semi_auto_parallel", 8, Net, ((8, 1), (1, 1)), ((8, 1), (1, 1)), | |||
| interleaved_batch=2, | |||
| stages=1, micro_size=1) | |||
| self.cat_fp16_from_ir(target_count=42) | |||
| self.cat_fp16_from_ir(target_count=39) | |||
| @@ -419,7 +419,7 @@ class TestPipelineSplitWithNoOptimizer: | |||
| self.cat_fp16_from_ir(pattern='grad_mirror_MirrorMicroStepOperator', | |||
| target_count=2) | |||
| self.cat_fp16_from_ir(pattern='Cast(', | |||
| target_count=16) | |||
| target_count=14) | |||
| def test_pipeline_with_micro_batch_no_parallel_optimizer(self): | |||
| """ | |||
| @@ -438,4 +438,4 @@ class TestPipelineSplitWithNoOptimizer: | |||
| self.cat_fp16_from_ir(pattern='grad_mirror_MirrorMicroStepOperator', | |||
| target_count=2) | |||
| self.cat_fp16_from_ir(pattern='Cast(', | |||
| target_count=28) | |||
| target_count=26) | |||