diff --git a/npu_mnist_example/train.py b/npu_mnist_example/train.py index 74adb65..7be1f84 100644 --- a/npu_mnist_example/train.py +++ b/npu_mnist_example/train.py @@ -68,7 +68,10 @@ if __name__ == "__main__": net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) - load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) + # load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) + if os.path.exists(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt")): + load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) + if args.device_target != "Ascend": model = Model(network, net_loss,