|
|
|
@@ -78,7 +78,11 @@ class LossNet(nn.Cell): |
|
|
|
|
|
|
|
|
|
|
|
def test_model_checkpoint_prefix_invalid(): |
|
|
|
"""Test ModelCheckpoint prefix invalid.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test ModelCheckpoint prefix invalid |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
with pytest.raises(ValueError): |
|
|
|
ModelCheckpoint(123) |
|
|
|
ModelCheckpoint(directory="./") |
|
|
|
@@ -89,7 +93,11 @@ def test_model_checkpoint_prefix_invalid(): |
|
|
|
|
|
|
|
|
|
|
|
def test_loss_monitor_sink_mode(): |
|
|
|
"""Test loss monitor sink mode.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test loss monitor sink mode |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
cb_params = _InternalCallbackParam() |
|
|
|
cb_params.cur_epoch_num = 4 |
|
|
|
cb_params.epoch_num = 4 |
|
|
|
@@ -109,7 +117,11 @@ def test_loss_monitor_sink_mode(): |
|
|
|
|
|
|
|
|
|
|
|
def test_loss_monitor_normal_mode(): |
|
|
|
"""Test loss monitor normal(non-sink) mode.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test loss monitor normal(non-sink) mode |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
cb_params = _InternalCallbackParam() |
|
|
|
run_context = RunContext(cb_params) |
|
|
|
loss_cb = LossMonitor(1) |
|
|
|
@@ -126,6 +138,40 @@ def test_loss_monitor_normal_mode(): |
|
|
|
loss_cb.end(run_context) |
|
|
|
|
|
|
|
|
|
|
|
def test_loss_monitor_args(): |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test loss monitor illegal args |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
with pytest.raises(ValueError): |
|
|
|
LossMonitor(per_print_times=-1) |
|
|
|
with pytest.raises(ValueError): |
|
|
|
LossMonitor(has_trained_epoch=-100) |
|
|
|
|
|
|
|
|
|
|
|
def test_loss_monitor_has_trained_epoch(): |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test loss monitor has_trained_epoch args |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
cb_params = _InternalCallbackParam() |
|
|
|
run_context = RunContext(cb_params) |
|
|
|
loss_cb = LossMonitor(has_trained_epoch=10) |
|
|
|
cb_params.cur_epoch_num = 4 |
|
|
|
cb_params.cur_step_num = 1 |
|
|
|
cb_params.batch_num = 1 |
|
|
|
cb_params.net_outputs = Tensor(2.0) |
|
|
|
cb_params.epoch_num = 4 |
|
|
|
loss_cb.begin(run_context) |
|
|
|
loss_cb.epoch_begin(run_context) |
|
|
|
loss_cb.step_begin(run_context) |
|
|
|
loss_cb.step_end(run_context) |
|
|
|
loss_cb.epoch_end(run_context) |
|
|
|
loss_cb.end(run_context) |
|
|
|
|
|
|
|
|
|
|
|
def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist(): |
|
|
|
""" |
|
|
|
Feature: Save checkpoint and check if there is a file with the same name. |
|
|
|
@@ -159,7 +205,11 @@ def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist(): |
|
|
|
|
|
|
|
|
|
|
|
def test_checkpoint_cb_for_save_op(): |
|
|
|
"""Test checkpoint cb for save op.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test checkpoint cb for save op |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
parameter_list = [] |
|
|
|
one_param = {} |
|
|
|
one_param['name'] = "conv1.weight" |
|
|
|
@@ -169,7 +219,11 @@ def test_checkpoint_cb_for_save_op(): |
|
|
|
|
|
|
|
|
|
|
|
def test_checkpoint_cb_for_save_op_update_net(): |
|
|
|
"""Test checkpoint cb for save op.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test checkpoint cb for save op |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
parameter_list = [] |
|
|
|
one_param = {} |
|
|
|
one_param['name'] = "conv.weight" |
|
|
|
@@ -182,7 +236,11 @@ def test_checkpoint_cb_for_save_op_update_net(): |
|
|
|
|
|
|
|
|
|
|
|
def test_internal_callback_param(): |
|
|
|
"""Test Internal CallbackParam.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test Internal CallbackParam |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
cb_params = _InternalCallbackParam() |
|
|
|
cb_params.member1 = 1 |
|
|
|
cb_params.member2 = "abc" |
|
|
|
@@ -191,7 +249,11 @@ def test_internal_callback_param(): |
|
|
|
|
|
|
|
|
|
|
|
def test_checkpoint_save_ckpt_steps(): |
|
|
|
"""Test checkpoint save ckpt steps.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test checkpoint save ckpt steps |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
train_config = CheckpointConfig( |
|
|
|
save_checkpoint_steps=16, |
|
|
|
save_checkpoint_seconds=0, |
|
|
|
@@ -220,7 +282,11 @@ def test_checkpoint_save_ckpt_steps(): |
|
|
|
|
|
|
|
|
|
|
|
def test_checkpoint_save_ckpt_seconds(): |
|
|
|
"""Test checkpoint save ckpt seconds.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test checkpoint save ckpt seconds |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
train_config = CheckpointConfig( |
|
|
|
save_checkpoint_steps=16, |
|
|
|
save_checkpoint_seconds=100, |
|
|
|
@@ -249,7 +315,11 @@ def test_checkpoint_save_ckpt_seconds(): |
|
|
|
|
|
|
|
|
|
|
|
def test_checkpoint_save_ckpt_with_encryption(): |
|
|
|
"""Test checkpoint save ckpt with encryption.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test checkpoint save ckpt with encryption |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
train_config = CheckpointConfig( |
|
|
|
save_checkpoint_steps=16, |
|
|
|
save_checkpoint_seconds=0, |
|
|
|
@@ -286,7 +356,11 @@ def test_checkpoint_save_ckpt_with_encryption(): |
|
|
|
|
|
|
|
|
|
|
|
def test_CallbackManager(): |
|
|
|
"""TestCallbackManager.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test CallbackManager |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
ck_obj = ModelCheckpoint() |
|
|
|
loss_cb_1 = LossMonitor(1) |
|
|
|
|
|
|
|
@@ -304,6 +378,11 @@ def test_CallbackManager(): |
|
|
|
|
|
|
|
|
|
|
|
def test_CallbackManager_exit_called(): |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test CallbackManager exit called |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: |
|
|
|
cb1, cb2 = Callback(), Callback() |
|
|
|
with _CallbackManager([cb1, cb2]): |
|
|
|
@@ -314,6 +393,11 @@ def test_CallbackManager_exit_called(): |
|
|
|
|
|
|
|
|
|
|
|
def test_CallbackManager_exit_called_when_raises(): |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test when CallbackManager exit called |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
with mock.patch.object(Callback, '__exit__', return_value=None) as mock_exit: |
|
|
|
cb1, cb2 = Callback(), Callback() |
|
|
|
with pytest.raises(ValueError): |
|
|
|
@@ -325,6 +409,11 @@ def test_CallbackManager_exit_called_when_raises(): |
|
|
|
|
|
|
|
|
|
|
|
def test_CallbackManager_begin_called(): |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test CallbackManager called begin |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
context = dict() |
|
|
|
with mock.patch.object(Callback, 'begin', return_value=None) as mock_begin: |
|
|
|
cb1, cb2 = Callback(), Callback() |
|
|
|
@@ -336,7 +425,11 @@ def test_CallbackManager_begin_called(): |
|
|
|
|
|
|
|
|
|
|
|
def test_RunContext(): |
|
|
|
"""Test RunContext.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test RunContext init |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
context_err = 666 |
|
|
|
with pytest.raises(TypeError): |
|
|
|
RunContext(context_err) |
|
|
|
@@ -356,7 +449,11 @@ def test_RunContext(): |
|
|
|
|
|
|
|
|
|
|
|
def test_Checkpoint_Config(): |
|
|
|
"""Test CheckpointConfig all None or 0.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test checkpoint config error args |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
with pytest.raises(ValueError): |
|
|
|
CheckpointConfig(0, 0, 0, 0, True) |
|
|
|
|
|
|
|
@@ -365,7 +462,11 @@ def test_Checkpoint_Config(): |
|
|
|
|
|
|
|
|
|
|
|
def test_step_end_save_graph(): |
|
|
|
"""Test save checkpoint.""" |
|
|
|
""" |
|
|
|
Feature: callback |
|
|
|
Description: Test save graph at step end |
|
|
|
Expectation: run success |
|
|
|
""" |
|
|
|
train_config = CheckpointConfig( |
|
|
|
save_checkpoint_steps=16, |
|
|
|
save_checkpoint_seconds=0, |
|
|
|
|