Browse Source

更新 'gpu_mnist_example/train_gpu.py'

liuzx-patch-1
liuzxtest02 2 years ago
parent
commit
2be3c0eab1
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      gpu_mnist_example/train_gpu.py

+ 1
- 1
gpu_mnist_example/train_gpu.py View File

@@ -103,7 +103,7 @@ if __name__ == '__main__':


#如果有保存的模型,则加载模型,并在其基础上继续训练 #如果有保存的模型,则加载模型,并在其基础上继续训练
if os.path.exists(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")): if os.path.exists(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")):
checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_Example_model_zjdt", "mnist_epoch1_0.76.pkl"))
checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl"))
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] start_epoch = checkpoint['epoch']


Loading…
Cancel
Save