You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_log.py 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. Due to the adaptability of a100, before using the training environment, please use the recommended image of the
  8. platform with cuda 11.Then adjust the code and submit the image.
  9. The image of this example is: dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191
  10. In the training environment, the uploaded dataset will be automatically placed in the /dataset directory.
  11. If it is a single dataset:
  12. if MnistDataset_torch.zip is selected,Then the dataset directory is /dataset/train, /dataset/test;
  13. If it is a multiple dataset:
  14. If MnistDataset_torch.zip and checkpoint_epoch1_0.73.zip are selected,
  15. the dataset directory is /dataset/MnistDataset_torch/train, /dataset/MnistDataset_torch/test
  16. and /dataset/checkpoint_epoch1_0.73/mnist_epoch1_0.73.pkl
  17. The model download path is under /model by default. Please specify the model output location to /model,
  18. and the Qizhi platform will provide file downloads under the /model directory.
  19. '''
  20. from model import Model
  21. import numpy as np
  22. import torch
  23. from torchvision.datasets import mnist
  24. from torch.nn import CrossEntropyLoss
  25. from torch.optim import SGD
  26. from torch.utils.data import DataLoader
  27. from torchvision.transforms import ToTensor
  28. import argparse
  29. import datetime
  30. # Training settings
  31. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  32. #The dataset location is placed under /dataset
  33. parser.add_argument('--traindata', default="/dataset/train" ,help='path to train dataset')
  34. parser.add_argument('--testdata', default="/dataset/test" ,help='path to test dataset')
  35. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  36. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  37. def gettime():
  38. timestr = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  39. return timestr
  40. if __name__ == '__main__':
  41. args, unknown = parser.parse_known_args()
  42. #log output
  43. print(gettime(), 'cuda is available:{}'.format(torch.cuda.is_available()))
  44. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  45. batch_size = args.batch_size
  46. train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False)
  47. test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False)
  48. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  49. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  50. model = Model().to(device)
  51. sgd = SGD(model.parameters(), lr=1e-1)
  52. cost = CrossEntropyLoss()
  53. epoch = args.epoch_size
  54. print(gettime(), 'epoch_size is:{}'.format(epoch))
  55. for _epoch in range(epoch):
  56. print(gettime(), 'the {} epoch_size begin'.format(_epoch + 1))
  57. model.train()
  58. for idx, (train_x, train_label) in enumerate(train_loader):
  59. train_x = train_x.to(device)
  60. train_label = train_label.to(device)
  61. label_np = np.zeros((train_label.shape[0], 10))
  62. sgd.zero_grad()
  63. predict_y = model(train_x.float())
  64. loss = cost(predict_y, train_label.long())
  65. print(gettime(), 'idx: {}, loss: {}'.format(idx, loss.sum().item()))
  66. if idx % 10 == 0:
  67. print("------------------")
  68. loss.backward()
  69. sgd.step()
  70. correct = 0
  71. _sum = 0
  72. model.eval()
  73. for idx, (test_x, test_label) in enumerate(test_loader):
  74. test_x = test_x
  75. test_label = test_label
  76. predict_y = model(test_x.to(device).float()).detach()
  77. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  78. label_np = test_label.numpy()
  79. _ = predict_ys == test_label
  80. correct += np.sum(_.numpy(), axis=-1)
  81. _sum += _.shape[0]
  82. print(gettime(), 'accuracy: {:.2f}'.format(correct / _sum))
  83. #The model output location is placed under /model
  84. torch.save(model, '/model/mnist_epoch{}_{:.2f}.pkl'.format(_epoch+1, correct / _sum))