diff --git a/model_zoo/official/cv/lenet/train.py b/model_zoo/official/cv/lenet/train.py index 31f6204bf1..73cc3c1dce 100644 --- a/model_zoo/official/cv/lenet/train.py +++ b/model_zoo/official/cv/lenet/train.py @@ -61,7 +61,7 @@ if __name__ == "__main__": keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.ckpt_path, config=config_ck) - if args.device_target == "CPU": + if args.device_target != "Ascend": model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) else: model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}, amp_level="O2")