Browse Source

optimize verb description

tags/v1.1.0
caozhou 5 years ago
parent
commit
b46ffced99
1 changed files with 8 additions and 8 deletions
  1. +8
    -8
      mindspore/train/serialization.py

+ 8
- 8
mindspore/train/serialization.py View File

@@ -108,7 +108,7 @@ def _update_param(param, new_param):




def _exec_save(ckpt_file_name, data_list): def _exec_save(ckpt_file_name, data_list):
"""Execute save checkpoint into file process."""
"""Execute the process of saving checkpoint into file."""


try: try:
with _ckpt_mutex: with _ckpt_mutex:
@@ -163,7 +163,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
integrated_save = Validator.check_bool(integrated_save) integrated_save = Validator.check_bool(integrated_save)
async_save = Validator.check_bool(async_save) async_save = Validator.check_bool(async_save)


logger.info("Execute save checkpoint process.")
logger.info("Execute the process of saving checkpoint.")


if isinstance(save_obj, nn.Cell): if isinstance(save_obj, nn.Cell):
save_obj.init_parameters_data() save_obj.init_parameters_data()
@@ -209,7 +209,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
else: else:
_exec_save(ckpt_file_name, data_list) _exec_save(ckpt_file_name, data_list)


logger.info("Save checkpoint process finish.")
logger.info("Saving checkpoint process finished.")




def _check_param_prefix(filter_prefix, param_name): def _check_param_prefix(filter_prefix, param_name):
@@ -268,7 +268,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], " raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
f"but got {str(type(prefix))} at index {index}.") f"but got {str(type(prefix))} at index {index}.")


logger.info("Execute load checkpoint process.")
logger.info("Execute the process of loading checkpoint.")
checkpoint_list = Checkpoint() checkpoint_list = Checkpoint()


try: try:
@@ -312,7 +312,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
param_value = param_data.reshape(param_dim) param_value = param_data.reshape(param_dim)
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)


logger.info("Load checkpoint process finish.")
logger.info("Loading checkpoint process finished.")


except BaseException as e: except BaseException as e:
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
@@ -356,7 +356,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
raise TypeError(msg) raise TypeError(msg)


strict_load = Validator.check_bool(strict_load) strict_load = Validator.check_bool(strict_load)
logger.info("Execute load parameter into net process.")
logger.info("Execute the process of loading parameter into net.")
net.init_parameters_data() net.init_parameters_data()
param_not_load = [] param_not_load = []
for _, param in net.parameters_and_names(): for _, param in net.parameters_and_names():
@@ -377,7 +377,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
for param_name in param_not_load: for param_name in param_not_load:
logger.debug("%s", param_name) logger.debug("%s", param_name)


logger.info("Load parameter into net finish.")
logger.info("Loading parameter into net finished.")
if param_not_load: if param_not_load:
logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load))) logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load)))
return param_not_load return param_not_load
@@ -416,7 +416,7 @@ def _save_graph(network, file_name):
network (Cell): Obtain a pipeline through network for saving graph. network (Cell): Obtain a pipeline through network for saving graph.
file_name (str): Graph file name into which the graph will be saved. file_name (str): Graph file name into which the graph will be saved.
""" """
logger.info("Execute save the graph process.")
logger.info("Execute the process of saving graph.")


graph_proto = network.get_func_graph_proto() graph_proto = network.get_func_graph_proto()
if graph_proto: if graph_proto:


Loading…
Cancel
Save