| @@ -374,9 +374,6 @@ class _Executor: | |||||
| obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) | obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) | ||||
| obj.load_parameter_slice(params) | obj.load_parameter_slice(params) | ||||
| if _get_parallel_mode() in ["hybrid_parallel"]: | |||||
| obj.parameter_layout_dict = self._build_parameter_layout(obj) | |||||
| # the following GE init process is not needed when use vm or ms backend | # the following GE init process is not needed when use vm or ms backend | ||||
| if enable_ge: | if enable_ge: | ||||
| # decide whether to sink based on whether the inputs is virtual or not | # decide whether to sink based on whether the inputs is virtual or not | ||||
| @@ -449,38 +446,6 @@ class _Executor: | |||||
| return self._exec_pip(obj, *args, phase=phase_real) | return self._exec_pip(obj, *args, phase=phase_real) | ||||
| raise KeyError('{} graph is not exist.'.format(phase_real)) | raise KeyError('{} graph is not exist.'.format(phase_real)) | ||||
| def _build_parameter_layout(self, obj): | |||||
| """ | |||||
| Build parameter layout, for layerwise_parallel parameter. | |||||
| Args: | |||||
| obj (Function or Cell): The function or cell instance need to be compiled. | |||||
| Returns: | |||||
| Dictionary, parameter layout info. | |||||
| """ | |||||
| parameter_layout_dict = {} | |||||
| layerwise_parallel_parameters = [] | |||||
| for key in obj.parameters_dict(): | |||||
| if obj.parameters_dict()[key].layerwise_parallel is True: | |||||
| layerwise_parallel_parameters.append(key) | |||||
| if not layerwise_parallel_parameters: | |||||
| return parameter_layout_dict | |||||
| from ..communication.management import get_group_size | |||||
| group_size = [get_group_size()] | |||||
| for key in layerwise_parallel_parameters: | |||||
| tensor_map = [0] | |||||
| shape = obj.parameters_dict()[key].data.shape() | |||||
| for x in range(len(shape)): # dim 0 set 0, others set -1 | |||||
| if x: | |||||
| tensor_map.append(-1) | |||||
| layout = [group_size, tensor_map] | |||||
| parameter_layout_dict[key] = layout | |||||
| return parameter_layout_dict | |||||
| def del_net_res(self, net_id): | def del_net_res(self, net_id): | ||||
| self._executor.del_net_res(net_id) | self._executor.del_net_res(net_id) | ||||
| @@ -24,7 +24,7 @@ import mindspore.context as context | |||||
| from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph | from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph | ||||
| from mindspore.train._utils import _make_directory | from mindspore.train._utils import _make_directory | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore._checkparam import check_int_non_negative | |||||
| from mindspore._checkparam import check_int_non_negative, check_bool | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from .summary.summary_record import _cache_summary_tensor_data | from .summary.summary_record import _cache_summary_tensor_data | ||||
| @@ -150,6 +150,8 @@ class CheckpointConfig: | |||||
| keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5. | keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5. | ||||
| keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0. | keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0. | ||||
| Can't be used with keep_checkpoint_max at the same time. | Can't be used with keep_checkpoint_max at the same time. | ||||
| integrated_save (bool): Whether to intergrated save in automatic model parall scene. Default: True. | |||||
| Integrated save function is only supported in automatic parall scene, not supported in manual parallel. | |||||
| Raises: | Raises: | ||||
| ValueError: If the input_param is None or 0. | ValueError: If the input_param is None or 0. | ||||
| @@ -163,7 +165,8 @@ class CheckpointConfig: | |||||
| save_checkpoint_steps=1, | save_checkpoint_steps=1, | ||||
| save_checkpoint_seconds=0, | save_checkpoint_seconds=0, | ||||
| keep_checkpoint_max=5, | keep_checkpoint_max=5, | ||||
| keep_checkpoint_per_n_minutes=0): | |||||
| keep_checkpoint_per_n_minutes=0, | |||||
| integrated_save=True): | |||||
| if not save_checkpoint_steps and not save_checkpoint_seconds and \ | if not save_checkpoint_steps and not save_checkpoint_seconds and \ | ||||
| not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: | not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: | ||||
| @@ -191,6 +194,8 @@ class CheckpointConfig: | |||||
| if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: | if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: | ||||
| self._keep_checkpoint_max = 1 | self._keep_checkpoint_max = 1 | ||||
| self._integrated_save = check_bool(integrated_save) | |||||
| @property | @property | ||||
| def save_checkpoint_steps(self): | def save_checkpoint_steps(self): | ||||
| """Get the value of _save_checkpoint_steps.""" | """Get the value of _save_checkpoint_steps.""" | ||||
| @@ -211,6 +216,11 @@ class CheckpointConfig: | |||||
| """Get the value of _keep_checkpoint_per_n_minutes.""" | """Get the value of _keep_checkpoint_per_n_minutes.""" | ||||
| return self._keep_checkpoint_per_n_minutes | return self._keep_checkpoint_per_n_minutes | ||||
| @property | |||||
| def integrated_save(self): | |||||
| """Get the value of _integrated_save.""" | |||||
| return self._integrated_save | |||||
| def get_checkpoint_policy(self): | def get_checkpoint_policy(self): | ||||
| """Get the policy of checkpoint.""" | """Get the policy of checkpoint.""" | ||||
| checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, | checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, | ||||
| @@ -619,7 +629,7 @@ class ModelCheckpoint(Callback): | |||||
| _set_cur_net(cb_params.train_network) | _set_cur_net(cb_params.train_network) | ||||
| cb_params.train_network.exec_checkpoint_graph() | cb_params.train_network.exec_checkpoint_graph() | ||||
| _exec_save_checkpoint(cb_params.train_network, gen_file) | |||||
| _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) | |||||
| if os.path.exists(gen_file): | if os.path.exists(gen_file): | ||||
| shutil.move(gen_file, cur_file) | shutil.move(gen_file, cur_file) | ||||
| @@ -279,13 +279,14 @@ def _save_graph(network, file_name): | |||||
| os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) | os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) | ||||
| def _exec_save_checkpoint(train_network, ckpoint_file_name): | |||||
| def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True): | |||||
| """ | """ | ||||
| Saves checkpoint for 'ms' backend. | Saves checkpoint for 'ms' backend. | ||||
| Args: | Args: | ||||
| train_network (Network): The train network for training. | train_network (Network): The train network for training. | ||||
| ckpoint_file_name (str): The name of checkpoint file. | ckpoint_file_name (str): The name of checkpoint file. | ||||
| integrated_save (bool): Whether to intergrated save in automatic model parallel scene. | |||||
| """ | """ | ||||
| param_dict = {} | param_dict = {} | ||||
| @@ -300,9 +301,9 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name): | |||||
| else: | else: | ||||
| param_data = Tensor(value.data) | param_data = Tensor(value.data) | ||||
| # in model parallel scenario, some parameters were spliteds to all the devices, | |||||
| # in automatic model parallel scenario, some parameters were spliteds to all the devices, | |||||
| # which should be combined before saving | # which should be combined before saving | ||||
| if key in train_network.parameter_layout_dict: | |||||
| if integrated_save and key in train_network.parameter_layout_dict: | |||||
| param_data = _get_merged_param_data(train_network, key, param_data) | param_data = _get_merged_param_data(train_network, key, param_data) | ||||
| each_param["data"] = param_data | each_param["data"] = param_data | ||||
| @@ -308,10 +308,10 @@ def test_RunContext(): | |||||
| def test_Checkpoint_Config(): | def test_Checkpoint_Config(): | ||||
| """Test CheckpointConfig all None or 0.""" | """Test CheckpointConfig all None or 0.""" | ||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| CheckpointConfig(0, 0, 0, 0) | |||||
| CheckpointConfig(0, 0, 0, 0, True) | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| CheckpointConfig(0, None, 0, 0) | |||||
| CheckpointConfig(0, None, 0, 0, True) | |||||
| def test_step_end_save_graph(): | def test_step_end_save_graph(): | ||||