| @@ -113,20 +113,28 @@ void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ck | |||
| strategy_ckpt_save_file_ = strategy_ckpt_save_file; | |||
| } | |||
| void ParallelContext::set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices) { | |||
| all_reduce_fusion_split_indices_ = indices; | |||
| void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) { | |||
| all_reduce_fusion_split_indices_[group] = indices; | |||
| } | |||
| const std::vector<uint32_t> ParallelContext::all_reduce_fusion_split_indices() const { | |||
| return all_reduce_fusion_split_indices_; | |||
| const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitIndices(const std::string &group) const { | |||
| auto iter = all_reduce_fusion_split_indices_.find(group); | |||
| if (iter != all_reduce_fusion_split_indices_.end()) { | |||
| return iter->second; | |||
| } | |||
| return {}; | |||
| } | |||
| void ParallelContext::set_all_reduce_fusion_split_sizes(const std::vector<uint32_t> sizes) { | |||
| all_reduce_fusion_split_sizes_ = sizes; | |||
| void ParallelContext::SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group) { | |||
| all_reduce_fusion_split_sizes_[group] = sizes; | |||
| } | |||
| const std::vector<uint32_t> ParallelContext::all_reduce_fusion_split_sizes() const { | |||
| return all_reduce_fusion_split_sizes_; | |||
| const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const std::string &group) const { | |||
| auto iter = all_reduce_fusion_split_sizes_.find(group); | |||
| if (iter != all_reduce_fusion_split_sizes_.end()) { | |||
| return iter->second; | |||
| } | |||
| return {}; | |||
| } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include <cstdint> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -76,10 +77,10 @@ class ParallelContext { | |||
| bool global_rank_is_set() const { return global_rank_is_set_; } | |||
| bool parameter_broadcast_is_set() const { return parameter_broadcast_is_set_; } | |||
| void set_all_reduce_fusion_split_indices(const std::vector<uint32_t> indices); | |||
| const std::vector<uint32_t> all_reduce_fusion_split_indices() const; | |||
| void set_all_reduce_fusion_split_sizes(const std::vector<uint32_t> sizes); | |||
| const std::vector<uint32_t> all_reduce_fusion_split_sizes() const; | |||
| void SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group); | |||
| const std::vector<uint32_t> GetAllReduceFusionSplitIndices(const std::string &group) const; | |||
| void SetAllReduceFusionSplitSizes(const std::vector<uint32_t> sizes, const std::string &group); | |||
| const std::vector<uint32_t> GetAllReduceFusionSplitSizes(const std::string &group) const; | |||
| void set_enable_all_reduce_fusion(bool enable_all_reduce_fusion) { | |||
| enable_all_reduce_fusion_ = enable_all_reduce_fusion; | |||
| } | |||
| @@ -108,8 +109,8 @@ class ParallelContext { | |||
| bool global_rank_is_set_; | |||
| bool parameter_broadcast_is_set_; | |||
| bool enable_all_reduce_fusion_; | |||
| std::vector<uint32_t> all_reduce_fusion_split_indices_; | |||
| std::vector<uint32_t> all_reduce_fusion_split_sizes_; | |||
| std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_indices_; | |||
| std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_; | |||
| std::string strategy_ckpt_load_file_; | |||
| std::string strategy_ckpt_save_file_; | |||
| }; | |||
| @@ -159,13 +159,13 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.") | |||
| .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.") | |||
| .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.") | |||
| .def("set_all_reduce_fusion_split_indices", &ParallelContext::set_all_reduce_fusion_split_indices, | |||
| .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices, | |||
| "Set all reduce fusion split indices.") | |||
| .def("get_all_reduce_fusion_split_indices", &ParallelContext::all_reduce_fusion_split_indices, | |||
| .def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices, | |||
| "Get all reduce fusion split indices.") | |||
| .def("set_all_reduce_fusion_split_sizes", &ParallelContext::set_all_reduce_fusion_split_sizes, | |||
| .def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes, | |||
| "Set all reduce fusion split sizes.") | |||
| .def("get_all_reduce_fusion_split_sizes", &ParallelContext::all_reduce_fusion_split_sizes, | |||
| .def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes, | |||
| "Get all reduce fusion split sizes.") | |||
| .def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion, | |||
| "Set enable/disable all reduce fusion.") | |||
| @@ -92,7 +92,7 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) { | |||
| } // namespace | |||
| bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, | |||
| std::vector<size_t> *segment_index) const { | |||
| std::vector<size_t> *segment_index, const std::string &group) const { | |||
| MS_EXCEPTION_IF_NULL(segment_num); | |||
| MS_EXCEPTION_IF_NULL(segment_index); | |||
| size_t communication_op_node_size = communication_op_info.communication_op_nodes.size(); | |||
| @@ -100,7 +100,7 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic | |||
| auto parallel_context = parallel::ParallelContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(parallel_context); | |||
| const std::vector<uint32_t> split_indices = parallel_context->all_reduce_fusion_split_indices(); | |||
| const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group); | |||
| size_t segments = 0; | |||
| if (split_indices.size() != 0) { | |||
| @@ -255,7 +255,7 @@ bool CommunicationOpFusion::Run(const FuncGraphPtr &func_graph) { | |||
| } | |||
| size_t segment_num = 0; | |||
| std::vector<size_t> segment_index; | |||
| if (GetSplitSegments(it.second, &segment_num, &segment_index)) { | |||
| if (GetSplitSegments(it.second, &segment_num, &segment_index, it.first)) { | |||
| if (DoFusion(func_graph, it.second, segment_num, segment_index)) { | |||
| changed = true; | |||
| } | |||
| @@ -46,7 +46,7 @@ class CommunicationOpFusion : public Pass { | |||
| const CommunicationOpInfo &communication_op_info, size_t start_index, | |||
| size_t end_index) const; | |||
| bool GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, | |||
| std::vector<size_t> *segment_index) const; | |||
| std::vector<size_t> *segment_index, const std::string &group) const; | |||
| std::string op_name_; | |||
| size_t groups_ = 1; | |||
| }; | |||
| @@ -19,6 +19,8 @@ from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, | |||
| from mindspore._c_expression import AutoParallelContext | |||
| from mindspore._checkparam import args_type_check | |||
| _MAX_GROUP_NAME_LEN = 127 | |||
| class _AutoParallelContext: | |||
| """ | |||
| @@ -243,51 +245,117 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_parameter_broadcast_is_set() | |||
| def set_all_reduce_fusion_split_indices(self, indices): | |||
| def set_all_reduce_fusion_split_indices(self, indices, group=""): | |||
| """ | |||
| Set allreduce fusion strategy by parameters indices. | |||
| Args: | |||
| indices (list): Indices list. | |||
| group (str): The hccl communication group. | |||
| Raises: | |||
| TypeError: If type of indices item is not int. | |||
| TypeError: If group is not a python str. | |||
| """ | |||
| self.check_context_handle() | |||
| for index in indices: | |||
| if not isinstance(index, int): | |||
| raise TypeError('indices has invalid value') | |||
| self._context_handle.set_all_reduce_fusion_split_indices(indices) | |||
| if isinstance(indices, (list)): | |||
| for index in indices: | |||
| if not isinstance(index, int): | |||
| raise TypeError('indices has invalid value') | |||
| else: | |||
| raise TypeError('indices must be a python list') | |||
| if isinstance(group, (str)): | |||
| group_len = len(group) | |||
| if group_len > _MAX_GROUP_NAME_LEN: | |||
| raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') | |||
| else: | |||
| raise TypeError('Group must be a python str') | |||
| self._context_handle.set_all_reduce_fusion_split_indices(indices, group) | |||
| if context.get_context("device_target") == "Ascend": | |||
| _set_fusion_strategy_by_idx(indices) | |||
| if group == "": | |||
| _set_fusion_strategy_by_idx(indices) | |||
| else: | |||
| _set_fusion_strategy_by_idx(indices, group) | |||
| def get_all_reduce_fusion_split_indices(self, group=""): | |||
| """ | |||
| Get allreduce fusion split indices. | |||
| Args: | |||
| group (str): The hccl communication group. | |||
| Returns: | |||
| Return split sizes list according to the group. | |||
| def get_all_reduce_fusion_split_indices(self): | |||
| """Get allreduce fusion split indices.""" | |||
| Raises: | |||
| TypeError: If group is not a python str. | |||
| """ | |||
| self.check_context_handle() | |||
| return self._context_handle.get_all_reduce_fusion_split_indices() | |||
| if isinstance(group, (str)): | |||
| group_len = len(group) | |||
| if group_len > _MAX_GROUP_NAME_LEN: | |||
| raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') | |||
| else: | |||
| raise TypeError('Group must be a python str') | |||
| return self._context_handle.get_all_reduce_fusion_split_indices(group) | |||
| def set_all_reduce_fusion_split_sizes(self, sizes): | |||
| def set_all_reduce_fusion_split_sizes(self, sizes, group=""): | |||
| """ | |||
| Set allreduce fusion strategy by parameters data sizes. | |||
| Args: | |||
| sizes (list): Sizes list. | |||
| group (str): The hccl communication group. | |||
| Raises: | |||
| TypeError: If type of sizes item is not int. | |||
| TypeError: If group is not a python str. | |||
| """ | |||
| self.check_context_handle() | |||
| for size in sizes: | |||
| if not isinstance(size, int): | |||
| raise TypeError('sizes has invalid value') | |||
| self._context_handle.set_all_reduce_fusion_split_sizes(sizes) | |||
| if isinstance(sizes, (list)): | |||
| for size in sizes: | |||
| if not isinstance(size, int): | |||
| raise TypeError('sizes has invalid value') | |||
| else: | |||
| raise TypeError('sizes must be a python list') | |||
| if isinstance(group, (str)): | |||
| group_len = len(group) | |||
| if group_len > _MAX_GROUP_NAME_LEN: | |||
| raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') | |||
| else: | |||
| raise TypeError('Group must be a python str') | |||
| self._context_handle.set_all_reduce_fusion_split_sizes(sizes, group) | |||
| if context.get_context("device_target") == "Ascend": | |||
| _set_fusion_strategy_by_size(sizes) | |||
| if group == "": | |||
| _set_fusion_strategy_by_size(sizes) | |||
| else: | |||
| _set_fusion_strategy_by_size(sizes, group) | |||
| def get_all_reduce_fusion_split_sizes(self): | |||
| """Get allreduce fusion split sizes.""" | |||
| def get_all_reduce_fusion_split_sizes(self, group=""): | |||
| """ | |||
| Get allreduce fusion split sizes. | |||
| Args: | |||
| group (str): The hccl communication group. | |||
| Returns: | |||
| Return split sizes list according to the group. | |||
| Raises: | |||
| TypeError: If group is not a python str. | |||
| """ | |||
| self.check_context_handle() | |||
| return self._context_handle.get_all_reduce_fusion_split_sizes() | |||
| if isinstance(group, (str)): | |||
| group_len = len(group) | |||
| if group_len > _MAX_GROUP_NAME_LEN: | |||
| raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') | |||
| else: | |||
| raise TypeError('Group must be a python str') | |||
| return self._context_handle.get_all_reduce_fusion_split_sizes(group) | |||
| def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): | |||
| """ | |||