Browse Source

modify lenet to mixed precision

tags/v1.1.0
guoqi 5 years ago
parent
commit
6606c1e431
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      model_zoo/official/cv/lenet/train.py

+ 1
- 1
model_zoo/official/cv/lenet/train.py View File

@@ -60,7 +60,7 @@ if __name__ == "__main__":
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.ckpt_path, config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")

print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],


Loading…
Cancel
Save