|
|
|
@@ -61,7 +61,7 @@ if __name__ == "__main__": |
|
|
|
keep_checkpoint_max=cfg.keep_checkpoint_max) |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.ckpt_path, config=config_ck) |
|
|
|
|
|
|
|
if args.device_target == "CPU": |
|
|
|
if args.device_target != "Ascend": |
|
|
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) |
|
|
|
else: |
|
|
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2") |
|
|
|
|