Browse Source

更新 'gpu_mnist_example/train_gpu.py'

liuzx-patch-1
liuzxtest02 2 years ago
parent
commit
2286844740
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      gpu_mnist_example/train_gpu.py

+ 4
- 4
gpu_mnist_example/train_gpu.py View File

@@ -89,14 +89,14 @@ if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size
epochs = args.epoch_size
train_dataset = mnist.MNIST(root=os.path.join(dataset_path, "train"), train=True, transform=ToTensor(),download=False)
test_dataset = mnist.MNIST(root=os.path.join(dataset_path, "test"), train=False, transform=ToTensor(),download=False)
train_dataset = mnist.MNIST(root=os.path.join(dataset_path + "/MnistDataset_torch", "train"), train=True, transform=ToTensor(),download=False)
test_dataset = mnist.MNIST(root=os.path.join(dataset_path+ "/MnistDataset_torch", "test"), train=False, transform=ToTensor(),download=False)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

#如果有保存的模型,则加载模型,并在其基础上继续训练
if os.path.exists(os.path.join(pretrain_model_path, "mnist_epoch1_0.76.pkl")):
checkpoint = torch.load(os.path.join(pretrain_model_path, "mnist_epoch1_0.76.pkl"))
if os.path.exists(os.path.join(pretrain_model_path + "/MNIST_Example_model_zjdt", "mnist_epoch1_0.76.pkl")):
checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_Example_model_zjdt", "mnist_epoch1_0.76.pkl"))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']


Loading…
Cancel
Save