|
|
|
@@ -12,7 +12,7 @@ If there are Chinese comments in the code,please add at the beginning: |
|
|
|
|
|
|
|
预训练模型文件夹结构是: |
|
|
|
Torch_MNIST_Example_Model |
|
|
|
├── mnist_epoch1_0.76.pkl |
|
|
|
├── mnist_epoch1.pkl |
|
|
|
|
|
|
|
''' |
|
|
|
|
|
|
|
@@ -99,8 +99,8 @@ if __name__ == '__main__': |
|
|
|
test_loader = DataLoader(test_dataset, batch_size=batch_size) |
|
|
|
|
|
|
|
#如果有保存的模型,则加载模型,并在其基础上继续训练 |
|
|
|
if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.76.pkl")): |
|
|
|
checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.76.pkl")) |
|
|
|
if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl")): |
|
|
|
checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1.pkl")) |
|
|
|
model.load_state_dict(checkpoint['model']) |
|
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
|
start_epoch = checkpoint['epoch'] |
|
|
|
|