|
|
@@ -15,6 +15,7 @@ |
|
|
"""Model and parameters serialization.""" |
|
|
"""Model and parameters serialization.""" |
|
|
import os |
|
|
import os |
|
|
import stat |
|
|
import stat |
|
|
|
|
|
from threading import Thread, Lock |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
import mindspore.nn as nn |
|
|
import mindspore.nn as nn |
|
|
@@ -40,6 +41,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin |
|
|
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, |
|
|
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, |
|
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} |
|
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} |
|
|
|
|
|
|
|
|
|
|
|
_ckpt_mutex = Lock() |
|
|
|
|
|
|
|
|
def _special_process_par(par, new_par): |
|
|
def _special_process_par(par, new_par): |
|
|
""" |
|
|
""" |
|
|
@@ -101,7 +103,29 @@ def _update_param(param, new_param): |
|
|
param.set_parameter_data(type(param.data)(new_param.data)) |
|
|
param.set_parameter_data(type(param.data)(new_param.data)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(parameter_list, ckpt_file_name): |
|
|
|
|
|
|
|
|
def _exec_save(ckpt_file_name, data_list): |
|
|
|
|
|
"""Execute save checkpoint into file process.""" |
|
|
|
|
|
checkpoint_list = Checkpoint() |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
with _ckpt_mutex: |
|
|
|
|
|
for name, value in data_list.items(): |
|
|
|
|
|
param_value = checkpoint_list.value.add() |
|
|
|
|
|
param_value.tag = name |
|
|
|
|
|
param_tensor = param_value.tensor |
|
|
|
|
|
param_tensor.dims.extend(value[0]) |
|
|
|
|
|
param_tensor.tensor_type = value[1] |
|
|
|
|
|
param_tensor.tensor_content = value[2].tostring() |
|
|
|
|
|
|
|
|
|
|
|
with open(ckpt_file_name, "wb") as f: |
|
|
|
|
|
f.write(checkpoint_list.SerializeToString()) |
|
|
|
|
|
os.chmod(ckpt_file_name, stat.S_IRUSR) |
|
|
|
|
|
|
|
|
|
|
|
except BaseException as e: |
|
|
|
|
|
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) |
|
|
|
|
|
raise RuntimeError(e.__str__()) |
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): |
|
|
""" |
|
|
""" |
|
|
Saves checkpoint info to a specified file. |
|
|
Saves checkpoint info to a specified file. |
|
|
|
|
|
|
|
|
@@ -109,37 +133,37 @@ def save_checkpoint(parameter_list, ckpt_file_name): |
|
|
parameter_list (list): Parameters list, each element is a dict |
|
|
parameter_list (list): Parameters list, each element is a dict |
|
|
like {"name":xx, "type":xx, "shape":xx, "data":xx}. |
|
|
like {"name":xx, "type":xx, "shape":xx, "data":xx}. |
|
|
ckpt_file_name (str): Checkpoint file name. |
|
|
ckpt_file_name (str): Checkpoint file name. |
|
|
|
|
|
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False |
|
|
|
|
|
|
|
|
Raises: |
|
|
Raises: |
|
|
RuntimeError: Failed to save the Checkpoint file. |
|
|
RuntimeError: Failed to save the Checkpoint file. |
|
|
""" |
|
|
""" |
|
|
logger.info("Execute save checkpoint process.") |
|
|
logger.info("Execute save checkpoint process.") |
|
|
checkpoint_list = Checkpoint() |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
data_list = {} |
|
|
|
|
|
with _ckpt_mutex: |
|
|
for param in parameter_list: |
|
|
for param in parameter_list: |
|
|
param_value = checkpoint_list.value.add() |
|
|
|
|
|
param_value.tag = param["name"] |
|
|
|
|
|
param_tensor = param_value.tensor |
|
|
|
|
|
|
|
|
key = param["name"] |
|
|
|
|
|
data_list[key] = [] |
|
|
if isinstance(param["data"], Parameter): |
|
|
if isinstance(param["data"], Parameter): |
|
|
param["data"].init_data() |
|
|
param["data"].init_data() |
|
|
param_data = param["data"].asnumpy().reshape(-1) |
|
|
|
|
|
param_tensor.tensor_content = param_data.tostring() |
|
|
|
|
|
param_tensor.tensor_type = str(param["data"].dtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dims = [] |
|
|
if param['data'].shape == (): |
|
|
if param['data'].shape == (): |
|
|
param_tensor.dims.append(0) |
|
|
|
|
|
|
|
|
dims.append(0) |
|
|
else: |
|
|
else: |
|
|
for dim in param['data'].shape: |
|
|
for dim in param['data'].shape: |
|
|
param_tensor.dims.append(dim) |
|
|
|
|
|
|
|
|
|
|
|
with open(ckpt_file_name, "wb") as f: |
|
|
|
|
|
f.write(checkpoint_list.SerializeToString()) |
|
|
|
|
|
os.chmod(ckpt_file_name, stat.S_IRUSR) |
|
|
|
|
|
|
|
|
|
|
|
except BaseException as e: |
|
|
|
|
|
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) |
|
|
|
|
|
raise RuntimeError(e.__str__()) |
|
|
|
|
|
|
|
|
dims.append(dim) |
|
|
|
|
|
data_list[key].append(dims) |
|
|
|
|
|
tensor_type = str(param["data"].dtype) |
|
|
|
|
|
data_list[key].append(tensor_type) |
|
|
|
|
|
data = param["data"].asnumpy().reshape(-1) |
|
|
|
|
|
data_list[key].append(data) |
|
|
|
|
|
|
|
|
|
|
|
if async_save: |
|
|
|
|
|
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list)) |
|
|
|
|
|
thr.start() |
|
|
|
|
|
else: |
|
|
|
|
|
_exec_save(ckpt_file_name, data_list) |
|
|
logger.info("Save checkpoint process finish.") |
|
|
logger.info("Save checkpoint process finish.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -305,7 +329,7 @@ def _save_graph(network, file_name): |
|
|
os.chmod(file_name, stat.S_IRUSR) |
|
|
os.chmod(file_name, stat.S_IRUSR) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): |
|
|
|
|
|
|
|
|
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False): |
|
|
""" |
|
|
""" |
|
|
Saves checkpoint for 'ms' backend. |
|
|
Saves checkpoint for 'ms' backend. |
|
|
|
|
|
|
|
|
@@ -313,6 +337,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): |
|
|
train_network (Network): The train network for training. |
|
|
train_network (Network): The train network for training. |
|
|
ckpt_file_name (str): The name of checkpoint file. |
|
|
ckpt_file_name (str): The name of checkpoint file. |
|
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene. |
|
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene. |
|
|
|
|
|
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
param_dict = {} |
|
|
param_dict = {} |
|
|
@@ -336,7 +361,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): |
|
|
each_param["data"] = param_data |
|
|
each_param["data"] = param_data |
|
|
param_list.append(each_param) |
|
|
param_list.append(each_param) |
|
|
|
|
|
|
|
|
save_checkpoint(param_list, ckpt_file_name) |
|
|
|
|
|
|
|
|
save_checkpoint(param_list, ckpt_file_name, async_save) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_merged_param_data(net, param_name, param_data): |
|
|
def _get_merged_param_data(net, param_name, param_data): |
|
|
|