| @@ -374,9 +374,6 @@ class _Executor: | |||
| obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) | |||
| obj.load_parameter_slice(params) | |||
| if _get_parallel_mode() in ["hybrid_parallel"]: | |||
| obj.parameter_layout_dict = self._build_parameter_layout(obj) | |||
| # the following GE init process is not needed when use vm or ms backend | |||
| if enable_ge: | |||
| # decide whether to sink based on whether the inputs is virtual or not | |||
| @@ -449,38 +446,6 @@ class _Executor: | |||
| return self._exec_pip(obj, *args, phase=phase_real) | |||
| raise KeyError('{} graph is not exist.'.format(phase_real)) | |||
| def _build_parameter_layout(self, obj): | |||
| """ | |||
| Build parameter layout, for layerwise_parallel parameter. | |||
| Args: | |||
| obj (Function or Cell): The function or cell instance need to be compiled. | |||
| Returns: | |||
| Dictionary, parameter layout info. | |||
| """ | |||
| parameter_layout_dict = {} | |||
| layerwise_parallel_parameters = [] | |||
| for key in obj.parameters_dict(): | |||
| if obj.parameters_dict()[key].layerwise_parallel is True: | |||
| layerwise_parallel_parameters.append(key) | |||
| if not layerwise_parallel_parameters: | |||
| return parameter_layout_dict | |||
| from ..communication.management import get_group_size | |||
| group_size = [get_group_size()] | |||
| for key in layerwise_parallel_parameters: | |||
| tensor_map = [0] | |||
| shape = obj.parameters_dict()[key].data.shape() | |||
| for x in range(len(shape)): # dim 0 set 0, others set -1 | |||
| if x: | |||
| tensor_map.append(-1) | |||
| layout = [group_size, tensor_map] | |||
| parameter_layout_dict[key] = layout | |||
| return parameter_layout_dict | |||
| def del_net_res(self, net_id): | |||
| self._executor.del_net_res(net_id) | |||
| @@ -24,7 +24,7 @@ import mindspore.context as context | |||
| from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph | |||
| from mindspore.train._utils import _make_directory | |||
| from mindspore import log as logger | |||
| from mindspore._checkparam import check_int_non_negative | |||
| from mindspore._checkparam import check_int_non_negative, check_bool | |||
| from mindspore.common.tensor import Tensor | |||
| from .summary.summary_record import _cache_summary_tensor_data | |||
| @@ -150,6 +150,8 @@ class CheckpointConfig: | |||
| keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5. | |||
| keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0. | |||
| Can't be used with keep_checkpoint_max at the same time. | |||
| integrated_save (bool): Whether to intergrated save in automatic model parall scene. Default: True. | |||
| Integrated save function is only supported in automatic parall scene, not supported in manual parallel. | |||
| Raises: | |||
| ValueError: If the input_param is None or 0. | |||
| @@ -163,7 +165,8 @@ class CheckpointConfig: | |||
| save_checkpoint_steps=1, | |||
| save_checkpoint_seconds=0, | |||
| keep_checkpoint_max=5, | |||
| keep_checkpoint_per_n_minutes=0): | |||
| keep_checkpoint_per_n_minutes=0, | |||
| integrated_save=True): | |||
| if not save_checkpoint_steps and not save_checkpoint_seconds and \ | |||
| not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: | |||
| @@ -191,6 +194,8 @@ class CheckpointConfig: | |||
| if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: | |||
| self._keep_checkpoint_max = 1 | |||
| self._integrated_save = check_bool(integrated_save) | |||
| @property | |||
| def save_checkpoint_steps(self): | |||
| """Get the value of _save_checkpoint_steps.""" | |||
| @@ -211,6 +216,11 @@ class CheckpointConfig: | |||
| """Get the value of _keep_checkpoint_per_n_minutes.""" | |||
| return self._keep_checkpoint_per_n_minutes | |||
| @property | |||
| def integrated_save(self): | |||
| """Get the value of _integrated_save.""" | |||
| return self._integrated_save | |||
| def get_checkpoint_policy(self): | |||
| """Get the policy of checkpoint.""" | |||
| checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, | |||
| @@ -619,7 +629,7 @@ class ModelCheckpoint(Callback): | |||
| _set_cur_net(cb_params.train_network) | |||
| cb_params.train_network.exec_checkpoint_graph() | |||
| _exec_save_checkpoint(cb_params.train_network, gen_file) | |||
| _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) | |||
| if os.path.exists(gen_file): | |||
| shutil.move(gen_file, cur_file) | |||
| @@ -279,13 +279,14 @@ def _save_graph(network, file_name): | |||
| os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) | |||
| def _exec_save_checkpoint(train_network, ckpoint_file_name): | |||
| def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True): | |||
| """ | |||
| Saves checkpoint for 'ms' backend. | |||
| Args: | |||
| train_network (Network): The train network for training. | |||
| ckpoint_file_name (str): The name of checkpoint file. | |||
| integrated_save (bool): Whether to intergrated save in automatic model parallel scene. | |||
| """ | |||
| param_dict = {} | |||
| @@ -300,9 +301,9 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name): | |||
| else: | |||
| param_data = Tensor(value.data) | |||
| # in model parallel scenario, some parameters were spliteds to all the devices, | |||
| # in automatic model parallel scenario, some parameters were spliteds to all the devices, | |||
| # which should be combined before saving | |||
| if key in train_network.parameter_layout_dict: | |||
| if integrated_save and key in train_network.parameter_layout_dict: | |||
| param_data = _get_merged_param_data(train_network, key, param_data) | |||
| each_param["data"] = param_data | |||
| @@ -308,10 +308,10 @@ def test_RunContext(): | |||
| def test_Checkpoint_Config(): | |||
| """Test CheckpointConfig all None or 0.""" | |||
| with pytest.raises(ValueError): | |||
| CheckpointConfig(0, 0, 0, 0) | |||
| CheckpointConfig(0, 0, 0, 0, True) | |||
| with pytest.raises(ValueError): | |||
| CheckpointConfig(0, None, 0, 0) | |||
| CheckpointConfig(0, None, 0, 0, True) | |||
| def test_step_end_save_graph(): | |||