Browse Source

load distributed ckpt for predict

tags/v1.1.0
caozhou 5 years ago
parent
commit
90cee01072
1 changed files with 204 additions and 4 deletions
  1. +204
    -4
      mindspore/train/serialization.py

+ 204
- 4
mindspore/train/serialization.py View File

@@ -18,12 +18,12 @@ import stat
import math
from threading import Thread, Lock
import numpy as np

import mindspore.nn as nn
import mindspore.context as context
from mindspore import log as logger
from mindspore.train.checkpoint_pb2 import Checkpoint
from mindspore.train.print_pb2 import Print
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap, ParallelLayouts
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
@@ -31,7 +31,8 @@ from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export
import mindspore.context as context
from mindspore.parallel._tensor import _load_tensor
from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices


tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
@@ -711,7 +712,7 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
param_split_shape = list(layout.param_split_shape[0].dim)
field_size = int(layout.field)
except BaseException as e:
raise ValueError(f"{e.__str__()}. please make sure that strategy matches the node_strategy.proto.")
raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.")

device_count = 1
for dim in dev_mat:
@@ -897,3 +898,202 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)

return merged_parameter


def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=None):
"""
Load checkpoint into net for distributed predication.

Args:
network (Cell): Network for distributed predication.
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and
value is a list that the first four elements are [dev_matrix, tensor_map, param_split_shape, field].
If None, it means that the predication process just uses single device. Default: None.

Raises:
TypeError: The type of inputs do not match the requirements.
ValueError: Failed to load checkpoint into net.
"""
network = Validator.check_isinstance("network", network, nn.Cell)

for index, filename in enumerate(checkpoint_filenames):
if not isinstance(filename, str) or not os.path.exists(filename) \
or filename[-5:] != ".ckpt" or os.path.getsize(filename) == 0:
raise ValueError(f"Please make sure that the {filename} at index {index} is a valid checkpoint file.")

if not _check_predict_strategy(predict_strategy):
raise ValueError(f"Please make sure that the key of predict_strategy is str, "
f"and the value is a list that the first four elements are "
f"dev_matrix (list[int]), tensor_map (list[int]), "
f"param_split_shape (list[int]) and field_size (zero).")

train_strategy_filename = context.get_auto_parallel_context("strategy_ckpt_load_file")
_train_strategy = build_searched_strategy(train_strategy_filename)
train_strategy = _convert_to_list(_train_strategy)

train_dev_count = 1
for dim in train_strategy[list(train_strategy.keys())[0]][0]:
train_dev_count *= dim
if train_dev_count != len(checkpoint_filenames):
raise ValueError(
f"The length of checkpoint_filenames should be equal to the device count of training process. "
f"The length is {len(checkpoint_filenames)} but the device count is {train_dev_count}.")

rank_list = _infer_rank_list(train_strategy, predict_strategy)

param_dict = {}
for _, param in network.parameters_and_names():
sliced_params = []
if param.name not in rank_list.keys():
continue
param_rank = rank_list[param.name]
for rank in param_rank:
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name)
sliced_params.append(sliced_param)
if len(sliced_params) == 1:
split_param = sliced_params[0]
else:
param_unique_strategy = _remove_repeated_slices(train_strategy[param.name])
_param_unique_strategy = _convert_to_layout(param.name, param_unique_strategy)
split_param = _merge_and_split(sliced_params, _param_unique_strategy, predict_strategy)
param_dict[param.name] = split_param

load_param_into_net(network, param_dict)


def _check_predict_strategy(predict_strategy):
"""Check predict strategy."""
def _check_int_list(arg):
if not isinstance(arg, list):
return False
for item in arg:
if not isinstance(item, int):
return False
return True

if predict_strategy is None:
return True

predict_strategy = Validator.check_isinstance("predict_strategy", predict_strategy, dict)
for key in predict_strategy.keys():
if not isinstance(key, str) or not isinstance(predict_strategy[key], list) or len(predict_strategy[key]) < 4:
return False
dev_matrix, tensor_map, param_split_shape, field_size = predict_strategy[key][:4]
if not _check_int_list(dev_matrix) or not _check_int_list(tensor_map) or \
not (_check_int_list(param_split_shape) or not param_split_shape) or \
not (isinstance(field_size, int) and field_size == 0):
return False
return True


def _convert_to_list(strategy):
"""Convert ParallelLayouts object to specified list."""
train_map = {}
for param_name in strategy.keys():
try:
layout = strategy.get(param_name)
dev_mat = list(layout.dev_matrix[0].dim)
tensor_map = list(layout.tensor_map[0].dim)
param_split_shape = list(layout.param_split_shape[0].dim)
field_size = int(layout.field)
train_map[param_name] = [dev_mat, tensor_map, param_split_shape, field_size]
except BaseException as e:
raise ValueError(f"{e.__str__()}. Please make sure that strategy matches the node_strategy.proto.")
return train_map


def _convert_to_layout(param_name, tensor_layout):
"""Convert list to ParallelLayouts object."""
strategy = {}
try:
layout = ParallelLayouts()
layout.field = tensor_layout[3]

dev_matrix = layout.dev_matrix.add()
for item in tensor_layout[0]:
dev_matrix.dim.append(item)

tensor_map = layout.tensor_map.add()
for item in tensor_layout[1]:
tensor_map.dim.append(item)

param_split_shape = layout.param_split_shape.add()
for item in tensor_layout[2]:
param_split_shape.dim.append(item)
except BaseException as e:
raise ValueError("Convert failed. " + e.__str__())

strategy[param_name] = layout
return strategy


def _merge_and_split(sliced_params, train_strategy, predict_strategy):
"""Merge sliced parameter and split it according to the predict strategy."""
merged_param = merge_sliced_parameter(sliced_params, train_strategy)
if predict_strategy is None:
return merged_param
param_name = merged_param.name
tensor_layout = predict_strategy[param_name]
split_tensor = _load_tensor(merged_param.data, tensor_layout[0], tensor_layout[1])
requires_grad = merged_param.requires_grad
layerwise_parallel = merged_param.layerwise_parallel
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel)
return split_param


def _load_single_param(ckpt_file_name, param_name):
"""Load a parameter from checkpoint."""
logger.info("Execute the process of loading checkpoint files.")
checkpoint_list = Checkpoint()

try:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__())

parameter = None
try:
param_data_list = []
for element_id, element in enumerate(checkpoint_list.value):
if element.tag != param_name:
continue
data = element.tensor.tensor_content
data_type = element.tensor.tensor_type
np_type = tensor_to_np_type[data_type]
ms_type = tensor_to_ms_type[data_type]
element_data = np.frombuffer(data, np_type)
param_data_list.append(element_data)
if (element_id == len(checkpoint_list.value) - 1) or \
(element.tag != checkpoint_list.value[element_id + 1].tag):
param_data = np.concatenate((param_data_list), axis=0)
param_data_list.clear()
dims = element.tensor.dims
if dims == [0]:
if 'Float' in data_type:
param_data = float(param_data[0])
elif 'Int' in data_type:
param_data = int(param_data[0])
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
elif dims == [1]:
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag)
else:
param_dim = []
for dim in dims:
param_dim.append(dim)
param_value = param_data.reshape(param_dim)
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag)
break
logger.info("Loading checkpoint files process is finished.")

except BaseException as e:
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__())

if parameter is None:
raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, "
f"please check parameter name or checkpoint file.")
return parameter

Loading…
Cancel
Save