You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 14 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utils of auto parallel"""
  16. import numpy as np
  17. from mindspore import context, log as logger
  18. from mindspore.context import ParallelMode
  19. from mindspore._c_expression import reset_op_id
  20. from mindspore.common.tensor import Tensor
  21. from mindspore.common.dtype import dtype_to_nptype
  22. from mindspore.common import dtype as mstype
  23. from mindspore.communication.management import get_group_size, get_rank
  24. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  25. from mindspore.common.seed import get_seed
  26. def _get_parallel_mode():
  27. """Get parallel mode."""
  28. return auto_parallel_context().get_parallel_mode()
  29. def _is_in_auto_parallel_mode():
  30. return _get_parallel_mode() in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL]
  31. def _get_full_batch():
  32. """Get whether to use full_batch."""
  33. return auto_parallel_context().get_full_batch()
  34. def _get_pipeline_stages():
  35. """Get pipeline stages"""
  36. return auto_parallel_context().get_pipeline_stages()
  37. def _check_task_sink_envs():
  38. """
  39. Check whether task_sink environment variables have been exported or not.
  40. return True if task_sink environment variables have been exported, False otherwise.
  41. """
  42. import os
  43. task_sink = os.getenv("GRAPH_OP_RUN")
  44. if task_sink and task_sink.isdigit() and int(task_sink) == 1:
  45. return False
  46. return True
  47. def _check_full_batch():
  48. """
  49. full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
  50. Raises:
  51. RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel.
  52. """
  53. parallel_mode = _get_parallel_mode()
  54. full_batch = _get_full_batch()
  55. if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch):
  56. raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.")
  57. def _need_to_full():
  58. """Check whether to convert input to full shape or tensor."""
  59. if _get_parallel_mode() not in ("semi_auto_parallel", "auto_parallel"):
  60. return False
  61. dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
  62. if dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch"):
  63. return True
  64. return not _get_full_batch()
  65. def _to_full_shapes(shapes, device_num):
  66. """Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
  67. new_shapes = []
  68. dataset_strategy = ()
  69. if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"):
  70. dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
  71. if dataset_strategy:
  72. if len(shapes) != len(dataset_strategy):
  73. raise ValueError("The input shapes size {} is not equal to "
  74. "dataset strategy size {}".format(len(shapes), len(dataset_strategy)))
  75. for index, shape in enumerate(shapes):
  76. if len(shape) != len(dataset_strategy[index]):
  77. raise ValueError("The input shapes item size {} is not equal to "
  78. "dataset strategy item size {}".format(len(shape), len(dataset_strategy[index])))
  79. new_shape = ()
  80. for i, item in enumerate(shape):
  81. new_shape += (item * dataset_strategy[index][i],)
  82. new_shapes.append(new_shape)
  83. return new_shapes
  84. for shape in shapes:
  85. new_shape = ()
  86. for i, item in enumerate(shape):
  87. if i == 0:
  88. new_shape += (item * device_num,)
  89. else:
  90. new_shape += (item,)
  91. new_shapes.append(new_shape)
  92. return new_shapes
  93. def _to_full_tensor(elem, global_device_num, global_rank, scaling_sens=None):
  94. """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
  95. from host solution.
  96. """
  97. lst = []
  98. device_num = global_device_num // _get_pipeline_stages()
  99. stage_rank = global_rank % device_num
  100. if not isinstance(elem, (tuple, list)):
  101. elem = [elem]
  102. if stage_rank >= device_num:
  103. raise ValueError("The global rank must be smaller than device number, the global rank is {}, "
  104. "the device num is {}".format(stage_rank, device_num))
  105. dataset_strategy = ()
  106. if context.get_auto_parallel_context("dataset_strategy") not in ("data_parallel", "full_batch"):
  107. dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
  108. if elem and dataset_strategy:
  109. if len(elem) != len(dataset_strategy):
  110. raise ValueError("The input size {} is not equal to "
  111. "dataset strategy size {}".format(len(elem), len(dataset_strategy)))
  112. for index, data in enumerate(elem):
  113. if isinstance(data, np.ndarray):
  114. data = Tensor(data)
  115. if not isinstance(data, Tensor):
  116. raise ValueError("elements in tensors must be Tensor")
  117. shape_ = data.shape
  118. type_ = data.dtype
  119. new_shape = ()
  120. if not dataset_strategy:
  121. batchsize_per_device = 1
  122. for i, item in enumerate(shape_):
  123. if i == 0:
  124. new_shape += (item * device_num,)
  125. batchsize_per_device = item
  126. else:
  127. new_shape += (item,)
  128. new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
  129. start = stage_rank * batchsize_per_device
  130. new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy()
  131. else:
  132. if len(shape_) != len(dataset_strategy[index]):
  133. raise ValueError("The input shapes item size {} is not equal to "
  134. "dataset strategy item size {}".format(len(shape_), len(dataset_strategy[index])))
  135. slice_index = ()
  136. for i, item in enumerate(shape_):
  137. new_shape += (item * dataset_strategy[index][i],)
  138. start = (stage_rank % dataset_strategy[index][i]) * item
  139. end = (stage_rank % dataset_strategy[index][i] + 1) * item
  140. s = slice(start, end, 1)
  141. slice_index += (s,)
  142. new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
  143. new_tensor_numpy[slice_index] = data.asnumpy()
  144. new_tensor = Tensor(new_tensor_numpy)
  145. lst.append(new_tensor)
  146. if scaling_sens:
  147. lst.append(Tensor(scaling_sens, mstype.float32))
  148. return tuple(lst)
  149. def _get_gradients_mean():
  150. """Get if using gradients_mean."""
  151. return auto_parallel_context().get_gradients_mean()
  152. def _get_device_num():
  153. """Get the device num."""
  154. parallel_mode = auto_parallel_context().get_parallel_mode()
  155. if parallel_mode == "stand_alone":
  156. device_num = 1
  157. return device_num
  158. if auto_parallel_context().get_device_num_is_set() is False:
  159. device_num = get_group_size()
  160. else:
  161. device_num = auto_parallel_context().get_device_num()
  162. return device_num
  163. def _get_global_rank():
  164. """Get the global rank."""
  165. parallel_mode = auto_parallel_context().get_parallel_mode()
  166. if parallel_mode == "stand_alone":
  167. global_rank = 0
  168. return global_rank
  169. if auto_parallel_context().get_global_rank_is_set() is False:
  170. global_rank = get_rank()
  171. else:
  172. global_rank = auto_parallel_context().get_global_rank()
  173. return global_rank
  174. def _get_parameter_broadcast():
  175. """Get the parameter broadcast."""
  176. parallel_mode = auto_parallel_context().get_parallel_mode()
  177. parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
  178. if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed() is None:
  179. logger.warning("You are suggested to use mindspore.context.set_auto_parallel_context(parameter_broadcast=True)"
  180. " or mindspore.common.set_seed() to share parameters among multi-devices.")
  181. return parameter_broadcast
  182. def _get_enable_parallel_optimizer():
  183. """Get if using parallel optimizer."""
  184. return auto_parallel_context().get_enable_parallel_optimizer()
  185. def _get_grad_accumulation_shard():
  186. """Get if using parallel shard."""
  187. return auto_parallel_context().get_grad_accumulation_shard()
  188. def _device_number_check(parallel_mode, device_number):
  189. """
  190. Check device num.
  191. Args:
  192. parallel_mode (str): The parallel mode.
  193. device_number (int): The device number.
  194. """
  195. if parallel_mode == "stand_alone" and device_number != 1:
  196. raise ValueError("If parallel_mode is stand_alone, device_number must be 1, "
  197. "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode))
  198. def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
  199. """
  200. Check parameter broadcast.
  201. Note:
  202. If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same
  203. random seed to make sure parameters on multiple devices are the same.
  204. Args:
  205. parallel_mode (str): The parallel mode.
  206. parameter_broadcast (bool): The parameter broadcast.
  207. Raises:
  208. ValueError: If parameter is broadcasted
  209. but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel").
  210. """
  211. if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"):
  212. raise ValueError("stand_alone, semi_auto_parallel and auto_parallel "
  213. "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
  214. .format(parallel_mode, parameter_broadcast))
  215. def _get_python_op(op_name, op_path, instance_name, arglist):
  216. """Get python operator."""
  217. module = __import__(op_path, fromlist=["None"])
  218. cls = getattr(module, op_name)
  219. if op_path != "mindspore.ops.functional":
  220. op = cls(*arglist)
  221. else:
  222. op = cls
  223. op.set_prim_instance_name(instance_name)
  224. return op
  225. def _reset_op_id():
  226. """Reset op id."""
  227. reset_op_id()
  228. def _parallel_predict_check():
  229. """validate parallel model prediction"""
  230. if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
  231. dataset_strategy = context.get_auto_parallel_context("dataset_strategy")
  232. is_shard_dataset_mp = (dataset_strategy and dataset_strategy not in ("data_parallel", "full_batch"))
  233. if not context.get_auto_parallel_context("full_batch") and not is_shard_dataset_mp:
  234. raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')
  235. def _check_similar_layout(tensor_layout1, tensor_layout2):
  236. """check if two tensor layouts are same"""
  237. if tensor_layout1[1] != tensor_layout2[1]:
  238. return False
  239. for i in tensor_layout1[1]:
  240. if i == -1:
  241. continue
  242. if tensor_layout1[0][-1-i] != tensor_layout2[0][-1-i]:
  243. return False
  244. return True
  245. def _check_same_layout(tensor_layout1, tensor_layout2):
  246. """check if two tensor layouts are same"""
  247. return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1]
  248. def _remove_repeated_slices(tensor_layout):
  249. """generate unrepeated tensor layout"""
  250. import copy
  251. new_tensor_layout = copy.deepcopy(tensor_layout)
  252. dev_mat = tensor_layout[0][:]
  253. tensor_map = tensor_layout[1]
  254. for dim in range(len(dev_mat)):
  255. if dim not in tensor_map:
  256. dev_mat[-1-dim] = 1
  257. new_tensor_layout[0] = dev_mat
  258. return new_tensor_layout
  259. def _infer_rank_list(train_map, predict_map=None):
  260. """infer checkpoint slices to be loaded"""
  261. ret = {}
  262. if _get_pipeline_stages() > 1:
  263. local_rank = int(_get_global_rank() % (_get_device_num() / _get_pipeline_stages()))
  264. else:
  265. local_rank = _get_global_rank()
  266. for param_name in train_map:
  267. train_layout = train_map[param_name]
  268. train_dev_mat = train_layout[0]
  269. dev_num = np.array(train_dev_mat).prod()
  270. new_train_layout = _remove_repeated_slices(train_layout)
  271. array = np.arange(dev_num).reshape(train_dev_mat)
  272. index = ()
  273. for i in new_train_layout[0]:
  274. if i == 1:
  275. index = index + (0,)
  276. else:
  277. index = index + (slice(None),)
  278. rank_list = array[index].flatten()
  279. if not predict_map:
  280. ret[param_name] = (rank_list, False)
  281. continue
  282. if param_name not in predict_map:
  283. logger.warning("predict_map does not contain %s", param_name)
  284. continue
  285. predict_layout = predict_map[param_name]
  286. dev_num = np.array(predict_layout[0]).prod()
  287. # optimization pass
  288. if _check_same_layout(train_layout, predict_layout):
  289. ret[param_name] = ([local_rank], True)
  290. continue
  291. if _check_similar_layout(train_layout, predict_layout):
  292. if len(rank_list) == 1:
  293. ret[param_name] = (rank_list, True)
  294. elif len(rank_list) == dev_num:
  295. ret[param_name] = ([rank_list[local_rank]], True)
  296. else:
  297. ret[param_name] = (rank_list, False)
  298. else:
  299. ret[param_name] = (rank_list, False)
  300. return ret