You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_tensor.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """load tensor and combine tensor"""
  16. import numpy as np
  17. from mindspore.common.tensor import Tensor
  18. from ..communication.management import get_rank, get_group_size
  19. def _get_tensor_strategy(dev_mat, tensor_map):
  20. """
  21. Get split strategy by device arrangement and tensor map.
  22. Args:
  23. dev_mat (list): The device matrix.
  24. tensor_map (list): The map relation between tensor and devices.
  25. Returns:
  26. List, the split strategy with the same size of np_tensor.
  27. """
  28. tensor_strategy = []
  29. for dim in tensor_map:
  30. if dim == -1:
  31. tensor_strategy.append(1)
  32. else:
  33. tensor_strategy.append(dev_mat[-dim-1])
  34. return tensor_strategy
  35. def _get_tensor_slice_index(device_arrangement, tensor_strategy, tensor_map, rank_index):
  36. """
  37. Get the tensor slice index for the local device.
  38. Args:
  39. device_arrangement (list): The device matrix.
  40. tensor_strategy (list): The split strategy with the same size of np_tensor.
  41. tensor_map (list): The map relation between tensor and devices.
  42. rank_index (int): The rank of local device.
  43. Returns:
  44. Integer, the index of the local device for tensor slices.
  45. """
  46. device_coordinate = _rank_to_coordinate(rank_index, device_arrangement)
  47. device_coordinate_new = _convert_to_new_device_coordinate(device_coordinate, tensor_map)
  48. tensor_slice_index = _coordinate_to_rank(device_coordinate_new, tensor_strategy)
  49. return tensor_slice_index
  50. def _rank_to_coordinate(rank_index, device_arrangement):
  51. """
  52. Convert rank index to device coordinate.
  53. Args:
  54. rank_index (int): The index of the local device.
  55. device_arrangement (list): The device matrix.
  56. Returns:
  57. List, the coordinate for local device in the device matrix
  58. """
  59. dim_len = len(device_arrangement)
  60. device_coordinate = np.zeros(dim_len)
  61. for i in range(dim_len):
  62. size = device_arrangement[dim_len - 1 - i]
  63. device_coordinate[dim_len - 1 - i] = rank_index % size
  64. rank_index = int(rank_index / size)
  65. return device_coordinate
  66. def _coordinate_to_rank(device_coordinate, device_arrangement):
  67. """
  68. Convert device coordinate to rank index.
  69. Args:
  70. device_coordinate (list): The coordinate for local device in the device matrix.
  71. device_arrangement (list): The device matrix.
  72. Returns:
  73. Integer, the index of the local device for tensor slices.
  74. """
  75. rank_index = 0
  76. size = 1
  77. for i in range(len(device_coordinate)):
  78. rank_index += size * device_coordinate[len(device_coordinate) - 1 - i]
  79. size *= device_arrangement[len(device_coordinate) - 1 - i]
  80. return rank_index
  81. def _convert_to_new_device_coordinate(device_coordinate, tensor_map):
  82. """
  83. Convert device_coordinate according to the tensor map.
  84. Args:
  85. device_coordinate (list): The coordinate for local device in the device matrix.
  86. tensor_map (list): The map relation between tensor and devices.
  87. Returns:
  88. List, the converted coordinate.
  89. """
  90. device_coordinate_new = []
  91. for i in range(len(tensor_map)):
  92. if tensor_map[len(tensor_map) - 1 - i] != -1:
  93. device_coordinate_new.insert(0, device_coordinate[len(device_coordinate) - 1 -
  94. tensor_map[len(tensor_map) - 1 - i]])
  95. else:
  96. device_coordinate_new.insert(0, 0)
  97. return device_coordinate_new
  98. def _chunk_tensor(np_tensor, strategy, depth):
  99. """
  100. Recursive function to chunk tensor.
  101. Args:
  102. np_tensor (NDarray): The matrix to be split.
  103. strategy (list): The split strategy with the same size of np_tensor.
  104. depth (int): Recursion depth.
  105. Returns:
  106. NDarray, the splited matrix.
  107. Raises:
  108. ValueError: If np_tensor can not be split by strategy.
  109. """
  110. output = []
  111. axis = len(np_tensor.shape) - depth
  112. if np_tensor.shape[axis] % strategy[0] != 0:
  113. raise ValueError("np_tensor can not be split by strategy!")
  114. ret = list(np.split(np_tensor, strategy[0], axis))
  115. if depth == 1:
  116. return ret
  117. for ret_ in ret:
  118. output.extend(
  119. _chunk_tensor(ret_, strategy[len(strategy) - depth + 1:len(strategy)], depth - 1))
  120. return output
  121. def _chunk_tensor_by_strategy(np_tensor, strategy):
  122. """
  123. Split the input by strategy.
  124. Args:
  125. np_tensor (NDarray): The matrix to be split.
  126. strategy (list): The split strategy with the same size of np_tensor.
  127. Returns:
  128. NDarray, the splited matrix.
  129. Raises:
  130. TypeError: If np_tensor is not ndarray
  131. ValueError: If the length of np_tensor does not match the length of strategy.
  132. """
  133. if not isinstance(np_tensor, np.ndarray):
  134. raise TypeError("np_tensor should be ndarray!")
  135. if len(strategy) != len(np_tensor.shape):
  136. raise ValueError("The length of np_tensor does not match the length of strategy!")
  137. return _chunk_tensor(np_tensor, strategy, len(strategy))
  138. def _get_slice_index(dev_mat, tensor_map):
  139. """
  140. Get the slice index for current slice.
  141. Args:
  142. dev_mat (list): The device matrix of devices.
  143. tensor_map (list): The split strategy of tensor.
  144. Returns:
  145. Integer, the slice index for slice on this device.
  146. """
  147. rank = get_rank()
  148. tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
  149. tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
  150. return tensor_slice_index
  151. def _load_tensor(tensor, dev_mat, tensor_map):
  152. """
  153. Get the tensor slice of the local device by the device matrix and the tensor map
  154. Args:
  155. tensor (Tensor): The tensor to be split.
  156. dev_mat (list): The device matrix of devices.
  157. tensor_map (list): The split strategy of tensor.
  158. Returns:
  159. numpy.array, the sliced array.
  160. Examples:
  161. >>> tensor = Tensor(np.ones([32, 32]))
  162. >>> dev_mat = [2, 4]
  163. >>> tensor_map = [1, -1]
  164. >>> tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
  165. """
  166. rank = get_rank()
  167. tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
  168. tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
  169. np_tensor = tensor.asnumpy()
  170. np_tensor_list = _chunk_tensor_by_strategy(np_tensor, tensor_strategy)
  171. np_tensor_slice = np_tensor_list[int(tensor_slice_index)]
  172. return np_tensor_slice
  173. def _load_tensor_by_layout(tensor, layout):
  174. """
  175. Load tensor by layout.
  176. Args:
  177. tensor (Tensor): The input tensor.
  178. layout (list): The tensor layout in auto parallel.
  179. Returns:
  180. Tensor, the sliced tensor.
  181. Raises:
  182. TypeError: If layout is not list.
  183. ValueError: If the length of layout is not 3.
  184. """
  185. if not isinstance(layout, tuple):
  186. raise TypeError("The layout should be tuple! layout is {}".format(layout))
  187. if len(layout) < 6:
  188. raise ValueError("The length of layout must be larger than 5! layout is {}".format(layout))
  189. dev_mat = layout[0]
  190. tensor_map = layout[1]
  191. uniform_split = layout[4]
  192. group = layout[5]
  193. if uniform_split == 0:
  194. raise RuntimeError("The load tensor only support uniform split now")
  195. if tensor.size() == 1:
  196. return tensor
  197. tensor_slice = _load_tensor(tensor, dev_mat, tensor_map)
  198. if group:
  199. # get a totally shard tensor slice for parallel optimizer
  200. rank = get_rank(group)
  201. size = get_group_size(group)
  202. tensor_slice = np.split(tensor_slice, size)[rank]
  203. return Tensor(tensor_slice)
  204. def _reshape_param_data(param_data, dev_mat, tensor_map):
  205. """
  206. Combine param slice by the device matrix and the tensor map, used in model parallel scenario.
  207. Args:
  208. param_data (Tensor): The tensor to be reshaped, generated from all the device from AllGatherParamNet.
  209. dev_mat (list): The device matrix of devices.
  210. tensor_map (list): The split strategy of tensor.
  211. Returns:
  212. Tensor, the combined tensor which with the whole data value.
  213. Examples:
  214. >>> param_data = _allgather_param_net(param_data)
  215. >>> dev_mat = [2, 2]
  216. >>> tensor_map = [1, 0]
  217. >>> tensor = _reshape_param_data(tensor_slices, dev_mat, tensor_map)
  218. """
  219. device_count = 1
  220. for dim in dev_mat:
  221. device_count *= dim
  222. tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
  223. tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
  224. # get the actual number of slices,as: different devices may load the same slice
  225. slice_count = 1
  226. for dim in tensor_strategy:
  227. slice_count *= dim
  228. # reorder slices and remove duplicates based on device matrix and tensor_map
  229. tensor_slices_new = list(range(slice_count))
  230. for i in range(device_count):
  231. slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i)
  232. tensor_slices_new[int(slice_index)] = np.array(tensor_slices[i])
  233. # combine slices to generate complete parameter
  234. dim_len = len(tensor_strategy)
  235. for i in range(dim_len):
  236. ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
  237. tensor_slices_new_inner = []
  238. for j in range(ele_count):
  239. new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
  240. for l in range(j * tensor_strategy[dim_len - 1 - i] + 1,
  241. (j + 1) * tensor_strategy[dim_len - 1 - i]):
  242. new_tensor = np.concatenate((new_tensor, tensor_slices_new[l]), axis=dim_len - 1 - i)
  243. tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
  244. tensor_slices_new = tensor_slices_new_inner
  245. return Tensor(tensor_slices_new[0])
  246. def _reshape_param_data_with_weight(param_data, dev_mat, field_size):
  247. """
  248. Combine param slice by the device matrix, used in model parallel scenario.
  249. Args:
  250. param_data (Tensor): The tensor to be reshaped and rearrangement,
  251. generated from all the device from AllGatherParamNet.
  252. dev_mat (list): The device matrix of devices.
  253. Returns:
  254. Tensor, the combined tensor which with the whole data value.
  255. Examples:
  256. >>> param_data = _allgather_param_net(param_data)
  257. >>> dev_mat = [2, 2]
  258. >>> field_size = [39]
  259. >>> tensor = _reshape_param_data_with_weight(param_data, dev_mat, field_size)
  260. """
  261. device_count = 1
  262. for dim in dev_mat:
  263. device_count *= dim
  264. tensor_slices = np.split(param_data.asnumpy(), device_count, axis=0)
  265. tensor_slices_col = []
  266. for i in range(len(tensor_slices[0][0])):
  267. tensor_slices_new = np.array(tensor_slices[0][:, i]).reshape(field_size, -1)
  268. for j in range(1, device_count):
  269. tensor_slices_new = np.concatenate((tensor_slices_new,\
  270. np.array(tensor_slices[j][:, i]).reshape(field_size, -1)), axis=1)
  271. tensor_slices_col.append(tensor_slices_new)
  272. new_tensor = np.array(tensor_slices_col[0]).reshape(-1, 1)
  273. for i in range(1, len(tensor_slices_col)):
  274. new_tensor = np.concatenate((new_tensor, np.array(tensor_slices_col[i]).reshape(-1, 1)), axis=1)
  275. return Tensor(new_tensor)