|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- Parallel Config for the Parallel Training
- This is an experimental interface that is subject to change and/or deletion.
- """
- from mindspore._checkparam import Validator
- from mindspore import context
- import mindspore.communication.management as D
- from mindspore.context import ParallelMode
- from mindspore.parallel._utils import _get_parallel_mode
- from mindspore import log as logger
-
- __all__ = [
- "OpParallelConfig"
- ]
-
-
- class _Config:
- r""" A basic class of the configure"""
-
- def __str__(self):
- info = "[ParallelConfig]" + '\n'
- for k, v in self.__dict__.items():
- var_info = "{}:{}\n".format(k, v)
- info += var_info
- return info
-
-
- class OpParallelConfig(_Config):
- r"""
- OpParallelConfig for the setting the data parallel and model parallel.
-
- Args:
- data_parallel (int): The data parallel way. Default: 1
- model_parallel (int): The model parallel way. Default: 1
- Supported Platforms:
- ``Ascend`` ``GPU``
-
- Examples:
- >>> from mindspore.parallel.nn import OpParallelConfig
- >>> config=OpParallelConfig(data_parallel=1, model_parallel=1)
- """
-
- def __init__(self, data_parallel=1, model_parallel=1):
- Validator.check_positive_int(data_parallel, "data_parallel")
- Validator.check_positive_int(model_parallel, "model_parallel")
- self.data_parallel = data_parallel
- self.model_parallel = model_parallel
-
- @property
- def data_parallel(self):
- return self._data_parallel
-
- @data_parallel.setter
- def data_parallel(self, value):
- Validator.check_positive_int(value, "data_parallel")
- self._data_parallel = value
-
- @property
- def model_parallel(self):
- return self._model_parallel
-
- @model_parallel.setter
- def model_parallel(self, value):
- Validator.check_positive_int(value, "model_parallel")
- self._model_parallel = value
-
-
- class _PipeLineConfig(_Config):
- r"""
- PPConfig for the setting the data parallel, model parallel
-
- Args:
- pipeline_stage (int): The number of the pipeline stages. Default: 1
- micro_batch_num (int): The model parallel way. Default: 1
- Supported Platforms:
- ``Ascend`` ``GPU``
-
- Examples:
- >>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1)
- """
-
- def __init__(self, pipeline_stage=1, micro_batch_num=1):
- Validator.check_positive_int(pipeline_stage, "pipeline_stage")
- Validator.check_positive_int(micro_batch_num, "micro_batch_num")
- self.pipeline_stage = pipeline_stage
- self.micro_batch_num = micro_batch_num
-
- @property
- def pipeline_stage(self):
- return self._pipeline_stage
-
- @pipeline_stage.setter
- def pipeline_stage(self, value):
- Validator.check_positive_int(value, "pipeline_stage")
- self._pipeline_stage = value
- context.set_auto_parallel_context(pipeline_stages=value)
-
- @property
- def micro_batch_num(self):
- return self._micro_batch_num
-
- @micro_batch_num.setter
- def micro_batch_num(self, value):
- Validator.check_positive_int(value, "micro_batch_num")
- self._micro_batch_num = value
-
-
- # In case the user doesn't pass a config as args.
- default_dpmp_config = OpParallelConfig()
-
-
- def _check_config(config):
- """
- Check if micro_batch_num >= pipeline_stage
- """
- # the config pipeline_stage is same with context.pipeline_stage
- pipeline_stage = context.get_auto_parallel_context("pipeline_stages")
- if hasattr(config, 'pipeline_stage') and pipeline_stage != config.pipeline_stage:
- raise ValueError(
- f"The pipeline stage {pipeline_stage} in auto_parallel_context is not equal to the pipeline_stage "
- f"{config.pipeline_stage}"
- f" in the config.")
-
- # make sure the following is in auto parallel mode
- is_auto_parallel = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
- if not is_auto_parallel:
- return
-
- device_num = D.get_group_size()
- optimizer_shard = context.get_auto_parallel_context("enable_parallel_optimizer")
-
- # dp * pp * pipeline_stage <= device_num
- if config.data_parallel * config.model_parallel * pipeline_stage > device_num:
- raise ValueError("The product of the data parallel {config.data_parallel},"
- "model parallel {config.model_parallel}"
- "pipeline stages {pipeline_stage}"
- "should be less than device_num {device_num}")
-
- # the config optimizer_shard is same with context.optimizer_shard
- if hasattr(config, "optimizer_shard") and optimizer_shard and optimizer_shard != config.optimizer_shard:
- logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
- f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the "
- f"optimizer_shard to make them consistent.")
-
- # pipeline_stage <= micro_batch_num
- if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\
- and config.pipeline_stage < config.micro_batch_num:
- raise ValueError(
- f"The pipeline stage {config.pipeline_stage} should be greater than the micro_batch_num"
- f"{config.micro_batch_num}.")
|