|
|
|
@@ -638,9 +638,12 @@ def _save_mindir(net, file_name, *inputs): |
|
|
|
else: |
|
|
|
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.") |
|
|
|
# save parameter |
|
|
|
file_prefix = file_name.split("/")[-1] |
|
|
|
if file_prefix.endswith(".mindir"): |
|
|
|
file_prefix = file_prefix[:-7] |
|
|
|
current_path = os.path.abspath(file_name) |
|
|
|
dirname = os.path.dirname(current_path) |
|
|
|
data_path = dirname + "/variables" |
|
|
|
data_path = dirname + "/" + file_prefix + "_variables" |
|
|
|
if os.path.exists(data_path): |
|
|
|
shutil.rmtree(data_path) |
|
|
|
os.makedirs(data_path, exist_ok=True) |
|
|
|
@@ -675,7 +678,7 @@ def _save_mindir(net, file_name, *inputs): |
|
|
|
|
|
|
|
# save graph |
|
|
|
del model.graph.parameter[:] |
|
|
|
graph_file_name = file_name + "_graph.mindir" |
|
|
|
graph_file_name = dirname + "/" + file_prefix + "_graph.mindir" |
|
|
|
with open(graph_file_name, 'wb') as f: |
|
|
|
os.chmod(graph_file_name, stat.S_IWUSR | stat.S_IRUSR) |
|
|
|
f.write(model.SerializeToString()) |
|
|
|
@@ -1147,7 +1150,6 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy): |
|
|
|
|
|
|
|
def _load_single_param(ckpt_file_name, param_name): |
|
|
|
"""Load a parameter from checkpoint.""" |
|
|
|
logger.info("Execute the process of loading checkpoint files.") |
|
|
|
checkpoint_list = Checkpoint() |
|
|
|
|
|
|
|
try: |
|
|
|
@@ -1155,7 +1157,8 @@ def _load_single_param(ckpt_file_name, param_name): |
|
|
|
pb_content = f.read() |
|
|
|
checkpoint_list.ParseFromString(pb_content) |
|
|
|
except BaseException as e: |
|
|
|
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) |
|
|
|
logger.error("Failed to read the checkpoint file `%s` during load single parameter," |
|
|
|
" please check the correct of the file.", ckpt_file_name) |
|
|
|
raise ValueError(e.__str__()) |
|
|
|
|
|
|
|
parameter = None |
|
|
|
@@ -1189,8 +1192,7 @@ def _load_single_param(ckpt_file_name, param_name): |
|
|
|
param_dim.append(dim) |
|
|
|
param_value = param_data.reshape(param_dim) |
|
|
|
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) |
|
|
|
break |
|
|
|
logger.info("Loading checkpoint files process is finished.") |
|
|
|
break |
|
|
|
|
|
|
|
except BaseException as e: |
|
|
|
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) |
|
|
|
|