|
|
|
@@ -30,7 +30,7 @@ from src.metrics import AUCMetric |
|
|
|
from src.config import WideDeepConfig |
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
context.set_context(mode=GRAPH_MODE, device_target="Davinci", save_graph=True) |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) |
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) |
|
|
|
init() |
|
|
|
|
|
|
|
@@ -71,8 +71,8 @@ def test_train_eval(): |
|
|
|
test_train_eval |
|
|
|
""" |
|
|
|
np.random.seed(1000) |
|
|
|
config = WideDeepConfig |
|
|
|
data_path = Config.data_path |
|
|
|
config = WideDeepConfig() |
|
|
|
data_path = config.data_path |
|
|
|
batch_size = config.batch_size |
|
|
|
epochs = config.epochs |
|
|
|
print("epochs is {}".format(epochs)) |
|
|
|
@@ -94,8 +94,14 @@ def test_train_eval(): |
|
|
|
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) |
|
|
|
|
|
|
|
callback = LossCallBack(config=config) |
|
|
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5) |
|
|
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', |
|
|
|
directory=config.ckpt_path, config=ckptconfig) |
|
|
|
out = model.eval(ds_eval) |
|
|
|
print("=====" * 5 + "model.eval() initialized: {}".format(out)) |
|
|
|
model.train(epochs, ds_train, |
|
|
|
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_train_eval() |