You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utils of auto parallel"""
  16. import numpy as np
  17. from mindspore import log as logger
  18. from mindspore._c_expression import reset_op_id
  19. from mindspore.common.tensor import Tensor
  20. from mindspore.common.dtype import dtype_to_nptype
  21. from mindspore.common import dtype as mstype
  22. from mindspore.communication.management import get_group_size, get_rank
  23. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  24. def _get_parallel_mode():
  25. """Get parallel mode."""
  26. return auto_parallel_context().get_parallel_mode()
  27. def _get_full_batch():
  28. """Get whether to use full_batch."""
  29. return auto_parallel_context().get_full_batch()
  30. def _check_full_batch():
  31. """
  32. full_batch could only be used under semi_auto_parallel or auto_parallel, check it.
  33. Raises:
  34. RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel.
  35. """
  36. parallel_mode = _get_parallel_mode()
  37. full_batch = _get_full_batch()
  38. if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch):
  39. raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.")
  40. def _need_to_full():
  41. """Check whether to convert input to full shape or tensor."""
  42. parallel_mode = _get_parallel_mode()
  43. full_batch = _get_full_batch()
  44. need = ((parallel_mode in ("semi_auto_parallel", "auto_parallel"))
  45. and (not full_batch))
  46. return need
  47. def _to_full_shapes(shapes, device_num):
  48. """Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution."""
  49. new_shapes = []
  50. for shape in shapes:
  51. new_shape = ()
  52. for i, item in enumerate(shape):
  53. if i == 0:
  54. new_shape += (item * device_num,)
  55. else:
  56. new_shape += (item,)
  57. new_shapes.append(new_shape)
  58. return new_shapes
  59. def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
  60. """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
  61. from host solution."""
  62. lst = []
  63. if not isinstance(elem, (tuple, list)):
  64. elem = [elem]
  65. if global_rank >= device_num:
  66. raise ValueError("The global rank must be smaller than device number, the global rank is {}, "
  67. "the device num is {}".format(global_rank, device_num))
  68. for data in elem:
  69. if isinstance(data, np.ndarray):
  70. data = Tensor(data)
  71. if not isinstance(data, Tensor):
  72. raise ValueError("elements in tensors must be Tensor")
  73. shape_ = data.shape
  74. type_ = data.dtype
  75. new_shape = ()
  76. batchsize_per_device = 1
  77. for i, item in enumerate(shape_):
  78. if i == 0:
  79. new_shape += (item * device_num,)
  80. batchsize_per_device = item
  81. else:
  82. new_shape += (item,)
  83. new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_))
  84. start = global_rank * batchsize_per_device
  85. new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy()
  86. new_tensor = Tensor(new_tensor_numpy)
  87. lst.append(new_tensor)
  88. if scaling_sens:
  89. lst.append(Tensor(scaling_sens, mstype.float32))
  90. return tuple(lst)
  91. def _get_gradients_mean():
  92. """Get if using gradients_mean."""
  93. return auto_parallel_context().get_gradients_mean()
  94. def _get_device_num():
  95. """Get the device num."""
  96. parallel_mode = auto_parallel_context().get_parallel_mode()
  97. if parallel_mode == "stand_alone":
  98. device_num = 1
  99. return device_num
  100. if auto_parallel_context().get_device_num_is_set() is False:
  101. device_num = get_group_size()
  102. else:
  103. device_num = auto_parallel_context().get_device_num()
  104. return device_num
  105. def _get_global_rank():
  106. """Get the global rank."""
  107. parallel_mode = auto_parallel_context().get_parallel_mode()
  108. if parallel_mode == "stand_alone":
  109. global_rank = 0
  110. return global_rank
  111. if auto_parallel_context().get_global_rank_is_set() is False:
  112. global_rank = get_rank()
  113. else:
  114. global_rank = auto_parallel_context().get_global_rank()
  115. return global_rank
  116. def _get_parameter_broadcast():
  117. """Get the parameter broadcast."""
  118. parallel_mode = auto_parallel_context().get_parallel_mode()
  119. parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
  120. if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False:
  121. logger.warning("You are suggested to use mindspore.common.set_seed() to share"
  122. " parameters among devices.")
  123. return parameter_broadcast
  124. def _device_number_check(parallel_mode, device_number):
  125. """
  126. Check device num.
  127. Args:
  128. parallel_mode (str): The parallel mode.
  129. device_number (int): The device number.
  130. """
  131. if parallel_mode == "stand_alone" and device_number != 1:
  132. raise ValueError("If parallel_mode is stand_alone, device_number must be 1, "
  133. "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode))
  134. def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
  135. """
  136. Check parameter broadcast.
  137. Note:
  138. If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same
  139. random seed to make sure parameters on multiple devices are the same.
  140. Args:
  141. parallel_mode (str): The parallel mode.
  142. parameter_broadcast (bool): The parameter broadcast.
  143. Raises:
  144. ValueError: If parameter is broadcasted
  145. but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel").
  146. """
  147. if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"):
  148. raise ValueError("stand_alone, semi_auto_parallel and auto_parallel "
  149. "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
  150. .format(parallel_mode, parameter_broadcast))
  151. def _get_python_op(op_name, op_path, instance_name, arglist):
  152. """Get python operator."""
  153. module = __import__(op_path, fromlist=["None"])
  154. cls = getattr(module, op_name)
  155. op = cls(*arglist)
  156. op.set_prim_instance_name(instance_name)
  157. return op
  158. def _reset_op_id():
  159. """Reset op id."""
  160. reset_op_id()