|
|
|
@@ -177,8 +177,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F |
|
|
|
|
|
|
|
# in automatic model parallel scenario, some parameters were spliteds to all the devices, |
|
|
|
# which should be combined before saving |
|
|
|
if integrated_save and key in save_obj.parameter_layout_dict: |
|
|
|
param_data = _get_merged_param_data(save_obj, key, param_data) |
|
|
|
if key in save_obj.parameter_layout_dict: |
|
|
|
param_data = _get_merged_param_data(save_obj, key, param_data, integrated_save) |
|
|
|
|
|
|
|
each_param["data"] = param_data |
|
|
|
param_list.append(each_param) |
|
|
|
@@ -426,18 +426,20 @@ def _save_graph(network, file_name): |
|
|
|
os.chmod(file_name, stat.S_IRUSR) |
|
|
|
|
|
|
|
|
|
|
|
def _get_merged_param_data(net, param_name, param_data): |
|
|
|
def _get_merged_param_data(net, param_name, param_data, integrated_save): |
|
|
|
""" |
|
|
|
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map. |
|
|
|
|
|
|
|
Args: |
|
|
|
net (Cell): MindSpore network. |
|
|
|
param_name(str): The parameter name, which to be combined. |
|
|
|
param_data(Tensor):The parameter data on the local device, |
|
|
|
It was a slice of the whole parameter data. |
|
|
|
param_name (str): The parameter name, which to be combined. |
|
|
|
param_data (Tensor): The parameter data on the local device, which was a slice of the whole parameter data. |
|
|
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene. |
|
|
|
Returns: |
|
|
|
Tensor, the combined tensor which with the whole data value. |
|
|
|
""" |
|
|
|
from mindspore.parallel._cell_wrapper import get_allgather_cell |
|
|
|
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight |
|
|
|
layout = net.parameter_layout_dict[param_name] |
|
|
|
if len(layout) < 6: |
|
|
|
logger.info("layout dict does not contain the key %s", param_name) |
|
|
|
@@ -448,24 +450,26 @@ def _get_merged_param_data(net, param_name, param_data): |
|
|
|
field_size = layout[3] |
|
|
|
uniform_split = layout[4] |
|
|
|
opt_shard_group = layout[5] |
|
|
|
if uniform_split == 0: |
|
|
|
raise RuntimeError("Save checkpoint only support uniform split tensor now.") |
|
|
|
|
|
|
|
from mindspore.parallel._cell_wrapper import get_allgather_cell |
|
|
|
from mindspore.parallel._tensor import _reshape_param_data, _reshape_param_data_with_weight |
|
|
|
# while any dim is not equal to -1, means param is split and needs to be merged |
|
|
|
# pipeline parallel need to be supported here later |
|
|
|
for dim in tensor_map: |
|
|
|
if dim != -1: |
|
|
|
if opt_shard_group: |
|
|
|
allgather_net = get_allgather_cell(opt_shard_group, True) |
|
|
|
else: |
|
|
|
allgather_net = get_allgather_cell(opt_shard_group, False) |
|
|
|
if integrated_save: |
|
|
|
if uniform_split == 0: |
|
|
|
raise RuntimeError("Integrated save checkpoint only support uniform split tensor now.") |
|
|
|
# while any dim is not equal to -1, means param is split and needs to be merged |
|
|
|
# pipeline parallel need to be supported here later |
|
|
|
for dim in tensor_map: |
|
|
|
if dim != -1: |
|
|
|
if opt_shard_group: |
|
|
|
allgather_net = get_allgather_cell(opt_shard_group, True) |
|
|
|
else: |
|
|
|
allgather_net = get_allgather_cell(opt_shard_group, False) |
|
|
|
param_data = allgather_net(param_data) |
|
|
|
if field_size: |
|
|
|
return _reshape_param_data_with_weight(param_data, dev_mat, field_size) |
|
|
|
return _reshape_param_data(param_data, dev_mat, tensor_map) |
|
|
|
if opt_shard_group: |
|
|
|
allgather_net = get_allgather_cell(opt_shard_group, False) |
|
|
|
param_data = allgather_net(param_data) |
|
|
|
if field_size: |
|
|
|
return _reshape_param_data_with_weight(param_data, dev_mat, field_size) |
|
|
|
return _reshape_param_data(param_data, dev_mat, tensor_map) |
|
|
|
if opt_shard_group: |
|
|
|
elif opt_shard_group: |
|
|
|
allgather_net = get_allgather_cell(opt_shard_group, False) |
|
|
|
param_data = allgather_net(param_data) |
|
|
|
return param_data |
|
|
|
|