Browse Source

modify ckpt type error

feature/build-system-rewrite
changzherui 4 years ago
parent
commit
05dca07cf6
3 changed files with 127 additions and 24 deletions
  1. +12
    -10
      mindspore/python/mindspore/train/serialization.py
  2. +114
    -13
      tests/ut/python/utils/test_callback.py
  3. +1
    -1
      tests/ut/python/utils/test_serialize.py

+ 12
- 10
mindspore/python/mindspore/train/serialization.py View File

@@ -242,8 +242,8 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
raise TypeError("For 'save_checkpoint', the argument 'save_obj' should be nn.Cell or list, "
"but got {}.".format(type(save_obj)))
if not isinstance(ckpt_file_name, str):
raise ValueError("The argument {} for checkpoint file name is invalid, 'ckpt_file_name' must be "
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
raise TypeError("The argument {} for checkpoint file name is invalid, 'ckpt_file_name' must be "
"string, but got {}.".format(ckpt_file_name, type(ckpt_file_name)))
if not ckpt_file_name.endswith('.ckpt'):
ckpt_file_name += ".ckpt"
integrated_save = Validator.check_bool(integrated_save)
@@ -507,8 +507,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
"""Check function load_checkpoint's parameter."""
if not isinstance(ckpt_file_name, str):
raise ValueError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
"but got {}.".format(type(ckpt_file_name)))
raise TypeError("For 'load_checkpoint', the argument 'ckpt_file_name' must be string, "
"but got {}.".format(type(ckpt_file_name)))

if ckpt_file_name[-5:] != ".ckpt":
raise ValueError("For 'load_checkpoint', the checkpoint file should end with '.ckpt', please "
@@ -543,7 +543,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):

Args:
net (Cell): The network where the parameters will be loaded.
parameter_dict (dict): The dictionary generated by load checkpoint file.
parameter_dict (dict): The dictionary generated by load checkpoint file,
it is a dictionary consisting of key: parameters's name, value: parameter.
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
into net when parameter name's suffix in checkpoint file is the same as the
parameter in the network. When the types are inconsistent perform type conversion
@@ -575,6 +576,12 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
"but got {}.".format(type(parameter_dict)))
raise TypeError(msg)
for key, value in parameter_dict.items():
if not isinstance(key, str) or not isinstance(value, Parameter):
logger.critical("Load parameters into net failed.")
msg = ("For 'parameter_dict', the element in the argument 'parameter_dict' should be a "
"'str' and 'Parameter' , but got {} and {}.".format(type(key), type(value)))
raise TypeError(msg)

strict_load = Validator.check_bool(strict_load)
logger.info("Execute the process of loading parameters into net.")
@@ -583,11 +590,6 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
for _, param in net.parameters_and_names():
if param.name in parameter_dict:
new_param = copy.deepcopy(parameter_dict[param.name])
if not isinstance(new_param, Parameter):
logger.critical("Failed to combine the net and the parameters.")
msg = ("For 'load_param_into_net', the element in the argument 'parameter_dict' should be a "
"'Parameter', but got {}.".format(type(new_param)))
raise TypeError(msg)
_update_param(param, new_param, strict_load)
else:
param_not_load.append(param.name)


+ 114
- 13
tests/ut/python/utils/test_callback.py View File

@@ -78,7 +78,11 @@ class LossNet(nn.Cell):


def test_model_checkpoint_prefix_invalid():
"""Test ModelCheckpoint prefix invalid."""
"""
Feature: callback
Description: Test ModelCheckpoint prefix invalid
Expectation: run success
"""
with pytest.raises(ValueError):
ModelCheckpoint(123)
ModelCheckpoint(directory="./")
@@ -89,7 +93,11 @@ def test_model_checkpoint_prefix_invalid():


def test_loss_monitor_sink_mode():
"""Test loss monitor sink mode."""
"""
Feature: callback
Description: Test loss monitor sink mode
Expectation: run success
"""
cb_params = _InternalCallbackParam()
cb_params.cur_epoch_num = 4
cb_params.epoch_num = 4
@@ -109,7 +117,11 @@ def test_loss_monitor_sink_mode():


def test_loss_monitor_normal_mode():
"""Test loss monitor normal(non-sink) mode."""
"""
Feature: callback
Description: Test loss monitor normal(non-sink) mode
Expectation: run success
"""
cb_params = _InternalCallbackParam()
run_context = RunContext(cb_params)
loss_cb = LossMonitor(1)
@@ -126,6 +138,40 @@ def test_loss_monitor_normal_mode():
loss_cb.end(run_context)


def test_loss_monitor_args():
"""
Feature: callback
Description: Test loss monitor illegal args
Expectation: run success
"""
with pytest.raises(ValueError):
LossMonitor(per_print_times=-1)
with pytest.raises(ValueError):
LossMonitor(has_trained_epoch=-100)


def test_loss_monitor_has_trained_epoch():
"""
Feature: callback
Description: Test loss monitor has_trained_epoch args
Expectation: run success
"""
cb_params = _InternalCallbackParam()
run_context = RunContext(cb_params)
loss_cb = LossMonitor(has_trained_epoch=10)
cb_params.cur_epoch_num = 4
cb_params.cur_step_num = 1
cb_params.batch_num = 1
cb_params.net_outputs = Tensor(2.0)
cb_params.epoch_num = 4
loss_cb.begin(run_context)
loss_cb.epoch_begin(run_context)
loss_cb.step_begin(run_context)
loss_cb.step_end(run_context)
loss_cb.epoch_end(run_context)
loss_cb.end(run_context)


