# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Utils of auto parallel""" import numpy as np from mindspore import context, log as logger from mindspore.context import ParallelMode from mindspore._c_expression import reset_op_id from mindspore.common.tensor import Tensor from mindspore.common.dtype import dtype_to_nptype from mindspore.common import dtype as mstype from mindspore.communication.management import get_group_size, get_rank from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.common.seed import get_seed def _get_parallel_mode(): """Get parallel mode.""" return auto_parallel_context().get_parallel_mode() def _get_full_batch(): """Get whether to use full_batch.""" return auto_parallel_context().get_full_batch() def _get_pipeline_stages(): """Get pipeline stages""" return auto_parallel_context().get_pipeline_stages() def _check_full_batch(): """ full_batch could only be used under semi_auto_parallel or auto_parallel, check it. Raises: RuntimeError: Using full_batch under neither semi_auto_parallel nor auto_parallel. """ parallel_mode = _get_parallel_mode() full_batch = _get_full_batch() if ((parallel_mode not in ("semi_auto_parallel", "auto_parallel")) and full_batch): raise RuntimeError("full_batch could only be used under semi_auto_parallel or auto_parallel.") def _need_to_full(): """Check whether to convert input to full shape or tensor.""" parallel_mode = _get_parallel_mode() full_batch = _get_full_batch() need = ((parallel_mode in ("semi_auto_parallel", "auto_parallel")) and (not full_batch)) return need def _to_full_shapes(shapes, device_num): """Expanding batch dimension according to device_num, adapt to mindspore minddata graph solution.""" new_shapes = [] for shape in shapes: new_shape = () for i, item in enumerate(shape): if i == 0: new_shape += (item * device_num,) else: new_shape += (item,) new_shapes.append(new_shape) return new_shapes def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None): """Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data from host solution.""" lst = [] if not isinstance(elem, (tuple, list)): elem = [elem] if global_rank >= device_num: raise ValueError("The global rank must be smaller than device number, the global rank is {}, " "the device num is {}".format(global_rank, device_num)) for data in elem: if isinstance(data, np.ndarray): data = Tensor(data) if not isinstance(data, Tensor): raise ValueError("elements in tensors must be Tensor") shape_ = data.shape type_ = data.dtype new_shape = () batchsize_per_device = 1 for i, item in enumerate(shape_): if i == 0: new_shape += (item * device_num,) batchsize_per_device = item else: new_shape += (item,) new_tensor_numpy = np.zeros(new_shape, dtype_to_nptype(type_)) start = global_rank * batchsize_per_device new_tensor_numpy[start: start + batchsize_per_device] = data.asnumpy() new_tensor = Tensor(new_tensor_numpy) lst.append(new_tensor) if scaling_sens: lst.append(Tensor(scaling_sens, mstype.float32)) return tuple(lst) def _get_gradients_mean(): """Get if using gradients_mean.""" return auto_parallel_context().get_gradients_mean() def _get_device_num(): """Get the device num.""" parallel_mode = auto_parallel_context().get_parallel_mode() if parallel_mode == "stand_alone": device_num = 1 return device_num if auto_parallel_context().get_device_num_is_set() is False: device_num = get_group_size() else: device_num = auto_parallel_context().get_device_num() return device_num def _get_global_rank(): """Get the global rank.""" parallel_mode = auto_parallel_context().get_parallel_mode() if parallel_mode == "stand_alone": global_rank = 0 return global_rank if auto_parallel_context().get_global_rank_is_set() is False: global_rank = get_rank() else: global_rank = auto_parallel_context().get_global_rank() return global_rank def _get_parameter_broadcast(): """Get the parameter broadcast.""" parallel_mode = auto_parallel_context().get_parallel_mode() parameter_broadcast = auto_parallel_context().get_parameter_broadcast() if parallel_mode in ("data_parallel", "hybrid_parallel") and parameter_broadcast is False and get_seed is None: logger.warning("You are suggested to use mindspore.common.set_seed() to share" " parameters among devices.") return parameter_broadcast def _device_number_check(parallel_mode, device_number): """ Check device num. Args: parallel_mode (str): The parallel mode. device_number (int): The device number. """ if parallel_mode == "stand_alone" and device_number != 1: raise ValueError("If parallel_mode is stand_alone, device_number must be 1, " "device_number: {0}, parallel_mode:{1}".format(device_number, parallel_mode)) def _parameter_broadcast_check(parallel_mode, parameter_broadcast): """ Check parameter broadcast. Note: If parallel mode is semi_auto_parallel or auto_parallel, parameter broadcast is not supported. Using the same random seed to make sure parameters on multiple devices are the same. Args: parallel_mode (str): The parallel mode. parameter_broadcast (bool): The parameter broadcast. Raises: ValueError: If parameter is broadcasted but the parallel mode is "stand_alone" or "semi_auto_parallel" or "auto_parallel"). """ if parameter_broadcast is True and parallel_mode in ("stand_alone", "semi_auto_parallel", "auto_parallel"): raise ValueError("stand_alone, semi_auto_parallel and auto_parallel " "do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}" .format(parallel_mode, parameter_broadcast)) def _get_python_op(op_name, op_path, instance_name, arglist): """Get python operator.""" module = __import__(op_path, fromlist=["None"]) cls = getattr(module, op_name) if op_path != "mindspore.ops.functional": op = cls(*arglist) else: op = cls op.set_prim_instance_name(instance_name) return op def _reset_op_id(): """Reset op id.""" reset_op_id() def _parallel_predict_check(): """validate parallel model prediction""" if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if not context.get_auto_parallel_context("full_batch"): raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.') if context.get_auto_parallel_context("enable_parallel_optimizer"): raise RuntimeError('Model prediction does not support parallel optimizer. Please set' '"enable_parallel_optimizer" with False.') def _check_similar_layout(tensor_layout1, tensor_layout2): """check if two tensor layouts are same""" if tensor_layout1[1] != tensor_layout2[1]: return False for i in tensor_layout1[1]: if i == -1: continue if tensor_layout1[0][-1-i] != tensor_layout2[0][-1-i]: return False return True def _check_same_layout(tensor_layout1, tensor_layout2): """check if two tensor layouts are same""" return tensor_layout1[0] == tensor_layout2[0] and tensor_layout1[1] == tensor_layout2[1] def _remove_repeated_slices(tensor_layout): """generate unrepeated tensor layout""" import copy new_tensor_layout = copy.deepcopy(tensor_layout) dev_mat = tensor_layout[0][:] tensor_map = tensor_layout[1] for dim in range(len(dev_mat)): if dim not in tensor_map: dev_mat[-1-dim] = 1 new_tensor_layout[0] = dev_mat return new_tensor_layout def _infer_rank_list(train_map, predict_map=None): """infer checkpoint slices to be loaded""" ret = {} for param_name in train_map: train_layout = train_map[param_name] train_dev_mat = train_layout[0] dev_num = np.array(train_dev_mat).prod() new_train_layout = _remove_repeated_slices(train_layout) array = np.arange(dev_num).reshape(train_dev_mat) index = () for i in new_train_layout[0]: if i == 1: index = index + (0,) else: index = index + (slice(None),) rank_list = array[index].flatten() if not predict_map: ret[param_name] = (rank_list, False) continue if param_name not in predict_map: logger.warning("predict_map does not contain %s", param_name) continue predict_layout = predict_map[param_name] dev_num = np.array(predict_layout[0]).prod() # optimization pass if _check_same_layout(train_layout, predict_layout): dev_rank = _get_global_rank() ret[param_name] = ([dev_rank], True) continue if _check_similar_layout(train_layout, predict_layout): if len(rank_list) == 1: ret[param_name] = (rank_list, True) elif len(rank_list) == dev_num: dev_rank = _get_global_rank() ret[param_name] = ([rank_list[dev_rank]], True) else: ret[param_name] = (rank_list, False) else: ret[param_name] = (rank_list, False) return ret