|
|
|
@@ -109,8 +109,11 @@ def train_and_eval(config): |
|
|
|
directory=config.ckpt_path, config=ckptconfig) |
|
|
|
out = model.eval(ds_eval) |
|
|
|
print("=====" * 5 + "model.eval() initialized: {}".format(out)) |
|
|
|
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] |
|
|
|
if get_rank() == 0: |
|
|
|
callback_list.append(ckpoint_cb) |
|
|
|
model.train(epochs, ds_train, |
|
|
|
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb], |
|
|
|
callbacks=callback_list, |
|
|
|
sink_size=ds_train.get_dataset_size()) |
|
|
|
|
|
|
|
|
|
|
|
|