Browse Source

Pre Merge pull request !16055 from changzherui/mod_dis_ckpt_r1.2

pull/16055/MERGE
changzherui Gitee 4 years ago
parent
commit
0fc0d615fb
1 changed files with 57 additions and 82 deletions
  1. +57
    -82
      mindspore/train/serialization.py

+ 57
- 82
mindspore/train/serialization.py View File

@@ -19,6 +19,7 @@ import stat
import math import math
import shutil import shutil
from threading import Thread, Lock from threading import Thread, Lock
from collections import defaultdict
import numpy as np import numpy as np


import mindspore.nn as nn import mindspore.nn as nn
@@ -1062,38 +1063,27 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
return merged_parameter return merged_parameter




def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None):
def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None, train_strategy_filename=None):
""" """
Load checkpoint into net for distributed predication. Load checkpoint into net for distributed predication.


Args: Args:
network (Cell): Network for distributed predication. network (Cell): Network for distributed predication.
checkpoint_filenames (list(str)): The name of Checkpoint files
in order of rank id.
predict_strategy (Optional(dict)): Strategy of predication process, whose key
is parameter name, and value is a list or a tuple that the first four
elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
it means that the predication process just uses single device.
Default: None.
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
it means that the predication process just uses single device. Default: None.


Raises: Raises:
TypeError: The type of inputs do not match the requirements. TypeError: The type of inputs do not match the requirements.
ValueError: Failed to load checkpoint into net. ValueError: Failed to load checkpoint into net.
""" """
network = Validator.check_isinstance("network", network, nn.Cell) network = Validator.check_isinstance("network", network, nn.Cell)
_check_checkpoint_file(checkpoint_filenames)
_check_predict_strategy(predict_strategy)


for index, filename in enumerate(checkpoint_filenames):
if not isinstance(filename, str) or not os.path.exists(filename) \
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.")

if not _check_predict_strategy(predict_strategy):
raise ValueError(f"Please make sure that the key of predict_strategy is str, "
f"and the value is a list or a tuple that the first four elements are "
f"dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (zero).")

train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
if train_strategy_filename is None:
train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
_train_strategy = build_searched_strategy(train_strategy_filename) _train_strategy = build_searched_strategy(train_strategy_filename)
train_strategy = _convert_to_list(_train_strategy) train_strategy = _convert_to_list(_train_strategy)


@@ -1107,6 +1097,12 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=


rank_list = _infer_rank_list(train_strategy, predict_strategy) rank_list = _infer_rank_list(train_strategy, predict_strategy)


param_total_dict = defaultdict(dict)
for file_index, file_name in enumerate(checkpoint_filenames):
ckpt_dict = load_checkpoint(file_name)
for param_name, param in ckpt_dict.items():
param_total_dict[param_name][file_index] = param

param_dict = {} param_dict = {}
for _, param in network.parameters_and_names(): for _, param in network.parameters_and_names():
sliced_params = [] sliced_params = []
@@ -1114,8 +1110,31 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
continue continue
param_rank = rank_list[param.name][0] param_rank = rank_list[param.name][0]
skip_merge_split = rank_list[param.name][1] skip_merge_split = rank_list[param.name][1]
shard_stride = train_strategy[param.name][4]
if train_strategy[param.name][5]:
shard_size = len(checkpoint_filenames) / shard_stride / train_strategy[param.name][5]
else:
shard_size = 0
for rank in param_rank: for rank in param_rank:
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name)
param_total_list = list(range(0, len(checkpoint_filenames)))
if shard_size > 0:
shard_total_list = [param_total_list[i:i + shard_size] for i in
range(0, len(checkpoint_filenames), shard_size)]
param_total_list = shard_total_list[rank // shard_size]
if shard_stride > 0:
param_stride = []
# merge pre parameter
param_index = param_total_list[0:param_total_list.index(rank) + 1][::-1][::shard_stride]
param_index.extend(param_total_list[param_total_list.index(rank):][::shard_stride])
param_index = list(set(param_index))
param_index.sort()
for rank_num in param_index:
param_stride.append(param_total_dict[param.name][rank_num].data.asnumpy())

sliced_param = Parameter(Tensor(np.concatenate(param_stride)), name=param.name)
else:
sliced_param = param_total_dict[param.name][rank]

sliced_params.append(sliced_param) sliced_params.append(sliced_param)
if skip_merge_split: if skip_merge_split:
split_param = sliced_params[0] split_param = sliced_params[0]
@@ -1139,19 +1158,31 @@ def _check_predict_strategy(predict_strategy):
return True return True


if predict_strategy is None: if predict_strategy is None:
return True
return
flag = True
predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict) predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
for key in predict_strategy.keys(): for key in predict_strategy.keys():
if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \ if not isinstance(key, str) or not isinstance(predict_strategy[key], (list, tuple)) \
or len(predict_strategy[key]) < 4: or len(predict_strategy[key]) < 4:
return False
flag = False
dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4] dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \ if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
not (_check_int_list(param_split_shape) or not param_split_shape) or \ not (_check_int_list(param_split_shape) or not param_split_shape) or \
not (isinstance(field_size, int) and field_size == 0): not (isinstance(field_size, int) and field_size == 0):
return False
return True
flag = False
if not flag:
raise ValueError(f"Please make sure that the key of predict_strategy is str, "
f"and the value is a list or a tuple that the first four elements are "
f"dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (zero).")


def _check_checkpoint_file(checkpoint_filenames):
"""Check checkpoint file name."""
for index, filename in enumerate(checkpoint_filenames):
if not isinstance(filename, str) or not os.path.exists(filename) \
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.")




def _convert_to_list(strategy): def _convert_to_list(strategy):
@@ -1207,59 +1238,3 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
layerwise_parallel = merged_param.layerwise_parallel layerwise_parallel = merged_param.layerwise_parallel
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
return split_param return split_param


def _load_single_param(ckpt_file_name, param_name):
"""Load a parameter from checkpoint."""
checkpoint_list = Checkpoint()

try:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
logger.error("Failed to read the checkpoint file `%s` during load single parameter,"
" please check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__())

parameter = None
try:
param_data_list = []
for element_id, element in enumerate(checkpoint_list.value):
if element.tag != param_name:
continue
data = element.tensor.tensor_content
data_type = element.tensor.tensor_type
np_type = tensor_to_np_type[data_type]
ms_type = tensor_to_ms_type[data_type]
element_data = np.frombuffer(data, np_type)
param_data_list.append(element_data)
if (element_id == len(checkpoint_list.value) - 1) or \
(element.tag != checkpoint_list.value[element_id + 1].tag):
param_data = np.concatenate((param_data_list), axis=0)
param_data_list.clear()
dims = element.tensor.dims
if dims == [0]:
if 'Float' in data_type:
param_data = float(param_data[0])
elif 'Int' in data_type:
param_data = int(param_data[0])
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
elif dims == [1]:
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
else:
param_dim = []
for dim in dims:
param_dim.append(dim)
param_value = param_data.reshape(param_dim)
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag)
break

except BaseException as e:
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__())

if parameter is None:
raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, "
f"please check parameter name or checkpoint file.")
return parameter

Loading…
Cancel
Save