|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
"""Model and parameters serialization.""" |
|
|
|
import os |
|
|
|
import stat |
|
|
|
import math |
|
|
|
from threading import Thread, Lock |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
@@ -42,6 +43,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin |
|
|
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} |
|
|
|
|
|
|
|
_ckpt_mutex = Lock() |
|
|
|
SLICE_SIZE = 512 * 1024 * 1024 |
|
|
|
|
|
|
|
|
|
|
|
def _special_process_par(par, new_par): |
|
|
|
""" |
|
|
|
@@ -105,26 +108,38 @@ def _update_param(param, new_param): |
|
|
|
|
|
|
|
def _exec_save(ckpt_file_name, data_list): |
|
|
|
"""Execute save checkpoint into file process.""" |
|
|
|
checkpoint_list = Checkpoint() |
|
|
|
|
|
|
|
try: |
|
|
|
with _ckpt_mutex: |
|
|
|
for name, value in data_list.items(): |
|
|
|
param_value = checkpoint_list.value.add() |
|
|
|
param_value.tag = name |
|
|
|
param_tensor = param_value.tensor |
|
|
|
param_tensor.dims.extend(value[0]) |
|
|
|
param_tensor.tensor_type = value[1] |
|
|
|
param_tensor.tensor_content = value[2].tostring() |
|
|
|
|
|
|
|
with open(ckpt_file_name, "wb") as f: |
|
|
|
f.write(checkpoint_list.SerializeToString()) |
|
|
|
os.chmod(ckpt_file_name, stat.S_IRUSR) |
|
|
|
if os.path.exists(ckpt_file_name): |
|
|
|
os.remove(ckpt_file_name) |
|
|
|
with open(ckpt_file_name, "ab") as f: |
|
|
|
for name, value in data_list.items(): |
|
|
|
data_size = value[2].nbytes |
|
|
|
if data_size > SLICE_SIZE: |
|
|
|
slice_count = math.ceil(data_size / SLICE_SIZE) |
|
|
|
param_slice_list = np.array_split(value[2], slice_count) |
|
|
|
else: |
|
|
|
param_slice_list = [value[2]] |
|
|
|
|
|
|
|
for param_slice in param_slice_list: |
|
|
|
checkpoint_list = Checkpoint() |
|
|
|
param_value = checkpoint_list.value.add() |
|
|
|
param_value.tag = name |
|
|
|
param_tensor = param_value.tensor |
|
|
|
param_tensor.dims.extend(value[0]) |
|
|
|
param_tensor.tensor_type = value[1] |
|
|
|
param_tensor.tensor_content = param_slice.tostring() |
|
|
|
|
|
|
|
f.write(checkpoint_list.SerializeToString()) |
|
|
|
|
|
|
|
os.chmod(ckpt_file_name, stat.S_IRUSR) |
|
|
|
|
|
|
|
except BaseException as e: |
|
|
|
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) |
|
|
|
raise RuntimeError(e.__str__()) |
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): |
|
|
|
""" |
|
|
|
Saves checkpoint info to a specified file. |
|
|
|
@@ -206,28 +221,37 @@ def load_checkpoint(ckpt_file_name, net=None): |
|
|
|
|
|
|
|
parameter_dict = {} |
|
|
|
try: |
|
|
|
element_id = 0 |
|
|
|
param_data_list = [] |
|
|
|
for element in checkpoint_list.value: |
|
|
|
data = element.tensor.tensor_content |
|
|
|
data_type = element.tensor.tensor_type |
|
|
|
np_type = tensor_to_np_type[data_type] |
|
|
|
ms_type = tensor_to_ms_type[data_type] |
|
|
|
param_data = np.fromstring(data, np_type) |
|
|
|
dims = element.tensor.dims |
|
|
|
|
|
|
|
if dims == [0]: |
|
|
|
if 'Float' in data_type: |
|
|
|
param_data = float(param_data[0]) |
|
|
|
elif 'Int' in data_type: |
|
|
|
param_data = int(param_data[0]) |
|
|
|
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) |
|
|
|
elif dims == [1]: |
|
|
|
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) |
|
|
|
else: |
|
|
|
param_dim = [] |
|
|
|
for dim in dims: |
|
|
|
param_dim.append(dim) |
|
|
|
param_value = param_data.reshape(param_dim) |
|
|
|
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) |
|
|
|
element_data = np.frombuffer(data, np_type) |
|
|
|
param_data_list.append(element_data) |
|
|
|
if (element_id == len(checkpoint_list.value) - 1) or \ |
|
|
|
(element.tag != checkpoint_list.value[element_id + 1].tag): |
|
|
|
param_data = np.concatenate((param_data_list), axis=0) |
|
|
|
param_data_list.clear() |
|
|
|
dims = element.tensor.dims |
|
|
|
|
|
|
|
if dims == [0]: |
|
|
|
if 'Float' in data_type: |
|
|
|
param_data = float(param_data[0]) |
|
|
|
elif 'Int' in data_type: |
|
|
|
param_data = int(param_data[0]) |
|
|
|
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) |
|
|
|
elif dims == [1]: |
|
|
|
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) |
|
|
|
else: |
|
|
|
param_dim = [] |
|
|
|
for dim in dims: |
|
|
|
param_dim.append(dim) |
|
|
|
param_value = param_data.reshape(param_dim) |
|
|
|
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) |
|
|
|
|
|
|
|
element_id += 1 |
|
|
|
|
|
|
|
logger.info("Load checkpoint process finish.") |
|
|
|
|
|
|
|
|