You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_for_multidataset.py 4.8 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. '''
  2. 1,本示例中多数据集训练上传的数据集结构
  3. MnistDataset_torch.zip
  4. ├── test
  5. └── train
  6. checkpoint_epoch1_0.73.zip
  7. ├── mnist_epoch1_0.73.pkl
  8. 2,由于a100的适配性问题,使用训练环境前请使用平台的含有cuda11以上的推荐镜像在调试环境中调试自己的代码,
  9. 本示例的镜像地址是dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191,并
  10. 提交镜像,再切到训练环境训练已跑通的代码。
  11. 在训练环境中,上传的数据集会自动放在/dataset目录下,注意:选择单数据集和多数据集时的路径不同!
  12. (1)如果是单数据集:如选择的是MnistDataset_torch.zip,则数据集目录为/dataset/train、/dataset/test;
  13. 本示例中单数据集在训练镜像中的数据集结构
  14. dataset
  15. ├── test
  16. └── train
  17. (2)如选择的是多数据集,如选择的是MnistDataset_torch.zip和checkpoint_epoch1_0.73.zip,则数据集
  18. 目录为/dataset/MnistDataset_torch/train、/dataset/MnistDataset_torch/test
  19. 和/dataset/checkpoint_epoch1_0.73/mnist_epoch1_0.73.pkl
  20. 本示例中多数据集在训练镜像中的数据集结构
  21. dataset
  22. ├── MnistDataset_torch
  23. | ├── test
  24. | └── train
  25. └── checkpoint_epoch1_0.73
  26. ├── mnist_epoch1_0.73.pkl
  27. 模型下载路径默认在/model下,请将模型输出位置指定到/model,启智平台界面会提供/model目录下的文件下载。
  28. '''
  29. from model import Model
  30. import numpy as np
  31. import torch
  32. from torchvision.datasets import mnist
  33. from torch.nn import CrossEntropyLoss
  34. from torch.optim import SGD
  35. from torch.utils.data import DataLoader
  36. from torchvision.transforms import ToTensor
  37. import argparse
  38. # Training settings
  39. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  40. #数据集位置放在/dataset下
  41. parser.add_argument('--traindata', default="/dataset/MnistDataset_torch/train" ,help='path to train dataset')
  42. parser.add_argument('--testdata', default="/dataset/MnistDataset_torch/test" ,help='path to test dataset')
  43. parser.add_argument('--checkpoint', default="/dataset/checkpoint_epoch1_0.73/mnist_epoch1_0.73.pkl" ,help='checkpoint file')
  44. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  45. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  46. if __name__ == '__main__':
  47. args = parser.parse_args()
  48. #日志输出
  49. print('cuda is available:{}'.format(torch.cuda.is_available()))
  50. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  51. batch_size = args.batch_size
  52. train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False)
  53. test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False)
  54. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  55. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  56. model = Model().to(device)
  57. sgd = SGD(model.parameters(), lr=1e-1)
  58. cost = CrossEntropyLoss()
  59. epoch = args.epoch_size
  60. #日志输出
  61. print('epoch_size is:{}'.format(epoch))
  62. #加载已训练好的模型:
  63. # path = args.checkpoint
  64. # checkpoint = torch.load(path, map_location=device)
  65. # model.load_state_dict(checkpoint)
  66. #开始训练
  67. for _epoch in range(epoch):
  68. print('the {} epoch_size begin'.format(_epoch + 1))
  69. model.train()
  70. for idx, (train_x, train_label) in enumerate(train_loader):
  71. train_x = train_x.to(device)
  72. train_label = train_label.to(device)
  73. label_np = np.zeros((train_label.shape[0], 10))
  74. sgd.zero_grad()
  75. predict_y = model(train_x.float())
  76. loss = cost(predict_y, train_label.long())
  77. if idx % 10 == 0:
  78. print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
  79. loss.backward()
  80. sgd.step()
  81. correct = 0
  82. _sum = 0
  83. model.eval()
  84. for idx, (test_x, test_label) in enumerate(test_loader):
  85. test_x = test_x
  86. test_label = test_label
  87. predict_y = model(test_x.to(device).float()).detach()
  88. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  89. label_np = test_label.numpy()
  90. _ = predict_ys == test_label
  91. correct += np.sum(_.numpy(), axis=-1)
  92. _sum += _.shape[0]
  93. #日志输出
  94. print('accuracy: {:.2f}'.format(correct / _sum))
  95. #模型输出位置放在/model下
  96. torch.save(model, '/model/mnist_epoch{}_{:.2f}.pkl'.format(_epoch+1, correct / _sum))