|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Utils of auto parallel"""
-
- from mindspore._c_expression import reset_op_id
- from mindspore.communication.management import get_group_size, get_rank
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
-
-
- def _get_parallel_mode():
- return auto_parallel_context().get_parallel_mode()
-
-
- def _get_mirror_mean():
- return auto_parallel_context().get_mirror_mean()
-
-
- def _get_device_num():
- """Get the device num."""
- parallel_mode = auto_parallel_context().get_parallel_mode()
- if parallel_mode == "stand_alone":
- device_num = 1
- return device_num
-
- if auto_parallel_context().get_device_num_is_set() is False:
- device_num = get_group_size()
- else:
- device_num = auto_parallel_context().get_device_num()
- return device_num
-
-
- def _get_global_rank():
- """Get the global rank."""
- parallel_mode = auto_parallel_context().get_parallel_mode()
- if parallel_mode == "stand_alone":
- global_rank = 0
- return global_rank
-
- if auto_parallel_context().get_global_rank_is_set() is False:
- global_rank = get_rank()
- else:
- global_rank = auto_parallel_context().get_global_rank()
- return global_rank
-
-
- def _get_parameter_broadcast():
- """Get the parameter broadcast."""
- parallel_mode = auto_parallel_context().get_parallel_mode()
- if parallel_mode == "stand_alone":
- parameter_broadcast = False
- return parameter_broadcast
-
- if auto_parallel_context().get_parameter_broadcast_is_set() is True:
- parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
- elif parallel_mode in ("data_parallel", "hybrid_parallel"):
- parameter_broadcast = True
- else:
- parameter_broadcast = False
-
- return parameter_broadcast
-
-
- def _device_number_check(parallel_mode, device_number):
- """
- Check device num.
-
- Args:
- parallel_mode (str): The parallel mode.
- device_number (int): The device number.
- """
- if parallel_mode == "stand_alone" and device_number != 1:
- raise ValueError("If parallel_mode is stand_alone, device_number must be 1, "
- "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode))
-
-
- def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
- """
- Check parameter broadcast.
-
- Note:
- If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same
- random seed to make sure parameters on multiple devices are the same.
-
- Args:
- parallel_mode (str): The parallel mode.
- parameter_broadcast (bool): The parameter broadcast.
-
- Raises:
- ValueError: If parameter is broadcasted
- but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel").
- """
- if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"):
- raise ValueError("stand_alone, semi_auto_parallel and auto_parallel "
- "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
- .format(parallel_mode, parameter_broadcast))
-
-
- PARAMETER_CLONED_INDEX = 0
-
-
- class _CloneInfo():
- """
- The clone info of parameter.
-
- Attributes:
- be_cloned (bool): Whether the parameter is cloned.
- cloned (bool): Whether the parameter clone from other parameter.
- be_cloned_index (tuple): If the parameter is cloned, generate one index per clone.
- cloned_index (int): If the parameter clone from other parameter, it has a unique index.
- """
- def __init__(self):
- self.be_cloned = False
- self.cloned = False
- self.be_cloned_index = []
- self.cloned_index = None
-
-
- def _set_clone_info(clone_from, clone_to):
- """
- Set the clone info.
-
- Args:
- clone_from (_CloneInfo): The clone info of be_cloned parameter.
- clone_to (_CloneInfo): The clone info of cloned parameter.
- """
- global PARAMETER_CLONED_INDEX
- clone_to.be_cloned = False
- clone_to.cloned = True
- clone_to.be_cloned_index = []
- clone_to.cloned_index = PARAMETER_CLONED_INDEX
-
- clone_from.be_cloned = True
- clone_from.be_cloned_index.append(PARAMETER_CLONED_INDEX)
-
- PARAMETER_CLONED_INDEX = PARAMETER_CLONED_INDEX + 1
-
-
- def _get_python_op(op_name, op_path, instance_name, arglist):
- """Get python operator."""
- module = __import__(op_path, fromlist=["None"])
- cls = getattr(module, op_name)
- op = cls(*arglist)
- op.set_prim_instance_name(instance_name)
- return op
-
-
- def _reset_op_id():
- """Reset op id."""
- reset_op_id()
|