|
|
@@ -170,7 +170,7 @@ def train(): |
|
|
ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck)
|
|
|
ckpoint_cb = ModelCheckpoint(prefix=args.model, directory=args.train_dir, config=config_ck)
|
|
|
cbs.append(ckpoint_cb)
|
|
|
cbs.append(ckpoint_cb)
|
|
|
|
|
|
|
|
|
model.train(args.train_epochs, dataset, callbacks=cbs)
|
|
|
|
|
|
|
|
|
model.train(args.train_epochs, dataset, callbacks=cbs, dataset_sink_mode=(args.device_target != "CPU"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
if __name__ == '__main__':
|
|
|
|