|
|
|
@@ -14,7 +14,7 @@ |
|
|
|
""" test_training """ |
|
|
|
import os |
|
|
|
from mindspore import Model, context |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor |
|
|
|
|
|
|
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel |
|
|
|
from src.callbacks import LossCallBack |
|
|
|
@@ -75,7 +75,7 @@ def test_train(configure): |
|
|
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, |
|
|
|
keep_checkpoint_max=5) |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig) |
|
|
|
model.train(epochs, ds_train, callbacks=[callback, ckpoint_cb]) |
|
|
|
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb]) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|