|
|
|
@@ -889,7 +889,6 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): |
|
|
|
raise ValueError(f"The sliced_parameters length should be equal to device_count. " |
|
|
|
f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}.") |
|
|
|
|
|
|
|
merged_tensor = None |
|
|
|
if not param_split_shape: |
|
|
|
if not is_even: |
|
|
|
raise ValueError("The shape of every parameter in sliced_parameters should be the same " |
|
|
|
@@ -1052,7 +1051,6 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): |
|
|
|
layerwise_parallel = sliced_parameters[0].layerwise_parallel |
|
|
|
requires_grad = sliced_parameters[0].requires_grad |
|
|
|
sliced_data = [parameter.data.asnumpy() for parameter in sliced_parameters] |
|
|
|
merged_parameter = None |
|
|
|
|
|
|
|
if not strategy: |
|
|
|
merged_tensor = Tensor(np.concatenate(sliced_data)) |
|
|
|
@@ -1121,7 +1119,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= |
|
|
|
param_rank = rank_list[param.name][0] |
|
|
|
skip_merge_split = rank_list[param.name][1] |
|
|
|
for rank in param_rank: |
|
|
|
sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) |
|
|
|
sliced_param = load_checkpoint(checkpoint_filenames[rank])[param.name] |
|
|
|
sliced_params.append(sliced_param) |
|
|
|
if skip_merge_split: |
|
|
|
split_param = sliced_params[0] |
|
|
|
@@ -1213,59 +1211,3 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy): |
|
|
|
layerwise_parallel = merged_param.layerwise_parallel |
|
|
|
split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) |
|
|
|
return split_param |
|
|
|
|
|
|
|
|
|
|
|
def _load_single_param(ckpt_file_name, param_name): |
|
|
|
"""Load a parameter from checkpoint.""" |
|
|
|
checkpoint_list = Checkpoint() |
|
|
|
|
|
|
|
try: |
|
|
|
with open(ckpt_file_name, "rb") as f: |
|
|
|
pb_content = f.read() |
|
|
|
checkpoint_list.ParseFromString(pb_content) |
|
|
|
except BaseException as e: |
|
|
|
logger.error("Failed to read the checkpoint file `%s` during load single parameter," |
|
|
|
" please check the correct of the file.", ckpt_file_name) |
|
|
|
raise ValueError(e.__str__()) |
|
|
|
|
|
|
|
parameter = None |
|
|
|
try: |
|
|
|
param_data_list = [] |
|
|
|
for element_id, element in enumerate(checkpoint_list.value): |
|
|
|
if element.tag != param_name: |
|
|
|
continue |
|
|
|
data = element.tensor.tensor_content |
|
|
|
data_type = element.tensor.tensor_type |
|
|
|
np_type = tensor_to_np_type[data_type] |
|
|
|
ms_type = tensor_to_ms_type[data_type] |
|
|
|
element_data = np.frombuffer(data, np_type) |
|
|
|
param_data_list.append(element_data) |
|
|
|
if (element_id == len(checkpoint_list.value) - 1) or \ |
|
|
|
(element.tag != checkpoint_list.value[element_id + 1].tag): |
|
|
|
param_data = np.concatenate((param_data_list), axis=0) |
|
|
|
param_data_list.clear() |
|
|
|
dims = element.tensor.dims |
|
|
|
if dims == [0]: |
|
|
|
if 'Float' in data_type: |
|
|
|
param_data = float(param_data[0]) |
|
|
|
elif 'Int' in data_type: |
|
|
|
param_data = int(param_data[0]) |
|
|
|
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) |
|
|
|
elif dims == [1]: |
|
|
|
parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) |
|
|
|
else: |
|
|
|
param_dim = [] |
|
|
|
for dim in dims: |
|
|
|
param_dim.append(dim) |
|
|
|
param_value = param_data.reshape(param_dim) |
|
|
|
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) |
|
|
|
break |
|
|
|
|
|
|
|
except BaseException as e: |
|
|
|
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) |
|
|
|
raise RuntimeError(e.__str__()) |
|
|
|
|
|
|
|
if parameter is None: |
|
|
|
raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, " |
|
|
|
f"please check parameter name or checkpoint file.") |
|
|
|
return parameter |