Browse Source

load distributed ckpt for predict

tags/v1.1.0
caozhou 5 years ago
parent
commit
90cee01072
1 changed files with 204 additions and 4 deletions
  1. +204
    -4
      mindspore/train/serialization.py

+ 204
- 4
mindspore/train/serialization.py View File

@@ -18,12 +18,12 @@ import stat
import math import math
from threading import Thread, Lock from threading import Thread, Lock
import numpy as np import numpy as np

import mindspore.nn as nn import mindspore.nn as nn
import mindspore.context as context
from mindspore import log as logger from mindspore import log as logger
from mindspore.train.checkpoint_pb2 import Checkpoint from mindspore.train.checkpoint_pb2 import Checkpoint
from mindspore.train.print_pb2 import Print from mindspore.train.print_pb2 import Print
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
@@ -31,7 +31,8 @@ from mindspore.common.api import _executor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data, Validator from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export from mindspore.compression.export import quant_export
import mindspore.context as context
from mindspore.parallel._tensor import _load_tensor
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices




tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
@@ -711,7 +712,7 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
param_split_shape = list(layout.param_split_shape[0].dim) param_split_shape = list(layout.param_split_shape[0].dim)
field_size = int(layout.field) field_size = int(layout.field)
except BaseException as e: except BaseException as e:
raise ValueError(f"{e.__str__()}. please make sure that strategy matches the node_strategy.proto.")
raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.")


device_count = 1 device_count = 1
for dim in dev_mat: for dim in dev_mat:
@@ -897,3 +898,202 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)


return merged_parameter return merged_parameter


def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None):
"""
Load checkpoint into net for distributed predication.

Args:
network (Cell): Network for distributed predication.
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and
value is a list that the first four elements are [dev_matrix, tensor_map, param_split_shape, field].
If None, it means that the predication process just uses single device. Default: None.

Raises:
TypeError: The type of inputs do not match the requirements.
ValueError: Failed to load checkpoint into net.
"""
network = Validator.check_isinstance("network", network, nn.Cell)

for index, filename in enumerate(checkpoint_filenames):
if not isinstance(filename, str) or not os.path.exists(filename) \
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.")

if not _check_predict_strategy(predict_strategy):
raise ValueError(f"Please make sure that the key of predict_strategy is str, "
f"and the value is a list that the first four elements are "
f"dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (zero).")

train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
_train_strategy = build_searched_strategy(train_strategy_filename)
train_strategy = _convert_to_list(_train_strategy)

train_dev_count = 1
for dim in train_strategy[list(train_strategy.keys())[0]][0]:
train_dev_count *= dim
if train_dev_count != len(checkpoint_filenames):
raise ValueError(
f"The length of checkpoint_filenames should be equal to the device count of training process. "
f"The length is {len(checkpoint_filenames)} but the device count is {train_dev_count}.")

rank_list = _infer_rank_list(train_strategy, predict_strategy)

param_dict = {}
for _, param in network.parameters_and_names():
sliced_params = []
if param.name not in rank_list.keys():
continue
param_rank = rank_list[param.name]
for rank in param_rank:
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name)
sliced_params.append(sliced_param)
if len(sliced_params) == 1:
split_param = sliced_params[0]
else:
param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
_param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
param_dict[param.name] = split_param

load_param_into_net(network, param_dict)


def _check_predict_strategy(predict_strategy):
"""Check predict strategy."""
def _check_int_list(arg):
if not isinstance(arg, list):
return False
for item in arg:
if not isinstance(item, int):
return False
return True

if predict_strategy is None:
return True

predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
for key in predict_strategy.keys():
if not isinstance(key, str) or not isinstance(predict_strategy[key], list) or len(predict_strategy[key]) < 4:
return False
dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
not (_check_int_list(param_split_shape) or not param_split_shape) or \
not (isinstance(field_size, int) and field_size == 0):
return False
return True


def _convert_to_list(strategy):
"""Convert ParallelLayouts object to specified list."""
train_map = {}
for param_name in strategy.keys():
try:
layout = strategy.get(param_name)
dev_mat = list(layout.dev_matrix[0].dim)
tensor_map = list(layout.tensor_map[0].dim)
param_split_shape = list(layout.param_split_shape[0].dim)
field_size = int(layout.field)
train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size]
except BaseException as e:
raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.")
return train_map


def _convert_to_layout(param_name, tensor_layout):
"""Convert list to ParallelLayouts object."""
strategy = {}
try:
layout = ParallelLayouts()
layout.field = tensor_layout[3]

dev_matrix = layout.dev_matrix.add()
for item in tensor_layout[0]:
dev_matrix.dim.append(item)

tensor_map = layout.tensor_map.add()
for item in tensor_layout[1]:
tensor_map.dim.append(item)

param_split_shape = layout.param_split_shape.add()
for item in tensor_layout[2]:
param_split_shape.dim.append(item)
except BaseException as e:
raise ValueError("Convert failed. " + e.__str__())

strategy[param_name] = layout
return strategy


def _merge_and_split(sliced_params, train_strategy, predict_strategy):
"""Merge sliced parameter and split it according to the predict strategy."""
merged_param = merge_sliced_parameter(sliced_params, train_strategy)
if predict_strategy is None:
return merged_param
param_name = merged_param.name
tensor_layout = predict_strategy[param_name]
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1])
requires_grad = merged_param.requires_grad
layerwise_parallel = merged_param.layerwise_parallel
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
return split_param


def _load_single_param(ckpt_file_name, param_name):
"""Load a parameter from checkpoint."""
logger.info("Execute the process of loading checkpoint files.")
checkpoint_list = Checkpoint()

try:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__())

parameter = None
try:
param_data_list = []
for element_id, element in enumerate(checkpoint_list.value):
if element.tag != param_name:
continue
data = element.tensor.tensor_content
data_type = element.tensor.tensor_type
np_type = tensor_to_np_type[data_type]
ms_type = tensor_to_ms_type[data_type]
element_data = np.frombuffer(data, np_type)
param_data_list.append(element_data)
if (element_id == len(checkpoint_list.value) - 1) or \
(element.tag != checkpoint_list.value[element_id + 1].tag):
param_data = np.concatenate((param_data_list), axis=0)
param_data_list.clear()
dims = element.tensor.dims
if dims == [0]:
if 'Float' in data_type:
param_data = float(param_data[0])
elif 'Int' in data_type:
param_data = int(param_data[0])
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
elif dims == [1]:
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
else:
param_dim = []
for dim in dims:
param_dim.append(dim)
param_value = param_data.reshape(param_dim)
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag)
break
logger.info("Loading checkpoint files process is finished.")

except BaseException as e:
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__())

if parameter is None:
raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, "
f"please check parameter name or checkpoint file.")
return parameter

Loading…
Cancel
Save