You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_dp_allreduce_fusion.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Data paralell allreduce fusion"""
  16. import ctypes
  17. _MAX_GROUP_NAME_LEN = 127
  18. _HCCL_LIB = 'libhccl.so'
  19. def _load_lib():
  20. try:
  21. hccl_lib = ctypes.CDLL(_HCCL_LIB)
  22. except Exception:
  23. raise RuntimeError('Get hccl lib error')
  24. return hccl_lib
  25. def _c_str(string):
  26. """Convert a python string to C string."""
  27. if not isinstance(string, str):
  28. string = string.decode('ascii')
  29. return ctypes.c_char_p(string.encode('utf-8'))
  30. def _c_array(ctype, values):
  31. """Create ctypes array from a python array."""
  32. return (ctype * len(values))(*values)
  33. def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
  34. """
  35. A function set gradient segment strategy according to the index list.
  36. Note:
  37. In the back propagation,
  38. the fusion of the allreduce operators with a fusion attribute equals 1,
  39. will be performed according to the idxList,
  40. to achieve the effect of parallel between calculation and communication.
  41. Args:
  42. idxList (list): The index list of the gradient.
  43. group (str): The hccl communication group.
  44. Raises:
  45. TypeError: If group is not a python str.
  46. TypeError: If IdxList is not a python list.
  47. TypeError: If type of idxList item is not int.
  48. ValueError: If group name length is out of range.
  49. ValueError: If idxList length is 0.
  50. ValueError: If idxList item is less than 0.
  51. RuntimeError: If allreduce split failed.
  52. """
  53. try:
  54. lib_ctype = _load_lib()
  55. except RuntimeError:
  56. import hccl_test.manage.api as hccl
  57. hccl.set_fusion_strategy_by_idx()
  58. return
  59. if isinstance(group, (str)):
  60. group_len = len(group)
  61. if (group_len > _MAX_GROUP_NAME_LEN or group_len == 0):
  62. raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
  63. else:
  64. raise TypeError('Group must be a python str')
  65. if isinstance(idxList, (list)):
  66. idx_len = len(idxList)
  67. if idx_len == 0:
  68. raise ValueError('IdxList length is 0')
  69. else:
  70. raise TypeError('IdxList must be a python list')
  71. for idx in idxList:
  72. if isinstance(idx, (int)):
  73. if idx < 0:
  74. raise ValueError('Idx < 0')
  75. else:
  76. raise TypeError('Idx in idxList is invalid')
  77. c_array_idxList = _c_array(ctypes.c_uint, idxList)
  78. c_idx_num = ctypes.c_uint(len(idxList))
  79. c_group = _c_str(group)
  80. ret = lib_ctype.hcom_set_split_strategy_by_index(c_group, c_idx_num, c_array_idxList)
  81. if ret != 0:
  82. raise RuntimeError('Allreduce split error')
  83. def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
  84. """
  85. A function set gradient segment strategy according to the data size percentage list.
  86. Note:
  87. In the back propagation,
  88. the fusion of the allreduce operators with a fusion attribute equals 1,
  89. will be performed according to dataSizeList,
  90. to achieve the effect of parallel between calculation and communication.
  91. Args:
  92. dataSizeList (list): The data size percentage list of the gradient.
  93. group (str): The hccl communication group.
  94. Raises:
  95. TypeError: If group is not a python str.
  96. TypeError: If dataSizeList is not a python list.
  97. TypeError: If type of dataSizeList item is not int or float.
  98. ValueError: If group name length is out of range.
  99. ValueError: If dataSizeList length is 0.
  100. ValueError: If dataSizeList item is less than 0.
  101. RuntimeError: If allreduce split failed.
  102. """
  103. try:
  104. lib_ctype = _load_lib()
  105. except RuntimeError:
  106. import hccl_test.manage.api as hccl
  107. hccl.set_fusion_strategy_by_size()
  108. return
  109. if isinstance(group, (str)):
  110. group_len = len(group)
  111. if group_len > _MAX_GROUP_NAME_LEN or group_len == 0:
  112. raise ValueError('Group name is out of range {_MAX_GROUP_NAME_LEN}')
  113. else:
  114. raise TypeError('Group must be a python str')
  115. if isinstance(dataSizeList, (list)):
  116. len_data_size = len(dataSizeList)
  117. if len_data_size == 0:
  118. raise ValueError('DataSizeList length is 0')
  119. else:
  120. raise TypeError('DataSizeList must be a python list')
  121. for dataSize in dataSizeList:
  122. if not isinstance(dataSize, (int, float)):
  123. raise TypeError('DataSize in dataSizeList is invalid')
  124. c_array_sizeList = _c_array(ctypes.c_float, dataSizeList)
  125. c_size_num = ctypes.c_uint(len(dataSizeList))
  126. c_group = _c_str(group)
  127. ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList)
  128. if ret != 0:
  129. raise RuntimeError('Allreduce split error')