|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Data paralell allreduce fusion"""
-
- import ctypes
-
- from mindspore import log as logger
-
- _MAX_GROUP_NAME_LEN = 127
- _HCCL_LIB = 'libhccl.so'
-
-
- def _load_lib():
- try:
- hccl_lib = ctypes.CDLL(_HCCL_LIB)
- except RuntimeError:
- logger.error('Get hccl lib error')
-
- return hccl_lib
-
-
- def _c_str(string):
- """Convert a python string to C string."""
- if not isinstance(string, str):
- string = string.decode('ascii')
- return ctypes.c_char_p(string.encode('utf-8'))
-
-
- def _c_array(ctype, values):
- """Create ctypes array from a python array."""
- return (ctype * len(values))(*values)
-
-
- def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
- """
- A function set gradient segment strategy according to the index list.
-
- Note:
- In the back propagation,
- the fusion of the allreduce operators with a fusion attribute equals 1,
- will be performed according to the idxList,
- to achieve the effect of parallel between calculation and communication.
-
- Args:
- idxList (list): The index list of the gradient.
- group (str): The hccl communication group.
-
- Raises:
- TypeError: If group is not a python str.
- TypeError: If IdxList is not a python list.
- TypeError: If type of idxList item is not int.
- ValueError: If group name length is out of range.
- ValueError: If idxList length is 0.
- ValueError: If idxList item is less than 0.
- RuntimeError: If allreduce split failed.
- """
- try:
- lib_ctype = _load_lib()
- except RuntimeError:
- logger.error('Load HCCL lib failed')
-
- if isinstance(group, (str)):
- group_len = len(group)
- if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
- raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
- else:
- raise TypeError('Group must be a python str')
-
- if isinstance(idxList, (list)):
- idx_len = len(idxList)
- if idx_len == 0:
- raise ValueError('IdxList length is 0')
- else:
- raise TypeError('IdxList must be a python list')
-
- for idx in idxList:
- if isinstance(idx, (int)):
- if idx < 0:
- raise ValueError('Idx < 0')
- else:
- raise TypeError('Idx in idxList is invalid')
-
- c_array_idxList = _c_array(ctypes.c_uint, idxList)
- c_idx_num = ctypes.c_uint(len(idxList))
- c_group = _c_str(group)
- ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idxList)
- if ret != 0:
- raise RuntimeError('Allreduce split error')
-
-
- def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
- """
- A function set gradient segment strategy according to the data size percentage list.
-
- Note:
- In the back propagation,
- the fusion of the allreduce operators with a fusion attribute equals 1,
- will be performed according to dataSizeList,
- to achieve the effect of parallel between calculation and communication.
-
- Args:
- dataSizeList (list): The data size percentage list of the gradient.
- group (str): The hccl communication group.
-
- Raises:
- TypeError: If group is not a python str.
- TypeError: If dataSizeList is not a python list.
- TypeError: If type of dataSizeList item is not int or float.
- ValueError: If group name length is out of range.
- ValueError: If dataSizeList length is 0.
- ValueError: If dataSizeList item is less than 0.
- RuntimeError: If allreduce split failed.
- """
- try:
- lib_ctype = _load_lib()
- except RuntimeError:
- logger.error('Load HCCL lib failed')
- if isinstance(group, (str)):
- group_len = len(group)
- if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
- raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
- else:
- raise TypeError('Group must be a python str')
- if isinstance(dataSizeList, (list)):
- len_data_size = len(dataSizeList)
- if len_data_size == 0:
- raise ValueError('DataSizeList length is 0')
- else:
- raise TypeError('DataSizeList must be a python list')
- for dataSize in dataSizeList:
- if not isinstance(dataSize, (int, float)):
- raise TypeError('DataSize in dataSizeList is invalid')
-
- c_array_sizeList = _c_array(ctypes.c_float, dataSizeList)
- c_size_num = ctypes.c_uint(len(dataSizeList))
- c_group = _c_str(group)
- ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList)
- if ret != 0:
- raise RuntimeError('Allreduce split error')
|