Browse Source

fix eval error in single device and data parallel mode

tags/v1.1.0
huangxinjing 5 years ago
parent
commit
f4ce5768f1
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      model_zoo/official/recommend/wide_and_deep/train_and_eval.py
  2. +1
    -1
      model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py

+ 1
- 1
model_zoo/official/recommend/wide_and_deep/train_and_eval.py View File

@@ -95,7 +95,7 @@ def test_train_eval(config):
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig)

out = model.eval(ds_eval)
out = model.eval(ds_eval, dataset_sink_mode=(not sparse))
print("=====" * 5 + "model.eval() initialized: {}".format(out))
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb],


+ 1
- 1
model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py View File

@@ -105,7 +105,7 @@ def train_and_eval(config):
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/',
config=ckptconfig)
out = model.eval(ds_eval)
out = model.eval(ds_eval, dataset_sink_mode=(not sparse))
print("=====" * 5 + "model.eval() initialized: {}".format(out))
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
if get_rank() == 0:


Loading…
Cancel
Save