You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_parallel_config.py 5.7 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Parallel Config for the Parallel Training
  17. This is an experimental interface that is subject to change and/or deletion.
  18. """
  19. from mindspore._checkparam import Validator
  20. from mindspore import context
  21. import mindspore.communication.management as D
  22. from mindspore.context import ParallelMode
  23. from mindspore.parallel._utils import _get_parallel_mode
  24. from mindspore import log as logger
  25. __all__ = [
  26. "OpParallelConfig"
  27. ]
  28. class _Config:
  29. r""" A basic class of the configure"""
  30. def __str__(self):
  31. info = "[ParallelConfig]" + '\n'
  32. for k, v in self.__dict__.items():
  33. var_info = "{}:{}\n".format(k, v)
  34. info += var_info
  35. return info
  36. class OpParallelConfig(_Config):
  37. r"""
  38. OpParallelConfig for the setting data parallel and model parallel.
  39. Args:
  40. data_parallel (int): The data parallel way. Default: 1
  41. model_parallel (int): The model parallel way. Default: 1
  42. Supported Platforms:
  43. ``Ascend`` ``GPU``
  44. Examples:
  45. >>> from mindspore.parallel.nn import OpParallelConfig
  46. >>> config=OpParallelConfig(data_parallel=1, model_parallel=1)
  47. """
  48. def __init__(self, data_parallel=1, model_parallel=1):
  49. Validator.check_positive_int(data_parallel, "data_parallel")
  50. Validator.check_positive_int(model_parallel, "model_parallel")
  51. self.data_parallel = data_parallel
  52. self.model_parallel = model_parallel
  53. @property
  54. def data_parallel(self):
  55. return self._data_parallel
  56. @data_parallel.setter
  57. def data_parallel(self, value):
  58. Validator.check_positive_int(value, "data_parallel")
  59. self._data_parallel = value
  60. @property
  61. def model_parallel(self):
  62. return self._model_parallel
  63. @model_parallel.setter
  64. def model_parallel(self, value):
  65. Validator.check_positive_int(value, "model_parallel")
  66. self._model_parallel = value
  67. class _PipeLineConfig(_Config):
  68. r"""
  69. PPConfig for the setting data parallel, model parallel
  70. Args:
  71. pipeline_stage (int): The number of the pipeline stages. Default: 1
  72. micro_batch_num (int): The model parallel way. Default: 1
  73. Supported Platforms:
  74. ``Ascend`` ``GPU``
  75. Examples:
  76. >>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1)
  77. """
  78. def __init__(self, pipeline_stage=1, micro_batch_num=1):
  79. Validator.check_positive_int(pipeline_stage, "pipeline_stage")
  80. Validator.check_positive_int(micro_batch_num, "micro_batch_num")
  81. self.pipeline_stage = pipeline_stage
  82. self.micro_batch_num = micro_batch_num
  83. @property
  84. def pipeline_stage(self):
  85. return self._pipeline_stage
  86. @pipeline_stage.setter
  87. def pipeline_stage(self, value):
  88. Validator.check_positive_int(value, "pipeline_stage")
  89. self._pipeline_stage = value
  90. context.set_auto_parallel_context(pipeline_stages=value)
  91. @property
  92. def micro_batch_num(self):
  93. return self._micro_batch_num
  94. @micro_batch_num.setter
  95. def micro_batch_num(self, value):
  96. Validator.check_positive_int(value, "micro_batch_num")
  97. self._micro_batch_num = value
  98. # In case the user doesn't pass a config as args.
  99. default_dpmp_config = OpParallelConfig()
  100. def _check_config(config):
  101. """
  102. Check if micro_batch_num >= pipeline_stage
  103. """
  104. # the config pipeline_stage is same with context.pipeline_stage
  105. pipeline_stage = context.get_auto_parallel_context("pipeline_stages")
  106. if hasattr(config, 'pipeline_stage') and pipeline_stage != config.pipeline_stage:
  107. raise ValueError(
  108. f"The pipeline stage {pipeline_stage} in auto_parallel_context is not equal to the pipeline_stage "
  109. f"{config.pipeline_stage}"
  110. f" in the config.")
  111. # make sure the following is in auto parallel mode
  112. is_auto_parallel = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
  113. if not is_auto_parallel:
  114. return
  115. device_num = D.get_group_size()
  116. optimizer_shard = context.get_auto_parallel_context("enable_parallel_optimizer")
  117. if config.data_parallel * config.model_parallel * pipeline_stage > device_num:
  118. raise ValueError(f"The product of the data parallel {config.data_parallel}, "
  119. f"model parallel {config.model_parallel} "
  120. f"pipeline stages {pipeline_stage} "
  121. f"should be less than device_num {device_num}.")
  122. # the config optimizer_shard is same with context.optimizer_shard
  123. if hasattr(config, "optimizer_shard") and optimizer_shard and optimizer_shard != config.optimizer_shard:
  124. logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
  125. f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the "
  126. f"optimizer_shard to make them consistent.")