You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utils of auto parallel"""
  16. from mindspore._c_expression import reset_op_id
  17. from mindspore.communication.management import get_group_size, get_rank
  18. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  19. def _get_parallel_mode():
  20. """Get parallel mode."""
  21. return auto_parallel_context().get_parallel_mode()
  22. def _get_full_batch():
  23. """Get whether to use full_batch."""
  24. return auto_parallel_context().get_full_batch()
  25. def _need_to_full():
  26. """Check whether to convert input to full shape or tensor."""
  27. parallel_mode = _get_parallel_mode()
  28. full_batch = _get_full_batch()
  29. need = ((parallel_mode in ("semi_auto_parallel", "auto_parallel"))
  30. and (not full_batch))
  31. return need
  32. def _get_mirror_mean():
  33. """Get if using mirror_mean."""
  34. return auto_parallel_context().get_mirror_mean()
  35. def _get_device_num():
  36. """Get the device num."""
  37. parallel_mode = auto_parallel_context().get_parallel_mode()
  38. if parallel_mode == "stand_alone":
  39. device_num = 1
  40. return device_num
  41. if auto_parallel_context().get_device_num_is_set() is False:
  42. device_num = get_group_size()
  43. else:
  44. device_num = auto_parallel_context().get_device_num()
  45. return device_num
  46. def _get_global_rank():
  47. """Get the global rank."""
  48. parallel_mode = auto_parallel_context().get_parallel_mode()
  49. if parallel_mode == "stand_alone":
  50. global_rank = 0
  51. return global_rank
  52. if auto_parallel_context().get_global_rank_is_set() is False:
  53. global_rank = get_rank()
  54. else:
  55. global_rank = auto_parallel_context().get_global_rank()
  56. return global_rank
  57. def _get_parameter_broadcast():
  58. """Get the parameter broadcast."""
  59. parallel_mode = auto_parallel_context().get_parallel_mode()
  60. if parallel_mode == "stand_alone":
  61. parameter_broadcast = False
  62. return parameter_broadcast
  63. if auto_parallel_context().get_parameter_broadcast_is_set() is True:
  64. parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
  65. elif parallel_mode in ("data_parallel", "hybrid_parallel"):
  66. parameter_broadcast = True
  67. else:
  68. parameter_broadcast = False
  69. return parameter_broadcast
  70. def _device_number_check(parallel_mode, device_number):
  71. """
  72. Check device num.
  73. Args:
  74. parallel_mode (str): The parallel mode.
  75. device_number (int): The device number.
  76. """
  77. if parallel_mode == "stand_alone" and device_number != 1:
  78. raise ValueError("If parallel_mode is stand_alone, device_number must be 1, "
  79. "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode))
  80. def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
  81. """
  82. Check parameter broadcast.
  83. Note:
  84. If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same
  85. random seed to make sure parameters on multiple devices are the same.
  86. Args:
  87. parallel_mode (str): The parallel mode.
  88. parameter_broadcast (bool): The parameter broadcast.
  89. Raises:
  90. ValueError: If parameter is broadcasted
  91. but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel").
  92. """
  93. if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"):
  94. raise ValueError("stand_alone, semi_auto_parallel and auto_parallel "
  95. "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
  96. .format(parallel_mode, parameter_broadcast))
  97. PARAMETER_CLONED_INDEX = 0
  98. class _CloneInfo():
  99. """
  100. The clone info of parameter.
  101. Attributes:
  102. be_cloned (bool): Whether the parameter is cloned.
  103. cloned (bool): Whether the parameter clone from other parameter.
  104. be_cloned_index (tuple): If the parameter is cloned, generate one index per clone.
  105. cloned_index (int): If the parameter clone from other parameter, it has a unique index.
  106. """
  107. def __init__(self):
  108. self.be_cloned = False
  109. self.cloned = False
  110. self.be_cloned_index = []
  111. self.cloned_index = None
  112. def _set_clone_info(clone_from, clone_to):
  113. """
  114. Set the clone info.
  115. Args:
  116. clone_from (_CloneInfo): The clone info of be_cloned parameter.
  117. clone_to (_CloneInfo): The clone info of cloned parameter.
  118. """
  119. global PARAMETER_CLONED_INDEX
  120. clone_to.be_cloned = False
  121. clone_to.cloned = True
  122. clone_to.be_cloned_index = []
  123. clone_to.cloned_index = PARAMETER_CLONED_INDEX
  124. clone_from.be_cloned = True
  125. clone_from.be_cloned_index.append(PARAMETER_CLONED_INDEX)
  126. PARAMETER_CLONED_INDEX = PARAMETER_CLONED_INDEX + 1
  127. def _get_python_op(op_name, op_path, instance_name, arglist):
  128. """Get python operator."""
  129. module = __import__(op_path, fromlist=["None"])
  130. cls = getattr(module, op_name)
  131. op = cls(*arglist)
  132. op.set_prim_instance_name(instance_name)
  133. return op
  134. def _reset_op_id():
  135. """Reset op id."""
  136. reset_op_id()