You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_dp_allreduce_fusion.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Data paralell allreduce fusion"""
  16. import ctypes
  17. from mindspore import log as logger
  18. _MAX_GROUP_NAME_LEN = 127
  19. _HCCL_LIB = 'libhccl.so'
  20. def _load_lib():
  21. try:
  22. hccl_lib = ctypes.CDLL(_HCCL_LIB)
  23. except RuntimeError:
  24. logger.error('Get hccl lib error')
  25. return hccl_lib
  26. def _c_str(string):
  27. """Convert a python string to C string."""
  28. if not isinstance(string, str):
  29. string = string.decode('ascii')
  30. return ctypes.c_char_p(string.encode('utf-8'))
  31. def _c_array(ctype, values):
  32. """Create ctypes array from a python array."""
  33. return (ctype * len(values))(*values)
  34. def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
  35. """
  36. A function set gradient segment strategy according to the index list.
  37. Note:
  38. In the back propagation,
  39. the fusion of the allreduce operators with a fusion attribute equals 1,
  40. will be performed according to the idxList,
  41. to achieve the effect of parallel between calculation and communication.
  42. Args:
  43. idxList (list): The index list of the gradient.
  44. group (str): The hccl communication group.
  45. Raises:
  46. TypeError: If group is not a python str.
  47. TypeError: If IdxList is not a python list.
  48. TypeError: If type of idxList item is not int.
  49. ValueError: If group name length is out of range.
  50. ValueError: If idxList length is 0.
  51. ValueError: If idxList item is less than 0.
  52. RuntimeError: If allreduce split failed.
  53. """
  54. try:
  55. lib_ctype = _load_lib()
  56. except RuntimeError:
  57. logger.error('Load HCCL lib failed')
  58. if isinstance(group, (str)):
  59. group_len = len(group)
  60. if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
  61. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  62. else:
  63. raise TypeError('Group must be a python str')
  64. if isinstance(idxList, (list)):
  65. idx_len = len(idxList)
  66. if idx_len == 0:
  67. raise ValueError('IdxList length is 0')
  68. else:
  69. raise TypeError('IdxList must be a python list')
  70. for idx in idxList:
  71. if isinstance(idx, (int)):
  72. if idx < 0:
  73. raise ValueError('Idx < 0')
  74. else:
  75. raise TypeError('Idx in idxList is invalid')
  76. c_array_idxList = _c_array(ctypes.c_uint, idxList)
  77. c_idx_num = ctypes.c_uint(len(idxList))
  78. c_group = _c_str(group)
  79. ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idxList)
  80. if ret != 0:
  81. raise RuntimeError('Allreduce split error')
  82. def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
  83. """
  84. A function set gradient segment strategy according to the data size percentage list.
  85. Note:
  86. In the back propagation,
  87. the fusion of the allreduce operators with a fusion attribute equals 1,
  88. will be performed according to dataSizeList,
  89. to achieve the effect of parallel between calculation and communication.
  90. Args:
  91. dataSizeList (list): The data size percentage list of the gradient.
  92. group (str): The hccl communication group.
  93. Raises:
  94. TypeError: If group is not a python str.
  95. TypeError: If dataSizeList is not a python list.
  96. TypeError: If type of dataSizeList item is not int or float.
  97. ValueError: If group name length is out of range.
  98. ValueError: If dataSizeList length is 0.
  99. ValueError: If dataSizeList item is less than 0.
  100. RuntimeError: If allreduce split failed.
  101. """
  102. try:
  103. lib_ctype = _load_lib()
  104. except RuntimeError:
  105. logger.error('Load HCCL lib failed')
  106. if isinstance(group, (str)):
  107. group_len = len(group)
  108. if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
  109. raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
  110. else:
  111. raise TypeError('Group must be a python str')
  112. if isinstance(dataSizeList, (list)):
  113. len_data_size = len(dataSizeList)
  114. if len_data_size == 0:
  115. raise ValueError('DataSizeList length is 0')
  116. else:
  117. raise TypeError('DataSizeList must be a python list')
  118. for dataSize in dataSizeList:
  119. if not isinstance(dataSize, (int, float)):
  120. raise TypeError('DataSize in dataSizeList is invalid')
  121. c_array_sizeList = _c_array(ctypes.c_float, dataSizeList)
  122. c_size_num = ctypes.c_uint(len(dataSizeList))
  123. c_group = _c_str(group)
  124. ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList)
  125. if ret != 0:
  126. raise RuntimeError('Allreduce split error')