| @@ -33,6 +33,9 @@ std::vector<std::string> PARALLEL_MODE_LIST = {STAND_ALONE, DATA_PARALLEL, HYBRI | |||
| AUTO_PARALLEL}; | |||
| std::vector<std::string> STRATEGY_SEARCH_MODE_LIST = {DYNAMIC_PROGRAMMING, RECURSIVE_PROGRAMMING}; | |||
| std::vector<std::string> COMMUNI_PARALLEL_MODE_LIST = {ALL_GROUP_PARALLEL, SAME_SERVER_GROUP_PARALLEL, | |||
| NO_GROUP_PARALLEL}; | |||
| std::shared_ptr<ParallelContext> ParallelContext::inst_context_ = nullptr; | |||
| std::shared_ptr<ParallelContext> ParallelContext::GetInstance() { | |||
| @@ -65,6 +68,7 @@ void ParallelContext::Reset() { | |||
| strategy_search_mode_ = DYNAMIC_PROGRAMMING; | |||
| pipeline_stage_split_num_ = 1; | |||
| grad_accumulation_step_ = 1; | |||
| communi_parallel_mode_ = ALL_GROUP_PARALLEL; | |||
| } | |||
| void ParallelContext::set_device_num(int64_t device_num) { | |||
| @@ -152,6 +156,17 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const | |||
| return {}; | |||
| } | |||
| bool ParallelContext::set_communi_parallel_mode(const std::string &communi_parallel_mode) { | |||
| auto iter = std::find(COMMUNI_PARALLEL_MODE_LIST.begin(), COMMUNI_PARALLEL_MODE_LIST.end(), communi_parallel_mode); | |||
| if (iter == COMMUNI_PARALLEL_MODE_LIST.end()) { | |||
| MS_LOG(INFO) << "Invalid communication parallel mode:" << communi_parallel_mode; | |||
| return false; | |||
| } | |||
| communi_parallel_mode_ = communi_parallel_mode; | |||
| return true; | |||
| } | |||
| // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode | |||
| void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| @@ -46,6 +46,10 @@ constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; | |||
| constexpr char TRAINING[] = "training"; | |||
| constexpr char ACCUMULATION[] = "accumulation"; | |||
| constexpr char ALL_GROUP_PARALLEL[] = "all_group_parallel"; | |||
| constexpr char SAME_SERVER_GROUP_PARALLEL[] = "same_server_group_parallel"; | |||
| constexpr char NO_GROUP_PARALLEL[] = "no_group_parallel"; | |||
| class ParallelContext { | |||
| public: | |||
| ~ParallelContext() = default; | |||
| @@ -112,6 +116,9 @@ class ParallelContext { | |||
| } | |||
| bool enable_parallel_optimizer() const { return enable_parallel_optimizer_; } | |||
| bool set_communi_parallel_mode(const std::string &communi_parallel_mode); | |||
| std::string communi_parallel_mode() const { return communi_parallel_mode_; } | |||
| void Reset(); | |||
| void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph); | |||
| void ParallelParameterContextRestoreShape(const FuncGraphPtr &func_graph, const ParameterPtr ¶m_node, | |||
| @@ -144,6 +151,7 @@ class ParallelContext { | |||
| std::string group_ckpt_save_file_; | |||
| bool enable_parallel_optimizer_; | |||
| bool init_param_shape_; | |||
| std::string communi_parallel_mode_; | |||
| }; | |||
| } // namespace parallel | |||
| @@ -169,6 +169,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| "Set enable/disable parallel optimizer.") | |||
| .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer, | |||
| "Get enable/disable parallel optimizer.") | |||
| .def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.") | |||
| .def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.") | |||
| .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); | |||
| (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") | |||
| @@ -40,6 +40,7 @@ namespace ascend { | |||
| namespace { | |||
| constexpr uint32_t kDeviceNumOfServer = 8; | |||
| constexpr uint32_t kDeviceNumThreshold = 1024; | |||
| const char kDefaultGroup[] = "__default_group"; | |||
| constexpr uint32_t kMaxStreamNum = 1024; | |||
| constexpr uint32_t kHcomSecondaryStreamNum = 3; | |||
| @@ -60,13 +61,48 @@ bool IsSameServer(const std::vector<uint32_t> &rank_ids) { | |||
| return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer)); | |||
| } | |||
| string DoGetHcomGroup(const string &original_group) { | |||
| string communi_parallel_mode = parallel::ParallelContext::GetInstance()->communi_parallel_mode(); | |||
| if (communi_parallel_mode == parallel::ALL_GROUP_PARALLEL) { | |||
| return original_group; | |||
| } | |||
| if (communi_parallel_mode == parallel::NO_GROUP_PARALLEL) { | |||
| return kDefaultGroup; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(parallel::g_device_manager); | |||
| auto group_info = parallel::g_device_manager->group_info(); | |||
| for (const auto &info : group_info) { | |||
| if (info.first != original_group) { | |||
| continue; | |||
| } | |||
| const auto &rank_ids = info.second; | |||
| if (IsSameServer(rank_ids)) { | |||
| return original_group; | |||
| } else { | |||
| return kDefaultGroup; | |||
| } | |||
| } | |||
| // world group is not in group_info. | |||
| return kDefaultGroup; | |||
| } | |||
| string GetHcomGroup(const CNodePtr &cnode) { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) { | |||
| MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute."; | |||
| } | |||
| return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup); | |||
| auto group_name = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup); | |||
| auto new_group = DoGetHcomGroup(group_name); | |||
| MS_LOG_INFO << "hcom node: " << cnode->fullname_with_scope() << ", old group: " << group_name | |||
| << ", new group: " << new_group; | |||
| return new_group; | |||
| } | |||
| uint32_t GetHcomTaskNum(const CNodePtr &cnode) { | |||
| @@ -167,6 +203,9 @@ StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, u | |||
| void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) { | |||
| if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) { | |||
| MS_LOG(INFO) << "Communication parallel mode: " << parallel::ParallelContext::GetInstance()->communi_parallel_mode() | |||
| << "."; | |||
| Reset(); | |||
| SetLoopSink(); | |||
| ReorderIndependentOrders(graph_ptr); | |||
| @@ -480,6 +480,26 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_enable_parallel_optimizer() | |||
| def set_communi_parallel_mode(self, communi_parallel_mode): | |||
| """ | |||
| Set communication parallel mode. | |||
| Args: | |||
| communi_parallel_mode (str): The communication parallel mode. | |||
| Raises: | |||
| ValueError: If parallel mode is not supported. | |||
| """ | |||
| self.check_context_handle() | |||
| ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode) | |||
| if ret is False: | |||
| raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode)) | |||
| def get_communi_parallel_mode(self): | |||
| """Get communication parallel mode.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_communi_parallel_mode() | |||
| def reset(self): | |||
| """Reset all settings.""" | |||
| self.check_context_handle() | |||
| @@ -518,7 +538,8 @@ _set_auto_parallel_context_func_map = { | |||
| "full_batch": auto_parallel_context().set_full_batch, | |||
| "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, | |||
| "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, | |||
| "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices} | |||
| "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices, | |||
| "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode} | |||
| _get_auto_parallel_context_func_map = { | |||
| @@ -536,14 +557,16 @@ _get_auto_parallel_context_func_map = { | |||
| "full_batch": auto_parallel_context().get_full_batch, | |||
| "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, | |||
| "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step, | |||
| "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices} | |||
| "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices, | |||
| "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode} | |||
| @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, | |||
| loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, | |||
| parameter_broadcast=bool, strategy_ckpt_load_file=str, | |||
| strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, | |||
| grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str) | |||
| grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, | |||
| communi_parallel_mode=str) | |||
| def _set_auto_parallel_context(**kwargs): | |||
| """ | |||
| @@ -592,6 +615,14 @@ def _set_auto_parallel_context(**kwargs): | |||
| the devices are distributed alone the pipeline. The total devices will be divided into | |||
| 'pipeline_stags' stages. This currently could only be used when | |||
| parall mode semi_auto_parallel is enabled. Default: 0 | |||
| communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel", | |||
| "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel". | |||
| - all_group_parallel: All communication groups are in parallel. | |||
| - same_server_group_parallel: Only the communication groups within the same server are parallel. | |||
| - no_group_parallel: All communication groups are not parallel. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -21,19 +21,22 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||
| def test_set_auto_parallel_context(): | |||
| context.set_auto_parallel_context(device_num=4, global_rank=3, gradients_mean=True, gradient_fp32_sync=False, | |||
| parallel_mode="auto_parallel", parameter_broadcast=False) | |||
| parallel_mode="auto_parallel", parameter_broadcast=False, | |||
| communi_parallel_mode="same_server_group_parallel") | |||
| device_num = context.get_auto_parallel_context("device_num") | |||
| global_rank = context.get_auto_parallel_context("global_rank") | |||
| gradients_mean = context.get_auto_parallel_context("gradients_mean") | |||
| gradient_fp32_sync = context.get_auto_parallel_context("gradient_fp32_sync") | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| parameter_broadcast = context.get_auto_parallel_context("parameter_broadcast") | |||
| communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode") | |||
| assert device_num == 4 | |||
| assert global_rank == 3 | |||
| assert gradients_mean | |||
| assert not gradient_fp32_sync | |||
| assert parallel_mode == "auto_parallel" | |||
| assert not parameter_broadcast | |||
| assert communi_parallel_mode == "same_server_group_parallel" | |||
| auto_parallel_context().set_device_num(4) | |||
| device_num = auto_parallel_context().get_device_num() | |||
| @@ -77,6 +80,9 @@ def test_set_auto_parallel_context(): | |||
| with pytest.raises(ValueError): | |||
| set_algo_parameters(tensor_slice_align_size=1025) | |||
| with pytest.raises(ValueError): | |||
| context.set_auto_parallel_context(communi_parallel_mode="wrong_mode") | |||
| context.set_auto_parallel_context(enable_parallel_optimizer=True) | |||
| assert context.get_auto_parallel_context("enable_parallel_optimizer") | |||
| assert not auto_parallel_context().get_all_reduce_fusion_split_indices() | |||
| @@ -98,6 +104,7 @@ def test_reset_auto_parallel_context(): | |||
| device_num_is_set = auto_parallel_context().get_device_num_is_set() | |||
| parameter_broadcast_is_set = auto_parallel_context().get_parameter_broadcast_is_set() | |||
| stage = auto_parallel_context().get_pipeline_stages() | |||
| communi_parallel_mode = context.get_auto_parallel_context("communi_parallel_mode") | |||
| assert device_num == 1 | |||
| assert global_rank == 0 | |||
| @@ -108,3 +115,4 @@ def test_reset_auto_parallel_context(): | |||
| assert not device_num_is_set | |||
| assert not parameter_broadcast_is_set | |||
| assert stage == 1 | |||
| assert communi_parallel_mode == "all_group_parallel" | |||