diff --git a/gpu_mnist_example/train_gpu.py b/gpu_mnist_example/train_gpu.py index c2b396e..31b026d 100644 --- a/gpu_mnist_example/train_gpu.py +++ b/gpu_mnist_example/train_gpu.py @@ -103,7 +103,7 @@ if __name__ == '__main__': #如果有保存的模型,则加载模型,并在其基础上继续训练 if os.path.exists(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")): - checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_Example_model_zjdt", "mnist_epoch1_0.76.pkl")) + checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch']