From dc2bf0d76f74329dac5368d487b62dc87a260294 Mon Sep 17 00:00:00 2001 From: liuzx Date: Tue, 14 May 2024 14:46:14 +0800 Subject: [PATCH] update train.py --- npu_mnist_example/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/npu_mnist_example/train.py b/npu_mnist_example/train.py index 74adb65..7be1f84 100644 --- a/npu_mnist_example/train.py +++ b/npu_mnist_example/train.py @@ -68,7 +68,10 @@ if __name__ == "__main__": net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) - load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) + # load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) + if os.path.exists(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt")): + load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) + if args.device_target != "Ascend": model = Model(network, net_loss,