From 2286844740786ae0609d3adf9b46ce14361145aa Mon Sep 17 00:00:00 2001 From: liuzxtest02 <134442@163.com> Date: Fri, 3 Nov 2023 17:22:32 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20'gpu=5Fmnist=5Fexample/tra?= =?UTF-8?q?in=5Fgpu.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gpu_mnist_example/train_gpu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gpu_mnist_example/train_gpu.py b/gpu_mnist_example/train_gpu.py index f163b42..7bdebcc 100644 --- a/gpu_mnist_example/train_gpu.py +++ b/gpu_mnist_example/train_gpu.py @@ -89,14 +89,14 @@ if __name__ == '__main__': device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") batch_size = args.batch_size epochs = args.epoch_size - train_dataset = mnist.MNIST(root=os.path.join(dataset_path, "train"), train=True, transform=ToTensor(),download=False) - test_dataset = mnist.MNIST(root=os.path.join(dataset_path, "test"), train=False, transform=ToTensor(),download=False) + train_dataset = mnist.MNIST(root=os.path.join(dataset_path + "/MnistDataset_torch", "train"), train=True, transform=ToTensor(),download=False) + test_dataset = mnist.MNIST(root=os.path.join(dataset_path+ "/MnistDataset_torch", "test"), train=False, transform=ToTensor(),download=False) train_loader = DataLoader(train_dataset, batch_size=batch_size) test_loader = DataLoader(test_dataset, batch_size=batch_size) #如果有保存的模型,则加载模型,并在其基础上继续训练 - if os.path.exists(os.path.join(pretrain_model_path, "mnist_epoch1_0.76.pkl")): - checkpoint = torch.load(os.path.join(pretrain_model_path, "mnist_epoch1_0.76.pkl")) + if os.path.exists(os.path.join(pretrain_model_path + "/MNIST_Example_model_zjdt", "mnist_epoch1_0.76.pkl")): + checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_Example_model_zjdt", "mnist_epoch1_0.76.pkl")) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch']