Browse Source

Pre Merge pull request !16055 from changzherui/mod_dis_ckpt_r1.2

pull/16055/MERGE
changzherui Gitee 4 years ago
parent
commit
0fc0d615fb
1 changed files with 57 additions and 82 deletions
  1. +57
    -82
      mindspore/train/serialization.py

+ 57
- 82
mindspore/train/serialization.py View File

@@ -19,6 +19,7 @@ import stat
import math
import shutil
from threading import Thread, Lock
from collections import defaultdict
import numpy as np

import mindspore.nn as nn
@@ -1062,38 +1063,27 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
return merged_parameter


def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None):
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, train_strategy_filename=None):
"""
Load checkpoint into net for distributed predication.

Args:
network (Cell): Network for distributed predication.
checkpoint_filenames (list(str)): The name of Checkpoint files
in order of rank id.
predict_strategy (Optional(dict)): Strategy of predication process, whose key
is parameter name, and value is a list or a tuple that the first four
elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
it means that the predication process just uses single device.
Default: None.
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
it means that the predication process just uses single device. Default: None.

Raises:
TypeError: The type of inputs do not match the requirements.
ValueError: Failed to load checkpoint into net.
"""
network = Validator.check_isinstance("network", network, nn.Cell)
_check_checkpoint_file(checkpoint_filenames)
_check_predict_strategy(predict_strategy)

for index, filename in enumerate(checkpoint_filenames):
if not isinstance(filename, str) or not os.path.exists(filename) \
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.")

if not _check_predict_strategy(predict_strategy):
raise ValueError(f"Please make sure that the key of predict_strategy is str, "
f"and the value is a list or a tuple that the first four elements are "
f"dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (zero).")

train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
if train_strategy_filename is None:
train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
_train_strategy = build_searched_strategy(train_strategy_filename)
train_strategy = _convert_to_list(_train_strategy)

@@ -1107,6 +1097,12 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=

rank_list = _infer_rank_list(train_strategy, predict_strategy)

param_total_dict = defaultdict(dict)
for file_index, file_name in enumerate(checkpoint_filenames):
ckpt_dict = load_checkpoint(file_name)
for param_name, param in ckpt_dict.items():
param_total_dict[param_name][file_index] = param

param_dict = {}
for _, param in network.parameters_and_names():
sliced_params = []
@@ -1114,8 +1110,31 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
continue
param_rank = rank_list[param.name][0]
skip_merge_split = rank_list[param.name][1]
shard_stride = train_strategy[param.name][4]
if train_strategy[param.name][5]:
shard_size = len(checkpoint_filenames) / shard_stride / train_strategy[param.name][5]
else:
shard_size = 0
for rank in param_rank:
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name)
param_total_list = list(range(0, len(checkpoint_filenames)))
if shard_size > 0:
shard_total_list = [param_total_list[i:i + shard_size] for i in
range(0, len(checkpoint_filenames), shard_size)]
param_total_list = shard_total_list[rank // shard_size]
if shard_stride > 0:
param_stride = []
# merge pre parameter
param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
param_index = list(set(param_index))
param_index.sort()
for rank_num in param_index:
param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())

sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
else:
sliced_param = param_total_dict[param.name][rank]

sliced_params.append(sliced_param)
if skip_merge_split:
split_param = sliced_params[0]
@@ -1139,19 +1158,31 @@ def _check_predict_strategy(predict_strategy):
return True

if predict_strategy is None:
return True
return
flag = True
predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
for key in predict_strategy.keys():
if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
or len(predict_strategy[key]) < 4:
return False
flag = False
dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
not (_check_int_list(param_split_shape) or not param_split_shape) or \
not (isinstance(field_size, int) and field_size == 0):
return False
return True
flag = False
if not flag:
raise ValueError(f"Please make sure that the key of predict_strategy is str, "
f"and the value is a list or a tuple that the first four elements are "
f"dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (zero).")


def _check_checkpoint_file(checkpoint_filenames):
"""Check checkpoint file name."""
for index, filename in enumerate(checkpoint_filenames):
if not isinstance(filename, str) or not os.path.exists(filename) \
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.")


def _convert_to_list(strategy):
@@ -1207,59 +1238,3 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
layerwise_parallel = merged_param.layerwise_parallel
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
return split_param


def _load_single_param(ckpt_file_name, param_name):
"""Load a parameter from checkpoint."""
checkpoint_list = Checkpoint()

try:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
logger.error("Failed to read the checkpoint file `%s` during load single parameter,"
" please check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__())

parameter = None
try:
param_data_list = []
for element_id, element in enumerate(checkpoint_list.value):
if element.tag != param_name:
continue
data = element.tensor.tensor_content
data_type = element.tensor.tensor_type
np_type = tensor_to_np_type[data_type]
ms_type = tensor_to_ms_type[data_type]
element_data = np.frombuffer(data, np_type)
param_data_list.append(element_data)
if (element_id == len(checkpoint_list.value) - 1) or \
(element.tag != checkpoint_list.value[element_id + 1].tag):
param_data = np.concatenate((param_data_list), axis=0)
param_data_list.clear()
dims = element.tensor.dims
if dims == [0]:
if 'Float' in data_type:
param_data = float(param_data[0])
elif 'Int' in data_type:
param_data = int(param_data[0])
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
elif dims == [1]:
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
else:
param_dim = []
for dim in dims:
param_dim.append(dim)
param_value = param_data.reshape(param_dim)
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag)
break

except BaseException as e:
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__())

if parameter is None:
raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, "
f"please check parameter name or checkpoint file.")
return parameter

Loading…
Cancel
Save