def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist():
"""
Feature: Save checkpoint and check if there is a file with the same name.
@@ -159,7 +205,11 @@ def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist():


def test_checkpoint_cb_for_save_op():
"""Test checkpoint cb for save op."""
"""
Feature: callback
Description: Test checkpoint cb for save op
Expectation: run success
"""
parameter_list = []
one_param = {}
one_param['name'] = "conv1.weight"
@@ -169,7 +219,11 @@ def test_checkpoint_cb_for_save_op():


def test_checkpoint_cb_for_save_op_update_net():
"""Test checkpoint cb for save op."""
"""
Feature: callback
Description: Test checkpoint cb for save op
Expectation: run success
"""
parameter_list = []
one_param = {}
one_param['name'] = "conv.weight"
@@ -182,7 +236,11 @@ def test_checkpoint_cb_for_save_op_update_net():


def test_internal_callback_param():
"""Test Internal CallbackParam."""
"""
Feature: callback
Description: Test Internal CallbackParam
Expectation: run success
"""
cb_params = _InternalCallbackParam()
cb_params.member1 = 1
cb_params.member2 = "abc"
@@ -191,7 +249,11 @@ def test_internal_callback_param():


def test_checkpoint_save_ckpt_steps():
"""Test checkpoint save ckpt steps."""
"""
Feature: callback
Description: Test checkpoint save ckpt steps
Expectation: run success
"""
train_config = CheckpointConfig(
save_checkpoint_steps=16,
save_checkpoint_seconds=0,
@@ -220,7 +282,11 @@ def test_checkpoint_save_ckpt_steps():


def test_checkpoint_save_ckpt_seconds():
"""Test checkpoint save ckpt seconds."""
"""
Feature: callback
Description: Test checkpoint save ckpt seconds
Expectation: run success
"""
train_config = CheckpointConfig(
save_checkpoint_steps=16,
save_checkpoint_seconds=100,
@@ -249,7 +315,11 @@ def test_checkpoint_save_ckpt_seconds():


def test_checkpoint_save_ckpt_with_encryption():
"""Test checkpoint save ckpt with encryption."""
"""
Feature: callback
Description: Test checkpoint save ckpt with encryption
Expectation: run success
"""
train_config = CheckpointConfig(
save_checkpoint_steps=16,
save_checkpoint_seconds=0,
@@ -286,7 +356,11 @@ def test_checkpoint_save_ckpt_with_encryption():


def test_CallbackManager():
"""TestCallbackManager."""
"""
Feature: callback
Description: Test CallbackManager
Expectation: run success
"""
ck_obj = ModelCheckpoint()
loss_cb_1 = LossMonitor(1)

@@ -304,6 +378,11 @@ def test_CallbackManager():


def test_CallbackManager_exit_called():
"""
Feature: callback
Description: Test CallbackManager exit called
Expectation: run success
"""
with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
cb1, cb2 = Callback(), Callback()
with _CallbackManager([cb1, cb2]):
@@ -314,6 +393,11 @@ def test_CallbackManager_exit_called():


def test_CallbackManager_exit_called_when_raises():
"""
Feature: callback
Description: Test when CallbackManager exit called
Expectation: run success
"""
with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit:
cb1, cb2 = Callback(), Callback()
with pytest.raises(ValueError):
@@ -325,6 +409,11 @@ def test_CallbackManager_exit_called_when_raises():


def test_CallbackManager_begin_called():
"""
Feature: callback
Description: Test CallbackManager called begin
Expectation: run success
"""
context = dict()
with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin:
cb1, cb2 = Callback(), Callback()
@@ -336,7 +425,11 @@ def test_CallbackManager_begin_called():


def test_RunContext():
"""Test RunContext."""
"""
Feature: callback
Description: Test RunContext init
Expectation: run success
"""
context_err = 666
with pytest.raises(TypeError):
RunContext(context_err)
@@ -356,7 +449,11 @@ def test_RunContext():


def test_Checkpoint_Config():
"""Test CheckpointConfig all None or 0."""
"""
Feature: callback
Description: Test checkpoint config error args
Expectation: run success
"""
with pytest.raises(ValueError):
CheckpointConfig(0, 0, 0, 0, True)

@@ -365,7 +462,11 @@ def test_Checkpoint_Config():


def test_step_end_save_graph():
"""Test save checkpoint."""
"""
Feature: callback
Description: Test save graph at step end
Expectation: run success
"""
train_config = CheckpointConfig(
save_checkpoint_steps=16,
save_checkpoint_seconds=0,


+ 1
- 1
tests/ut/python/utils/test_serialize.py View File

@@ -128,7 +128,7 @@ def test_load_checkpoint_error_filename():
"""
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
ckpt_file_name = 1
with pytest.raises(ValueError):
with pytest.raises(TypeError):
load_checkpoint(ckpt_file_name)




Loading…
Cancel
Save