You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_tensor.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """load tensor and combine tensor"""
  16. import numpy as np
  17. from mindspore.common.tensor import Tensor
  18. from ..communication.management import get_rank
  19. def _get_tensor_strategy(dev_mat, tensor_map):
  20. """
  21. Get split strategy by device arrangement and tensor map.
  22. Args:
  23. dev_mat (list): The device matrix.
  24. tensor_map (list): The map relation between tensor and devices.
  25. Returns:
  26. List, the split strategy with the same size of np_tensor.
  27. """
  28. tensor_strategy = []
  29. for dim in tensor_map:
  30. if dim == -1:
  31. tensor_strategy.append(1)
  32. else:
  33. tensor_strategy.append(dev_mat[-dim-1])
  34. return tensor_strategy
  35. def _get_tensor_slice_index(device_arrangement, tensor_strategy, tensor_map, rank_index):
  36. """
  37. Get the tensor slice index for the local device.
  38. Args:
  39. device_arrangement (list): The device matrix.
  40. tensor_strategy (list): The split strategy with the same size of np_tensor.
  41. tensor_map (list): The map relation between tensor and devices.
  42. rank_index (int): The rank of local device.
  43. Returns:
  44. Integer, the index of the local device for tensor slices.
  45. """
  46. device_coordinate = _rank_to_coordinate(rank_index, device_arrangement)
  47. device_coordinate_new = _convert_to_new_device_coordinate(device_coordinate, tensor_map)
  48. tensor_slice_index = _coordinate_to_rank(device_coordinate_new, tensor_strategy)
  49. return tensor_slice_index
  50. def _rank_to_coordinate(rank_index, device_arrangement):
  51. """
  52. Convert rank index to device coordinate.
  53. Args:
  54. rank_index (int): The index of the local device.
  55. device_arrangement (list): The device matrix.
  56. Returns:
  57. List, the coordinate for local device in the device matrix
  58. """
  59. dim_len = len(device_arrangement)
  60. device_coordinate = np.zeros(dim_len)
  61. for i in range(dim_len):
  62. size = device_arrangement[dim_len - 1 - i]
  63. device_coordinate[dim_len - 1 - i] = rank_index % size
  64. rank_index = int(rank_index / size)
  65. return device_coordinate
  66. def _coordinate_to_rank(device_coordinate, device_arrangement):
  67. """
  68. Convert device coordinate to rank index.
  69. Args:
  70. device_coordinate (list): The coordinate for local device in the device matrix.
  71. device_arrangement (list): The device matrix.
  72. Returns:
  73. Integer, the index of the local device for tensor slices.
  74. """
  75. rank_index = 0
  76. size = 1
  77. for i in range(len(device_coordinate)):
  78. rank_index += size * device_coordinate[len(device_coordinate) - 1 - i]
  79. size *= device_arrangement[len(device_coordinate) - 1 - i]
  80. return rank_index
  81. def _convert_to_new_device_coordinate(device_coordinate, tensor_map):
  82. """
  83. Convert device_coordinate according to the tensor map.
  84. Args:
  85. device_coordinate (list): The coordinate for local device in the device matrix.
  86. tensor_map (list): The map relation between tensor and devices.
  87. Returns:
  88. List, the converted coordinate.
  89. """
  90. device_coordinate_new = []
  91. for i in range(len(tensor_map)):
  92. if tensor_map[len(tensor_map) - 1 - i] != -1:
  93. device_coordinate_new.insert(0, device_coordinate[len(device_coordinate) - 1 -
  94. tensor_map[len(tensor_map) - 1 - i]])
  95. else:
  96. device_coordinate_new.insert(0, 0)
  97. return device_coordinate_new
  98. def _chunk_tensor(np_tensor, strategy, depth):
  99. """
  100. Recursive function to chunk tensor.
  101. Args:
  102. np_tensor (NDarray): The matrix to be split.
  103. strategy (list): The split strategy with the same size of np_tensor.
  104. depth (int): Recursion depth.
  105. Returns:
  106. NDarray, the splited matrix.
  107. Raises:
  108. ValueError: If np_tensor can not be split by strategy.
  109. """
  110. output = []
  111. axis = len(np_tensor.shape) - depth
  112. if np_tensor.shape[axis] % strategy[0] != 0:
  113. raise ValueError("np_tensor can not be split by strategy!")
  114. ret = list(np.split(np_tensor, strategy[0], axis))
  115. if depth == 1:
  116. return ret
  117. for ret_ in ret:
  118. output.extend(
  119. _chunk_tensor(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1))
  120. return output
  121. def _chunk_tensor_by_strategy(np_tensor, strategy):
  122. """
  123. Split the input by strategy.
  124. Args:
  125. np_tensor (NDarray): The matrix to be split.
  126. strategy (list): The split strategy with the same size of np_tensor.
  127. Returns:
  128. NDarray, the splited matrix.
  129. Raises:
  130. TypeError: If np_tensor is not ndarray
  131. ValueError: If the length of np_tensor does not match the length of strategy.
  132. """
  133. if not isinstance(np_tensor, np.ndarray):
  134. raise TypeError("np_tensor should be ndarray!")
  135. if len(strategy) != len(np_tensor.shape):
  136. raise ValueError("The length of np_tensor does not match the length of strategy!")
  137. return _chunk_tensor(np_tensor, strategy, len(strategy))
  138. def _get_seed(dev_mat, tensor_map):
  139. """
  140. Get the random seed for current slice.
  141. Args:
  142. dev_mat (list): The device matrix of devices.
  143. tensor_map (list): The split strategy of tensor.
  144. Returns:
  145. Integer, the local random seed for this device.
  146. """
  147. rank = get_rank()
  148. tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
  149. tensor_slice_seed = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
  150. return tensor_slice_seed
  151. def _load_tensor(tensor, dev_mat, tensor_map):
  152. """
  153. Get the tensor slice of the local device by the device matrix and the tensor map
  154. Args:
  155. tensor (Tensor): The tensor to be split.
  156. dev_mat (list): The device matrix of devices.
  157. tensor_map (list): The split strategy of tensor.
  158. Returns:
  159. Tensor, the sliced tensor.
  160. Examples:
  161. >>> tensor = Tensor(np.ones([32, 32]))
  162. >>> dev_mat = [2, 4]
  163. >>> tensor_map = [1, -1]
  164. >>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
  165. """
  166. rank = get_rank()
  167. tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
  168. tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
  169. np_tensor = tensor.asnumpy()
  170. np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
  171. np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
  172. tensor_slice = Tensor(np_tensor_slice)
  173. return tensor_slice
  174. def _load_tensor_by_layout(tensor, layout):
  175. """
  176. Load tensor by layout.
  177. Args:
  178. tensor (Tensor): The input tensor.
  179. layout (list): The tensor layout in auto parallel.
  180. Returns:
  181. Tensor, the sliced tensor.
  182. Raises:
  183. TypeError: If layout is not list.
  184. ValueError: If the length of layout is not 3.
  185. """
  186. if not isinstance(layout, list):
  187. raise TypeError("The layout should be list! layout is {}".format(layout))
  188. if len(layout) != 3:
  189. raise ValueError("The length of layout must be 3! layout is {}".format(layout))
  190. dev_mat = layout[0]
  191. tensor_map = layout[1]
  192. if tensor.size() == 1:
  193. return tensor
  194. return _load_tensor(tensor, dev_mat, tensor_map)
  195. def _reshape_param_data(param_data, dev_mat, tensor_map):
  196. """
  197. Combine param slice by the device matrix and the tensor map, used in model parallel scenario.
  198. Args:
  199. param_data (Tensor): The tensor to be reshaped, generated from all the device from AllGatherParamNet.
  200. dev_mat (list): The device matrix of devices.
  201. tensor_map (list): The split strategy of tensor.
  202. Returns:
  203. Tensor, the combined tensor which with the whole data value.
  204. Examples:
  205. >>> param_data = _allgather_param_net(param_data)
  206. >>> dev_mat = [2, 2]
  207. >>> tensor_map = [1, 0]
  208. >>> tensor = _reshape_param_data(tensor_slices, dev_mat, tensor_map)
  209. """
  210. device_count = 1
  211. for dim in dev_mat:
  212. device_count *= dim
  213. tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
  214. tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
  215. # get the actual number of slices,as: different devices may load the same slice
  216. slice_count = 1
  217. for dim in tensor_strategy:
  218. slice_count *= dim
  219. # reorder slices and remove duplicates based on device matrix and tensor_map
  220. tensor_slices_new = list(range(slice_count))
  221. for i in range(device_count):
  222. slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i)
  223. tensor_slices_new[int(slice_index)] = np.array(tensor_slices[i])
  224. # combine slices to generate complete parameter
  225. dim_len = len(tensor_strategy)
  226. for i in range(dim_len):
  227. ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
  228. tensor_slices_new_inner = []
  229. for j in range(ele_count):
  230. new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
  231. for l in range(j * tensor_strategy[dim_len - 1 - i] + 1,
  232. (j + 1) * tensor_strategy[dim_len - 1 - i]):
  233. new_tensor = np.concatenate((new_tensor, tensor_slices_new[l]), axis=dim_len - 1 - i)
  234. tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
  235. tensor_slices_new = tensor_slices_new_inner
  236. return Tensor(tensor_slices_new[0])