You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pretrain.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. 1,The dataset structure of the single-dataset in this example
  8. MnistDataset_torch.zip
  9. ├── test
  10. └── train
  11. 2,Due to the adaptability of a100, before using the training environment, please use the recommended image of the
  12. platform with cuda 11.Then adjust the code and submit the image.
  13. The image of this example is: dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191
  14. In the training environment, the uploaded dataset will be automatically placed in the /dataset directory.
  15. Note: the paths are different when selecting a single dataset and multiple datasets.
  16. (1)If it is a single dataset: if MnistDataset_torch.zip is selected,
  17. the dataset directory is /dataset/train, /dataset/test;
  18. If it is a multiple dataset: if MnistDataset_torch.zip is selected,
  19. the dataset directory is /dataset/MnistDataset_torch/train, /dataset/MnistDataset_torch/test;
  20. (2)If the pre-training model file is selected, the selected pre-training model will be
  21. automatically placed in the /pretrainmodel directory.
  22. for example:
  23. If the model file is selected, the calling method is: '/pretrainmodel/' + args.pretrainmodelname
  24. The model download path is under /model by default. Please specify the model output location to /model,
  25. and the Qizhi platform will provide file downloads under the /model directory.
  26. '''
  27. from model import Model
  28. import numpy as np
  29. import torch
  30. from torchvision.datasets import mnist
  31. from torch.nn import CrossEntropyLoss
  32. from torch.optim import SGD
  33. from torch.utils.data import DataLoader
  34. from torchvision.transforms import ToTensor
  35. import argparse
  36. import os
  37. # Training settings
  38. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  39. #The dataset location is placed under /dataset
  40. parser.add_argument('--traindata', default="/dataset/train" ,help='path to train dataset')
  41. parser.add_argument('--testdata', default="/dataset/test" ,help='path to test dataset')
  42. parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train')
  43. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  44. #获取模型文件名称
  45. parser.add_argument('--ckpt_url', default="", help='pretrain model path')
  46. # 参数声明
  47. WORKERS = 0 # dataloder线程数
  48. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  49. model = Model().to(device)
  50. optimizer = SGD(model.parameters(), lr=1e-1)
  51. cost = CrossEntropyLoss()
  52. # 模型训练
  53. def train(model, train_loader, epoch):
  54. model.train()
  55. train_loss = 0
  56. for i, data in enumerate(train_loader, 0):
  57. x, y = data
  58. x = x.to(device)
  59. y = y.to(device)
  60. optimizer.zero_grad()
  61. y_hat = model(x)
  62. loss = cost(y_hat, y)
  63. loss.backward()
  64. optimizer.step()
  65. train_loss += loss
  66. loss_mean = train_loss / (i+1)
  67. print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
  68. # 模型测试
  69. def test(model, test_loader, test_data):
  70. model.eval()
  71. test_loss = 0
  72. correct = 0
  73. with torch.no_grad():
  74. for i, data in enumerate(test_loader, 0):
  75. x, y = data
  76. x = x.to(device)
  77. y = y.to(device)
  78. optimizer.zero_grad()
  79. y_hat = model(x)
  80. test_loss += cost(y_hat, y).item()
  81. pred = y_hat.max(1, keepdim=True)[1]
  82. correct += pred.eq(y.view_as(pred)).sum().item()
  83. test_loss /= (i+1)
  84. print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  85. test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  86. def main():
  87. # 如果有保存的模型,则加载模型,并在其基础上继续训练
  88. if os.path.exists(args.ckpt_url):
  89. checkpoint = torch.load(args.ckpt_url)
  90. model.load_state_dict(checkpoint['model'])
  91. optimizer.load_state_dict(checkpoint['optimizer'])
  92. start_epoch = checkpoint['epoch']
  93. print('加载 epoch {} 权重成功!'.format(start_epoch))
  94. else:
  95. start_epoch = 0
  96. print('无保存模型,将从头开始训练!')
  97. for epoch in range(start_epoch+1, epochs):
  98. train(model, train_loader, epoch)
  99. test(model, test_loader, test_dataset)
  100. # 保存模型
  101. state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
  102. torch.save(state, '/model/mnist_epoch{}.pkl'.format(epoch))
  103. if __name__ == '__main__':
  104. args, unknown = parser.parse_known_args()
  105. #log output
  106. print('cuda is available:{}'.format(torch.cuda.is_available()))
  107. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  108. batch_size = args.batch_size
  109. epochs = args.epoch_size
  110. train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False)
  111. test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False)
  112. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  113. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  114. main()