Browse Source

!9466 support disable integrated save for parallel optimizer

From: @gong_zi_yan
Reviewed-by: @caozhou_huawei,@yao_yf,@stsuteng,@kisnwang
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e5fd738100
1 changed files with 26 additions and 22 deletions
  1. +26
    -22
      mindspore/train/serialization.py

+ 26
- 22
mindspore/train/serialization.py View File

@@ -177,8 +177,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F

# in automatic model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving
if integrated_save and key in save_obj.parameter_layout_dict:
param_data = _get_merged_param_data(save_obj, key, param_data)
if key in save_obj.parameter_layout_dict:
param_data = _get_merged_param_data(save_obj, key, param_data, integrated_save)

each_param["data"] = param_data
param_list.append(each_param)
@@ -426,18 +426,20 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IRUSR)


def _get_merged_param_data(net, param_name, param_data):
def _get_merged_param_data(net, param_name, param_data, integrated_save):
"""
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.

Args:
net (Cell): MindSpore network.
param_name(str): The parameter name, which to be combined.
param_data(Tensor):The parameter data on the local device,
It was a slice of the whole parameter data.
param_name (str): The parameter name, which to be combined.
param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data.
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
Returns:
Tensor, the combined tensor which with the whole data value.
"""
from mindspore.parallel._cell_wrapper import get_allgather_cell
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
layout = net.parameter_layout_dict[param_name]
if len(layout) < 6:
logger.info("layout dict does not contain the key %s", param_name)
@@ -448,24 +450,26 @@ def _get_merged_param_data(net, param_name, param_data):
field_size = layout[3]
uniform_split = layout[4]
opt_shard_group = layout[5]
if uniform_split == 0:
raise RuntimeError("Save checkpoint only support uniform split tensor now.")

from mindspore.parallel._cell_wrapper import get_allgather_cell
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight
# while any dim is not equal to -1, means param is split and needs to be merged
# pipeline parallel need to be supported here later
for dim in tensor_map:
if dim != -1:
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, True)
else:
allgather_net = get_allgather_cell(opt_shard_group, False)
if integrated_save:
if uniform_split == 0:
raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.")
# while any dim is not equal to -1, means param is split and needs to be merged
# pipeline parallel need to be supported here later
for dim in tensor_map:
if dim != -1:
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, True)
else:
allgather_net = get_allgather_cell(opt_shard_group, False)
param_data = allgather_net(param_data)
if field_size:
return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
return _reshape_param_data(param_data, dev_mat, tensor_map)
if opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, False)
param_data = allgather_net(param_data)
if field_size:
return _reshape_param_data_with_weight(param_data, dev_mat, field_size)
return _reshape_param_data(param_data, dev_mat, tensor_map)
if opt_shard_group:
elif opt_shard_group:
allgather_net = get_allgather_cell(opt_shard_group, False)
param_data = allgather_net(param_data)
return param_data


Loading…
Cancel
Save