| @@ -14,7 +14,7 @@ | |||||
| """ test_training """ | """ test_training """ | ||||
| import os | import os | ||||
| from mindspore import Model, context | from mindspore import Model, context | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | |||||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | ||||
| from src.callbacks import LossCallBack | from src.callbacks import LossCallBack | ||||
| @@ -75,7 +75,7 @@ def test_train(configure): | |||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=1, | ckptconfig = CheckpointConfig(save_checkpoint_steps=1, | ||||
| keep_checkpoint_max=5) | keep_checkpoint_max=5) | ||||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig) | ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig) | ||||
| model.train(epochs, ds_train, callbacks=[callback, ckpoint_cb]) | |||||
| model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb]) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||