| @@ -18,16 +18,21 @@ | |||
| import types | |||
| from collections import OrderedDict | |||
| from functools import wraps | |||
| from mindspore import context | |||
| from mindspore import log as logger | |||
| from .tensor import Tensor as MsTensor | |||
| from .._c_expression import generate_key, Executor_, Tensor, MetaTensor, PynativeExecutor_ | |||
| from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend | |||
| from .tensor import Tensor as MsTensor | |||
| from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor | |||
| from ..parallel._ps_context import _is_role_pserver | |||
| from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _check_full_batch, _to_full_tensor, \ | |||
| _get_parameter_broadcast | |||
| # store ms_function class compiled pipeline cache | |||
| ms_compile_cache = {} | |||
| BROADCAST_PHASE = "_broadcast_" | |||
| def _convert_function_arguments(fn, *args): | |||
| """ | |||
| @@ -362,6 +367,27 @@ class _Executor: | |||
| def _build_data_graph(self, obj, phase): | |||
| self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict()) | |||
| def _get_auto_split_param_names(self, parameter_layout_dict): | |||
| auto_split_params = {} | |||
| for key, value in parameter_layout_dict.items(): | |||
| for dim in value[1]: | |||
| if dim != -1: | |||
| auto_split_params[key] = value | |||
| break | |||
| auto_split_param_names = (param_name for param_name in auto_split_params) | |||
| return auto_split_param_names | |||
| def _build_broadcast_graph(self, obj, broadcast_params, broadcast_phase): | |||
| """Build broadcast graph.""" | |||
| from mindspore.nn.wrap.cell_wrapper import _BroadCastCell | |||
| _broadcast_net = _BroadCastCell(broadcast_params) | |||
| _broadcast_net.phase = broadcast_phase | |||
| broadcasted_params = _broadcast_net() | |||
| parameters_broadcast_dict = obj.parameters_broadcast_dict() | |||
| for param_name, param in zip(parameters_broadcast_dict, broadcasted_params): | |||
| parameters_broadcast_dict[param_name].set_data(param) | |||
| def _set_dataset_mode(self, args_list): | |||
| """set dataset mode.""" | |||
| # decide whether to sink based on whether the inputs is virtual or args_list is () | |||
| @@ -444,6 +470,15 @@ class _Executor: | |||
| _exec_init_graph(obj, init_phase) | |||
| elif not enable_ge and "export" in phase: | |||
| self._build_data_graph(obj, phase) | |||
| elif BROADCAST_PHASE not in phase and _get_parameter_broadcast(): | |||
| auto_split_param_names = [] | |||
| if auto_parallel_mode: | |||
| auto_split_param_names = self._get_auto_split_param_names(obj.parameter_layout_dict) | |||
| broadcast_params = [param for param_name, param in obj.parameters_broadcast_dict().items() if | |||
| param_name not in auto_split_param_names] | |||
| broadcast_phase = "broadcast_subgraph" + "." + str(obj.create_time) | |||
| self._build_broadcast_graph(obj, broadcast_params, broadcast_phase) | |||
| self.compile_cache[phase] = broadcast_phase | |||
| return phase, True | |||
| @@ -377,9 +377,7 @@ def set_auto_parallel_context(**kwargs): | |||
| - recursive_programming: Recursive programming search mode. | |||
| - dynamic_programming: Dynamic programming search mode. | |||
| parameter_broadcast (bool): A developing feature. Whether to broadcast parameters before training. | |||
| "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter | |||
| broadcast. Default: False. | |||
| parameter_broadcast (bool): Whether to broadcast parameters before training. Default: False. | |||
| strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' | |||
| strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' | |||
| full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter | |||
| @@ -25,6 +25,40 @@ from ...ops.operations.comm_ops import _VirtualDataset | |||
| from ..cell import Cell | |||
| from .grad_reducer import DistributedGradReducer | |||
| _get_datatype = C.MultitypeFuncGraph("_get_datatype") | |||
| @_get_datatype.register("Tensor") | |||
| def _tensors_get_datatype(param): | |||
| """ | |||
| Acquire parameter datatype. | |||
| Args: | |||
| param (Tensor): The parameter before operation. | |||
| Returns: | |||
| mstype, the datatype of parameter. | |||
| """ | |||
| return F.dtype(param) | |||
| _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | |||
| @_cast_datatype.register("TypeType", "Tensor") | |||
| def _tensors_cast_datatype(datatype, param): | |||
| """ | |||
| Cast gradient to datatype. | |||
| Args: | |||
| datatype (mstype): the destination datatype of parameter. | |||
| param (Tensor): The parameter before operation. | |||
| Returns: | |||
| Tensor, the parameter after operation. | |||
| """ | |||
| return F.cast(param, datatype) | |||
| class WithLossCell(Cell): | |||
| r""" | |||
| @@ -175,6 +209,7 @@ class TrainOneStepCell(Cell): | |||
| >>> loss_net = MyWithLossCell(net, loss_fn) | |||
| >>> train_net = nn.TrainOneStepCell(loss_net, optim) | |||
| """ | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| @@ -314,7 +349,6 @@ class WithEvalCell(Cell): | |||
| self._loss_fn = loss_fn | |||
| self.add_cast_fp32 = add_cast_fp32 | |||
| def construct(self, data, label): | |||
| outputs = self._network(data) | |||
| if self.add_cast_fp32: | |||
| @@ -354,3 +388,25 @@ class ParameterUpdate(Cell): | |||
| def construct(self, x): | |||
| F.assign(self._param, x) | |||
| return x | |||
| class _BroadCastCell(Cell): | |||
| """ | |||
| Broadcast the parameters from device 0 to other devices. | |||
| Args: | |||
| params (list): The parameters of Net. | |||
| """ | |||
| def __init__(self, params): | |||
| super(_BroadCastCell, self).__init__() | |||
| self.map_ = C.Map() | |||
| self.params = tuple(params) | |||
| self.broadcast = P.Broadcast(0) | |||
| def construct(self): | |||
| datatypes = self.map_(F.partial(_get_datatype), self.params) | |||
| params = self.map_(F.partial(_cast_datatype, mstype.float32), self.params) | |||
| params = self.broadcast(params) | |||
| new_params = self.map_(F.partial(_cast_datatype), datatypes, params) | |||
| return new_params | |||
| @@ -207,9 +207,6 @@ class _AutoParallelContext: | |||
| parameter_broadcast (bool): Parameter broadcast or not. | |||
| """ | |||
| self.check_context_handle() | |||
| if parameter_broadcast is True and context.get_context("enable_ge") is False: | |||
| raise RuntimeError("Parameter broadcast is a developing feature. For now we suggest to" | |||
| " use mindspore.common.set_seed() to share parameters among devices.") | |||
| self._context_handle.set_parameter_broadcast(parameter_broadcast) | |||
| def get_parameter_broadcast(self): | |||