You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utils of auto parallel"""
  16. import numpy as np
  17. from mindspore import log as logger
  18. from mindspore._c_expression import reset_op_id
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.common.dtype import dtype_to_nptype
  21. from mindspore.common import dtype as mstype
  22. from mindspore.communication.management import get_group_size, get_rank
  23. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  24. from mindspore.common.seed import get_seed
  25. def _get_parallel_mode():
  26. """Get parallel mode."""
  27. return auto_parallel_context().get_parallel_mode()
  28. def _get_full_batch():
  29. """Get whether to use full_batch."""
  30. return auto_parallel_context().get_full_batch()
  31. def _check_full_batch():
  32. """
  33. full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
  34. Raises:
  35. RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel.
  36. """
  37. parallel_mode = _get_parallel_mode()
  38. full_batch = _get_full_batch()
  39. if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch):
  40. raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.")
  41. def _need_to_full():
  42. """Check whether to convert input to full shape or tensor."""
  43. parallel_mode = _get_parallel_mode()
  44. full_batch = _get_full_batch()
  45. need = ((parallel_mode in ("semi_auto_parallel", "auto_parallel"))
  46. and (not full_batch))
  47. return need
  48. def _to_full_shapes(shapes, device_num):
  49. """Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
  50. new_shapes = []
  51. for shape in shapes:
  52. new_shape = ()
  53. for i, item in enumerate(shape):
  54. if i == 0:
  55. new_shape += (item * device_num,)
  56. else:
  57. new_shape += (item,)
  58. new_shapes.append(new_shape)
  59. return new_shapes
  60. def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
  61. """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
  62. from host solution."""
  63. lst = []
  64. if not isinstance(elem, (tuple, list)):
  65. elem = [elem]
  66. if global_rank >= device_num:
  67. raise ValueError("The global rank must be smaller than device number, the global rank is {}, "
  68. "the device num is {}".format(global_rank, device_num))
  69. for data in elem:
  70. if isinstance(data, np.ndarray):
  71. data = Tensor(data)
  72. if not isinstance(data, Tensor):
  73. raise ValueError("elements in tensors must be Tensor")
  74. shape_ = data.shape
  75. type_ = data.dtype
  76. new_shape = ()
  77. batchsize_per_device = 1
  78. for i, item in enumerate(shape_):
  79. if i == 0:
  80. new_shape += (item * device_num,)
  81. batchsize_per_device = item
  82. else:
  83. new_shape += (item,)
  84. new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
  85. start = global_rank * batchsize_per_device
  86. new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy()
  87. new_tensor = Tensor(new_tensor_numpy)
  88. lst.append(new_tensor)
  89. if scaling_sens:
  90. lst.append(Tensor(scaling_sens, mstype.float32))
  91. return tuple(lst)
  92. def _get_gradients_mean():
  93. """Get if using gradients_mean."""
  94. return auto_parallel_context().get_gradients_mean()
  95. def _get_device_num():
  96. """Get the device num."""
  97. parallel_mode = auto_parallel_context().get_parallel_mode()
  98. if parallel_mode == "stand_alone":
  99. device_num = 1
  100. return device_num
  101. if auto_parallel_context().get_device_num_is_set() is False:
  102. device_num = get_group_size()
  103. else:
  104. device_num = auto_parallel_context().get_device_num()
  105. return device_num
  106. def _get_global_rank():
  107. """Get the global rank."""
  108. parallel_mode = auto_parallel_context().get_parallel_mode()
  109. if parallel_mode == "stand_alone":
  110. global_rank = 0
  111. return global_rank
  112. if auto_parallel_context().get_global_rank_is_set() is False:
  113. global_rank = get_rank()
  114. else:
  115. global_rank = auto_parallel_context().get_global_rank()
  116. return global_rank
  117. def _get_parameter_broadcast():
  118. """Get the parameter broadcast."""
  119. parallel_mode = auto_parallel_context().get_parallel_mode()
  120. parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
  121. if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None:
  122. logger.warning("You are suggested to use mindspore.common.set_seed() to share"
  123. " parameters among devices.")
  124. return parameter_broadcast
  125. def _device_number_check(parallel_mode, device_number):
  126. """
  127. Check device num.
  128. Args:
  129. parallel_mode (str): The parallel mode.
  130. device_number (int): The device number.
  131. """
  132. if parallel_mode == "stand_alone" and device_number != 1:
  133. raise ValueError("If parallel_mode is stand_alone, device_number must be 1, "
  134. "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode))
  135. def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
  136. """
  137. Check parameter broadcast.
  138. Note:
  139. If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same
  140. random seed to make sure parameters on multiple devices are the same.
  141. Args:
  142. parallel_mode (str): The parallel mode.
  143. parameter_broadcast (bool): The parameter broadcast.
  144. Raises:
  145. ValueError: If parameter is broadcasted
  146. but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel").
  147. """
  148. if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"):
  149. raise ValueError("stand_alone, semi_auto_parallel and auto_parallel "
  150. "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
  151. .format(parallel_mode, parameter_broadcast))
  152. def _get_python_op(op_name, op_path, instance_name, arglist):
  153. """Get python operator."""
  154. module = __import__(op_path, fromlist=["None"])
  155. cls = getattr(module, op_name)
  156. op = cls(*arglist)
  157. op.set_prim_instance_name(instance_name)
  158. return op
  159. def _reset_op_id():
  160. """Reset op id."""
  161. reset_op_id()