You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_for_multidataset.py 5.0 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. 1,The dataset structure of the multi-dataset in this example
  8. MnistDataset_torch.zip
  9. ├── test
  10. └── train
  11. checkpoint_epoch1_0.73.zip
  12. ├── mnist_epoch1_0.73.pkl
  13. 2,Due to the adaptability of a100, before using the training environment, please use the recommended image of the
  14. platform with cuda 11.Then adjust the code and submit the image.
  15. The image of this example is: dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191
  16. In the training environment, the uploaded dataset will be automatically placed in the /dataset directory.
  17. Note: the paths are different when selecting a single dataset and multiple datasets.
  18. (1)If it is a single dataset: if MnistDataset_torch.zip is selected,
  19. the dataset directory is /dataset/train, /dataset/test;
  20. The dataset structure of the single dataset in the training image in this example:
  21. dataset
  22. ├── test
  23. └── train
  24. (2)If multiple datasets are selected, such as MnistDataset_torch.zip and checkpoint_epoch1_0.73.zip,
  25. the dataset directory is /dataset/MnistDataset_torch/train, /dataset/MnistDataset_torch/test
  26. and /dataset/checkpoint_epoch1_0.73/mnist_epoch1_0.73.pkl
  27. The dataset structure in the training image for multiple datasets in this example:
  28. dataset
  29. ├── MnistDataset_torch
  30. | ├── test
  31. | └── train
  32. └── checkpoint_epoch1_0.73
  33. ├── mnist_epoch1_0.73.pkl
  34. The model download path is under /model by default. Please specify the model output location to /model,
  35. and the Qizhi platform will provide file downloads under the /model directory.
  36. '''
  37. from model import Model
  38. import numpy as np
  39. import torch
  40. from torchvision.datasets import mnist
  41. from torch.nn import CrossEntropyLoss
  42. from torch.optim import SGD
  43. from torch.utils.data import DataLoader
  44. from torchvision.transforms import ToTensor
  45. import argparse
  46. # Training settings
  47. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  48. #The dataset location is placed under /dataset
  49. parser.add_argument('--traindata', default="/dataset/MnistDataset_torch/train" ,help='path to train dataset')
  50. parser.add_argument('--testdata', default="/dataset/MnistDataset_torch/test" ,help='path to test dataset')
  51. parser.add_argument('--checkpoint', default="/dataset/checkpoint_epoch1_0.73/mnist_epoch1_0.73.pkl" ,help='checkpoint file')
  52. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  53. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  54. #获取模型文件名称
  55. parser.add_argument('--modelname', help='model name')
  56. if __name__ == '__main__':
  57. args, unknown = parser.parse_known_args()
  58. #log output
  59. print('cuda is available:{}'.format(torch.cuda.is_available()))
  60. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  61. batch_size = args.batch_size
  62. train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False)
  63. test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False)
  64. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  65. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  66. model = Model().to(device)
  67. sgd = SGD(model.parameters(), lr=1e-1)
  68. cost = CrossEntropyLoss()
  69. epoch = args.epoch_size
  70. print('epoch_size is:{}'.format(epoch))
  71. # Load the trained model
  72. # path = args.checkpoint
  73. # checkpoint = torch.load(path, map_location=device)
  74. # model.load_state_dict(checkpoint)
  75. for _epoch in range(epoch):
  76. print('the {} epoch_size begin'.format(_epoch + 1))
  77. model.train()
  78. for idx, (train_x, train_label) in enumerate(train_loader):
  79. train_x = train_x.to(device)
  80. train_label = train_label.to(device)
  81. label_np = np.zeros((train_label.shape[0], 10))
  82. sgd.zero_grad()
  83. predict_y = model(train_x.float())
  84. loss = cost(predict_y, train_label.long())
  85. if idx % 10 == 0:
  86. print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
  87. loss.backward()
  88. sgd.step()
  89. correct = 0
  90. _sum = 0
  91. model.eval()
  92. for idx, (test_x, test_label) in enumerate(test_loader):
  93. test_x = test_x
  94. test_label = test_label
  95. predict_y = model(test_x.to(device).float()).detach()
  96. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  97. label_np = test_label.numpy()
  98. _ = predict_ys == test_label
  99. correct += np.sum(_.numpy(), axis=-1)
  100. _sum += _.shape[0]
  101. print('accuracy: {:.2f}'.format(correct / _sum))
  102. #The model output location is placed under /model
  103. torch.save(model, '/model/mnist_epoch{}_{:.2f}.pkl'.format(_epoch+1, correct / _sum))