|
|
|
@@ -15,7 +15,7 @@ |
|
|
|
import os |
|
|
|
|
|
|
|
from mindspore import Model, context |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor |
|
|
|
|
|
|
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel |
|
|
|
from src.callbacks import LossCallBack, EvalCallBack |
|
|
|
@@ -23,10 +23,11 @@ from src.datasets import create_dataset |
|
|
|
from src.metrics import AUCMetric |
|
|
|
from src.config import WideDeepConfig |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci") |
|
|
|
|
|
|
|
|
|
|
|
def get_WideDeep_net(config): |
|
|
|
""" |
|
|
|
Get network of wide&deep model. |
|
|
|
""" |
|
|
|
WideDeep_net = WideDeepModel(config) |
|
|
|
|
|
|
|
loss_net = NetWithLossClass(WideDeep_net, config) |
|
|
|
@@ -87,11 +88,13 @@ def test_train_eval(config): |
|
|
|
|
|
|
|
out = model.eval(ds_eval) |
|
|
|
print("=====" * 5 + "model.eval() initialized: {}".format(out)) |
|
|
|
model.train(epochs, ds_train, callbacks=[eval_callback, callback, ckpoint_cb]) |
|
|
|
model.train(epochs, ds_train, |
|
|
|
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
wide_deep_config = WideDeepConfig() |
|
|
|
wide_deep_config.argparse_init() |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target) |
|
|
|
test_train_eval(wide_deep_config) |