|
|
|
@@ -88,33 +88,6 @@ def test_model_checkpoint_prefix_invalid(): |
|
|
|
ModelCheckpoint(prefix="ckpt_2", directory="./test_files") |
|
|
|
|
|
|
|
|
|
|
|
def test_save_checkpoint(): |
|
|
|
"""Test save checkpoint.""" |
|
|
|
train_config = CheckpointConfig( |
|
|
|
save_checkpoint_steps=16, |
|
|
|
save_checkpoint_seconds=0, |
|
|
|
keep_checkpoint_max=5, |
|
|
|
keep_checkpoint_per_n_minutes=0) |
|
|
|
cb_params = _InternalCallbackParam() |
|
|
|
net = Net() |
|
|
|
loss = nn.SoftmaxCrossEntropyWithLogits() |
|
|
|
optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) |
|
|
|
network_ = WithLossCell(net, loss) |
|
|
|
_train_network = TrainOneStepCell(network_, optim) |
|
|
|
cb_params.train_network = _train_network |
|
|
|
cb_params.epoch_num = 10 |
|
|
|
cb_params.cur_epoch_num = 5 |
|
|
|
cb_params.cur_step_num = 0 |
|
|
|
cb_params.batch_num = 32 |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config) |
|
|
|
run_context = RunContext(cb_params) |
|
|
|
ckpoint_cb.begin(run_context) |
|
|
|
ckpoint_cb.step_end(run_context) |
|
|
|
if os.path.exists('./test_files/test_ckpt-model.pkl'): |
|
|
|
os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE) |
|
|
|
os.remove('./test_files/test_ckpt-model.pkl') |
|
|
|
|
|
|
|
|
|
|
|
def test_loss_monitor_sink_mode(): |
|
|
|
"""Test loss monitor sink mode.""" |
|
|
|
cb_params = _InternalCallbackParam() |
|
|
|
@@ -153,8 +126,35 @@ def test_loss_monitor_normal_mode(): |
|
|
|
loss_cb.end(run_context) |
|
|
|
|
|
|
|
|
|
|
|
def test_chg_ckpt_file_name_if_same_exist(): |
|
|
|
"""Test chg ckpt file name if same exist.""" |
|
|
|
def test_save_ckpt_and_test_chg_ckpt_file_name_if_same_exist(): |
|
|
|
""" |
|
|
|
Feature: Save checkpoint and check if there is a file with the same name. |
|
|
|
Description: Save checkpoint and check if there is a file with the same name. |
|
|
|
Expectation: Checkpoint is saved and checking is successful. |
|
|
|
""" |
|
|
|
train_config = CheckpointConfig( |
|
|
|
save_checkpoint_steps=16, |
|
|
|
save_checkpoint_seconds=0, |
|
|
|
keep_checkpoint_max=5, |
|
|
|
keep_checkpoint_per_n_minutes=0) |
|
|
|
cb_params = _InternalCallbackParam() |
|
|
|
net = Net() |
|
|
|
loss = nn.SoftmaxCrossEntropyWithLogits() |
|
|
|
optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) |
|
|
|
network_ = WithLossCell(net, loss) |
|
|
|
_train_network = TrainOneStepCell(network_, optim) |
|
|
|
cb_params.train_network = _train_network |
|
|
|
cb_params.epoch_num = 10 |
|
|
|
cb_params.cur_epoch_num = 5 |
|
|
|
cb_params.cur_step_num = 0 |
|
|
|
cb_params.batch_num = 32 |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="test_ckpt", directory='./test_files', config=train_config) |
|
|
|
run_context = RunContext(cb_params) |
|
|
|
ckpoint_cb.begin(run_context) |
|
|
|
ckpoint_cb.step_end(run_context) |
|
|
|
if os.path.exists('./test_files/test_ckpt-model.pkl'): |
|
|
|
os.chmod('./test_files/test_ckpt-model.pkl', stat.S_IWRITE) |
|
|
|
os.remove('./test_files/test_ckpt-model.pkl') |
|
|
|
_chg_ckpt_file_name_if_same_exist(directory="./test_files", prefix="ckpt") |
|
|
|
|
|
|
|
|
|
|
|
|