| @@ -163,7 +163,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||
| integrated_save = Validator.check_bool(integrated_save) | |||
| async_save = Validator.check_bool(async_save) | |||
| logger.info("Execute the process of saving checkpoint.") | |||
| logger.info("Execute the process of saving checkpoint files.") | |||
| if isinstance(save_obj, nn.Cell): | |||
| save_obj.init_parameters_data() | |||
| @@ -209,7 +209,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||
| else: | |||
| _exec_save(ckpt_file_name, data_list) | |||
| logger.info("Saving checkpoint process finished.") | |||
| logger.info("Saving checkpoint process is finished.") | |||
| def _check_param_prefix(filter_prefix, param_name): | |||
| @@ -268,7 +268,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N | |||
| raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], " | |||
| f"but got {str(type(prefix))} at index {index}.") | |||
| logger.info("Execute the process of loading checkpoint.") | |||
| logger.info("Execute the process of loading checkpoint files.") | |||
| checkpoint_list = Checkpoint() | |||
| try: | |||
| @@ -312,7 +312,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N | |||
| param_value = param_data.reshape(param_dim) | |||
| parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) | |||
| logger.info("Loading checkpoint process finished.") | |||
| logger.info("Loading checkpoint files process is finished.") | |||
| except BaseException as e: | |||
| logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) | |||
| @@ -357,7 +357,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False): | |||
| raise TypeError(msg) | |||
| strict_load = Validator.check_bool(strict_load) | |||
| logger.info("Execute the process of loading parameter into net.") | |||
| logger.info("Execute the process of loading parameters into net.") | |||
| net.init_parameters_data() | |||
| param_not_load = [] | |||
| for _, param in net.parameters_and_names(): | |||
| @@ -378,7 +378,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False): | |||
| for param_name in param_not_load: | |||
| logger.debug("%s", param_name) | |||
| logger.info("Loading parameter into net finished.") | |||
| logger.info("Loading parameters into net is finished.") | |||
| if param_not_load: | |||
| logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load))) | |||
| return param_not_load | |||