|
|
|
@@ -82,23 +82,26 @@ def _special_process_par(par, new_par): |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def _update_param(param, new_param): |
|
|
|
def _update_param(param, new_param, strict_load): |
|
|
|
"""Updates param's data from new_param's data.""" |
|
|
|
|
|
|
|
if isinstance(param.data, Tensor) and isinstance(new_param.data, Tensor): |
|
|
|
if param.data.dtype != new_param.data.dtype: |
|
|
|
logger.error("Failed to combine the net and the parameters for param %s.", param.name) |
|
|
|
msg = ("Net parameters {} type({}) different from parameter_dict's({})" |
|
|
|
.format(param.name, param.data.dtype, new_param.data.dtype)) |
|
|
|
raise RuntimeError(msg) |
|
|
|
|
|
|
|
if param.data.shape != new_param.data.shape: |
|
|
|
if not _special_process_par(param, new_param): |
|
|
|
logger.error("Failed to combine the net and the parameters for param %s.", param.name) |
|
|
|
msg = ("Net parameters {} shape({}) different from parameter_dict's({})" |
|
|
|
.format(param.name, param.data.shape, new_param.data.shape)) |
|
|
|
raise RuntimeError(msg) |
|
|
|
return |
|
|
|
|
|
|
|
if param.data.dtype != new_param.data.dtype: |
|
|
|
if _type_convert(param, new_param, strict_load): |
|
|
|
new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype) |
|
|
|
param.set_data(new_tensor) |
|
|
|
return |
|
|
|
|
|
|
|
logger.error("Failed to combine the net and the parameters for param %s.", param.name) |
|
|
|
msg = ("Net parameters {} type({}) different from parameter_dict's({})" |
|
|
|
.format(param.name, param.data.dtype, new_param.data.dtype)) |
|
|
|
raise RuntimeError(msg) |
|
|
|
|
|
|
|
param.set_data(new_param.data) |
|
|
|
return |
|
|
|
@@ -121,11 +124,21 @@ def _update_param(param, new_param): |
|
|
|
param.set_data(type(param.data)(new_param.data)) |
|
|
|
|
|
|
|
|
|
|
|
def _type_convert(param, new_param, strict_load): |
|
|
|
"""Whether to convert parameter's type during load checkpoint into network.""" |
|
|
|
float_type = (mstype.float16, mstype.float32, mstype.float64) |
|
|
|
int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64) |
|
|
|
if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or |
|
|
|
{param.data.dtype, new_param.data.dtype}.issubset(int_type)): |
|
|
|
logger.warning("ckpt_dict parameter: {}'s type is {}, convert to {} in the network." |
|
|
|
.format(new_param.name, new_param.data.dtype, param.data.dtype)) |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"): |
|
|
|
"""Execute the process of saving checkpoint into file.""" |
|
|
|
|
|
|
|
try: |
|
|
|
MAX_BLOCK_SIZE = 1024*1024*512 |
|
|
|
with _ckpt_mutex: |
|
|
|
if os.path.exists(ckpt_file_name): |
|
|
|
os.remove(ckpt_file_name) |
|
|
|
@@ -155,10 +168,10 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM"): |
|
|
|
f.write(checkpoint_list.SerializeToString()) |
|
|
|
else: |
|
|
|
plain_data += checkpoint_list.SerializeToString() |
|
|
|
while len(plain_data) >= MAX_BLOCK_SIZE: |
|
|
|
cipher_data += _encrypt(plain_data[0: MAX_BLOCK_SIZE], MAX_BLOCK_SIZE, enc_key, |
|
|
|
while len(plain_data) >= SLICE_SIZE * 1024: |
|
|
|
cipher_data += _encrypt(plain_data[0: SLICE_SIZE*1024], SLICE_SIZE*1024, enc_key, |
|
|
|
len(enc_key), enc_mode) |
|
|
|
plain_data = plain_data[MAX_BLOCK_SIZE:] |
|
|
|
plain_data = plain_data[SLICE_SIZE*1024:] |
|
|
|
|
|
|
|
if enc_key is not None: |
|
|
|
if plain_data: |
|
|
|
@@ -310,7 +323,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N |
|
|
|
ckpt_file_name (str): Checkpoint file name. |
|
|
|
net (Cell): Cell network. Default: None |
|
|
|
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter |
|
|
|
in the param_dict into net with the same suffix. Default: False |
|
|
|
in the param_dict into net with the same suffix and load |
|
|
|
parameter with different accuracy. Default: False. |
|
|
|
filter_prefix (Union[str, list[str], tuple[str]]): Parameters starting with the filter_prefix |
|
|
|
will not be loaded. Default: None. |
|
|
|
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption |
|
|
|
@@ -469,12 +483,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False): |
|
|
|
logger.error("Failed to combine the net and the parameters.") |
|
|
|
msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param))) |
|
|
|
raise TypeError(msg) |
|
|
|
_update_param(param, new_param) |
|
|
|
_update_param(param, new_param, strict_load) |
|
|
|
else: |
|
|
|
param_not_load.append(param.name) |
|
|
|
|
|
|
|
if param_not_load and not strict_load: |
|
|
|
_load_dismatch_prefix_params(net, parameter_dict, param_not_load) |
|
|
|
_load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load) |
|
|
|
|
|
|
|
logger.debug("Params not matched(in net but not in parameter_dict):") |
|
|
|
for param_name in param_not_load: |
|
|
|
@@ -486,7 +500,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False): |
|
|
|
return param_not_load |
|
|
|
|
|
|
|
|
|
|
|
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load): |
|
|
|
def _load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load): |
|
|
|
"""When some net parameter did not load, try to continue load.""" |
|
|
|
prefix_name = "" |
|
|
|
longest_name = param_not_load[0] |
|
|
|
@@ -507,7 +521,7 @@ def _load_dismatch_prefix_params(net, parameter_dict, param_not_load): |
|
|
|
new_param_name = prefix_name + param.name |
|
|
|
if param.name in param_not_load and new_param_name in parameter_dict: |
|
|
|
new_param = parameter_dict[new_param_name] |
|
|
|
_update_param(param, new_param) |
|
|
|
_update_param(param, new_param, strict_load) |
|
|
|
param_not_load.remove(param.name) |
|
|
|
|
|
|
|
|
|
|
|
|