Browse Source

Support to config whether to save integeated checkpoint, in auto model parallel scene

tags/v0.3.0-alpha
WeibiaoYu chang zherui 5 years ago
parent
commit
5f1fedaae7
4 changed files with 19 additions and 43 deletions
  1. +0
    -35
      mindspore/common/api.py
  2. +13
    -3
      mindspore/train/callback.py
  3. +4
    -3
      mindspore/train/serialization.py
  4. +2
    -2
      tests/ut/python/utils/test_callback.py

+ 0
- 35
mindspore/common/api.py View File

@@ -374,9 +374,6 @@ class _Executor:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
obj.load_parameter_slice(params)

if _get_parallel_mode() in ["hybrid_parallel"]:
obj.parameter_layout_dict = self._build_parameter_layout(obj)

# the following GE init process is not needed when use vm or ms backend
if enable_ge:
# decide whether to sink based on whether the inputs is virtual or not
@@ -449,38 +446,6 @@ class _Executor:
return self._exec_pip(obj, *args, phase=phase_real)
raise KeyError('{} graph is not exist.'.format(phase_real))

def _build_parameter_layout(self, obj):
"""
Build parameter layout, for layerwise_parallel parameter.

Args:
obj (Function or Cell): The function or cell instance need to be compiled.

Returns:
Dictionary, parameter layout info.
"""
parameter_layout_dict = {}
layerwise_parallel_parameters = []
for key in obj.parameters_dict():
if obj.parameters_dict()[key].layerwise_parallel is True:
layerwise_parallel_parameters.append(key)

if not layerwise_parallel_parameters:
return parameter_layout_dict

from ..communication.management import get_group_size
group_size = [get_group_size()]
for key in layerwise_parallel_parameters:
tensor_map = [0]
shape = obj.parameters_dict()[key].data.shape()
for x in range(len(shape)): # dim 0 set 0, others set -1
if x:
tensor_map.append(-1)
layout = [group_size, tensor_map]
parameter_layout_dict[key] = layout

return parameter_layout_dict

def del_net_res(self, net_id):
self._executor.del_net_res(net_id)



+ 13
- 3
mindspore/train/callback.py View File

@@ -24,7 +24,7 @@ import mindspore.context as context
from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
from mindspore.train._utils import _make_directory
from mindspore import log as logger
from mindspore._checkparam import check_int_non_negative
from mindspore._checkparam import check_int_non_negative, check_bool
from mindspore.common.tensor import Tensor
from .summary.summary_record import _cache_summary_tensor_data

@@ -150,6 +150,8 @@ class CheckpointConfig:
keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parall scene. Default: True.
Integrated save function is only supported in automatic parall scene, not supported in manual parallel.

Raises:
ValueError: If the input_param is None or 0.
@@ -163,7 +165,8 @@ class CheckpointConfig:
save_checkpoint_steps=1,
save_checkpoint_seconds=0,
keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0):
keep_checkpoint_per_n_minutes=0,
integrated_save=True):

if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
@@ -191,6 +194,8 @@ class CheckpointConfig:
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
self._keep_checkpoint_max = 1

self._integrated_save = check_bool(integrated_save)

@property
def save_checkpoint_steps(self):
"""Get the value of _save_checkpoint_steps."""
@@ -211,6 +216,11 @@ class CheckpointConfig:
"""Get the value of _keep_checkpoint_per_n_minutes."""
return self._keep_checkpoint_per_n_minutes

@property
def integrated_save(self):
"""Get the value of _integrated_save."""
return self._integrated_save

def get_checkpoint_policy(self):
"""Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
@@ -619,7 +629,7 @@ class ModelCheckpoint(Callback):
_set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()

_exec_save_checkpoint(cb_params.train_network, gen_file)
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)

if os.path.exists(gen_file):
shutil.move(gen_file, cur_file)


+ 4
- 3
mindspore/train/serialization.py View File

@@ -279,13 +279,14 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)


def _exec_save_checkpoint(train_network, ckpoint_file_name):
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
"""
Saves checkpoint for 'ms' backend.

Args:
train_network (Network): The train network for training.
ckpoint_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
"""

param_dict = {}
@@ -300,9 +301,9 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name):
else:
param_data = Tensor(value.data)

# in model parallel scenario, some parameters were spliteds to all the devices,
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving
if key in train_network.parameter_layout_dict:
if integrated_save and key in train_network.parameter_layout_dict:
param_data = _get_merged_param_data(train_network, key, param_data)

each_param["data"] = param_data


+ 2
- 2
tests/ut/python/utils/test_callback.py View File

@@ -308,10 +308,10 @@ def test_RunContext():
def test_Checkpoint_Config():
"""Test CheckpointConfig all None or 0."""
with pytest.raises(ValueError):
CheckpointConfig(0, 0, 0, 0)
CheckpointConfig(0, 0, 0, 0, True)

with pytest.raises(ValueError):
CheckpointConfig(0, None, 0, 0)
CheckpointConfig(0, None, 0, 0, True)


def test_step_end_save_graph():


Loading…
Cancel
Save