|
|
|
@@ -164,7 +164,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F |
|
|
|
integrated_save = Validator.check_bool(integrated_save) |
|
|
|
async_save = Validator.check_bool(async_save) |
|
|
|
|
|
|
|
logger.info("Execute the process of saving checkpoint.") |
|
|
|
logger.info("Execute the process of saving checkpoint files.") |
|
|
|
|
|
|
|
if isinstance(save_obj, nn.Cell): |
|
|
|
save_obj.init_parameters_data() |
|
|
|
@@ -210,7 +210,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F |
|
|
|
else: |
|
|
|
_exec_save(ckpt_file_name, data_list) |
|
|
|
|
|
|
|
logger.info("Saving checkpoint process finished.") |
|
|
|
logger.info("Saving checkpoint process is finished.") |
|
|
|
|
|
|
|
|
|
|
|
def _check_param_prefix(filter_prefix, param_name): |
|
|
|
@@ -269,7 +269,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N |
|
|
|
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], " |
|
|
|
f"but got {str(type(prefix))} at index {index}.") |
|
|
|
|
|
|
|
logger.info("Execute the process of loading checkpoint.") |
|
|
|
logger.info("Execute the process of loading checkpoint files.") |
|
|
|
checkpoint_list = Checkpoint() |
|
|
|
|
|
|
|
try: |
|
|
|
@@ -313,7 +313,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N |
|
|
|
param_value = param_data.reshape(param_dim) |
|
|
|
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) |
|
|
|
|
|
|
|
logger.info("Loading checkpoint process finished.") |
|
|
|
logger.info("Loading checkpoint files process is finished.") |
|
|
|
|
|
|
|
except BaseException as e: |
|
|
|
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) |
|
|
|
@@ -358,7 +358,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False): |
|
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
strict_load = Validator.check_bool(strict_load) |
|
|
|
logger.info("Execute the process of loading parameter into net.") |
|
|
|
logger.info("Execute the process of loading parameters into net.") |
|
|
|
net.init_parameters_data() |
|
|
|
param_not_load = [] |
|
|
|
for _, param in net.parameters_and_names(): |
|
|
|
@@ -379,7 +379,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False): |
|
|
|
for param_name in param_not_load: |
|
|
|
logger.debug("%s", param_name) |
|
|
|
|
|
|
|
logger.info("Loading parameter into net finished.") |
|
|
|
logger.info("Loading parameters into net is finished.") |
|
|
|
if param_not_load: |
|
|
|
logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load))) |
|
|
|
return param_not_load |
|
|
|
|