You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_utils.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Utils of auto parallel"""
  16. from mindspore._c_expression import reset_op_id
  17. from mindspore.communication.management import get_group_size, get_rank
  18. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  19. def _get_parallel_mode():
  20. return auto_parallel_context().get_parallel_mode()
  21. def _get_mirror_mean():
  22. return auto_parallel_context().get_mirror_mean()
  23. def _get_device_num():
  24. """Get the device num."""
  25. parallel_mode = auto_parallel_context().get_parallel_mode()
  26. if parallel_mode == "stand_alone":
  27. device_num = 1
  28. return device_num
  29. if auto_parallel_context().get_device_num_is_set() is False:
  30. device_num = get_group_size()
  31. else:
  32. device_num = auto_parallel_context().get_device_num()
  33. return device_num
  34. def _get_global_rank():
  35. """Get the global rank."""
  36. parallel_mode = auto_parallel_context().get_parallel_mode()
  37. if parallel_mode == "stand_alone":
  38. global_rank = 0
  39. return global_rank
  40. if auto_parallel_context().get_global_rank_is_set() is False:
  41. global_rank = get_rank()
  42. else:
  43. global_rank = auto_parallel_context().get_global_rank()
  44. return global_rank
  45. def _get_parameter_broadcast():
  46. """Get the parameter broadcast."""
  47. parallel_mode = auto_parallel_context().get_parallel_mode()
  48. if parallel_mode == "stand_alone":
  49. parameter_broadcast = False
  50. return parameter_broadcast
  51. if auto_parallel_context().get_parameter_broadcast_is_set() is True:
  52. parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
  53. elif parallel_mode in ("data_parallel", "hybrid_parallel"):
  54. parameter_broadcast = True
  55. else:
  56. parameter_broadcast = False
  57. return parameter_broadcast
  58. def _device_number_check(parallel_mode, device_number):
  59. """
  60. Check device num.
  61. Args:
  62. parallel_mode (str): The parallel mode.
  63. device_number (int): The device number.
  64. """
  65. if parallel_mode == "stand_alone" and device_number != 1:
  66. raise ValueError("If parallel_mode is stand_alone, device_number must be 1, "
  67. "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode))
  68. def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
  69. """
  70. Check parameter broadcast.
  71. Note:
  72. If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same
  73. random seed to make sure parameters on multiple devices are the same.
  74. Args:
  75. parallel_mode (str): The parallel mode.
  76. parameter_broadcast (bool): The parameter broadcast.
  77. Raises:
  78. ValueError: If parameter is broadcasted
  79. but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel").
  80. """
  81. if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"):
  82. raise ValueError("stand_alone, semi_auto_parallel and auto_parallel "
  83. "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
  84. .format(parallel_mode, parameter_broadcast))
  85. PARAMETER_CLONED_INDEX = 0
  86. class _CloneInfo():
  87. """
  88. The clone info of parameter.
  89. Attributes:
  90. be_cloned (bool): Whether the parameter is cloned.
  91. cloned (bool): Whether the parameter clone from other parameter.
  92. be_cloned_index (tuple): If the parameter is cloned, generate one index per clone.
  93. cloned_index (int): If the parameter clone from other parameter, it has a unique index.
  94. """
  95. def __init__(self):
  96. self.be_cloned = False
  97. self.cloned = False
  98. self.be_cloned_index = []
  99. self.cloned_index = None
  100. def _set_clone_info(clone_from, clone_to):
  101. """
  102. Set the clone info.
  103. Args:
  104. clone_from (_CloneInfo): The clone info of be_cloned parameter.
  105. clone_to (_CloneInfo): The clone info of cloned parameter.
  106. """
  107. global PARAMETER_CLONED_INDEX
  108. clone_to.be_cloned = False
  109. clone_to.cloned = True
  110. clone_to.be_cloned_index = []
  111. clone_to.cloned_index = PARAMETER_CLONED_INDEX
  112. clone_from.be_cloned = True
  113. clone_from.be_cloned_index.append(PARAMETER_CLONED_INDEX)
  114. PARAMETER_CLONED_INDEX = PARAMETER_CLONED_INDEX + 1
  115. def _get_python_op(op_name, op_path, instance_name, arglist):
  116. """Get python operator."""
  117. module = __import__(op_path, fromlist=["None"])
  118. cls = getattr(module, op_name)
  119. op = cls(*arglist)
  120. op.set_prim_instance_name(instance_name)
  121. return op
  122. def _reset_op_id():
  123. """Reset op id."""
  124. reset_op_id()