You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gpu_train.py 3.5 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. '''
  2. 由于a100的适配性问题,使用训练环境前请使用平台的含有cuda11以上的推荐镜像在调试环境中调试自己的代码,
  3. 本示例的镜像地址是dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191,并
  4. 提交镜像,再切到训练环境训练已跑通的代码。
  5. 在训练环境中,上传的数据集会自动放在/dataset目录下,模型下载路径默认在/model下,请将模型输出位置指定到/model,
  6. 启智平台界面会提供/model目录下的文件下载。
  7. '''
  8. from model import Model
  9. import numpy as np
  10. import torch
  11. from torchvision.datasets import mnist
  12. from torch.nn import CrossEntropyLoss
  13. from torch.optim import SGD
  14. from torch.utils.data import DataLoader
  15. from torchvision.transforms import ToTensor
  16. import argparse
  17. # Training settings
  18. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  19. #数据集位置放在/dataset下
  20. parser.add_argument('--traindata', default="/dataset/train" ,help='path to train dataset')
  21. parser.add_argument('--testdata', default="/dataset/test" ,help='path to test dataset')
  22. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  23. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  24. if __name__ == '__main__':
  25. args = parser.parse_args()
  26. #日志输出
  27. print('cuda is available:{}'.format(torch.cuda.is_available()))
  28. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  29. batch_size = args.batch_size
  30. train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False)
  31. test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False)
  32. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  33. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  34. model = Model().to(device)
  35. sgd = SGD(model.parameters(), lr=1e-1)
  36. cost = CrossEntropyLoss()
  37. epoch = args.epoch_size
  38. #日志输出
  39. print('epoch_size is:{}'.format(epoch))
  40. for _epoch in range(epoch):
  41. print('the {} epoch_size begin'.format(_epoch + 1))
  42. model.train()
  43. for idx, (train_x, train_label) in enumerate(train_loader):
  44. train_x = train_x.to(device)
  45. train_label = train_label.to(device)
  46. label_np = np.zeros((train_label.shape[0], 10))
  47. sgd.zero_grad()
  48. predict_y = model(train_x.float())
  49. loss = cost(predict_y, train_label.long())
  50. if idx % 10 == 0:
  51. print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
  52. loss.backward()
  53. sgd.step()
  54. correct = 0
  55. _sum = 0
  56. model.eval()
  57. for idx, (test_x, test_label) in enumerate(test_loader):
  58. test_x = test_x
  59. test_label = test_label
  60. predict_y = model(test_x.to(device).float()).detach()
  61. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  62. label_np = test_label.numpy()
  63. _ = predict_ys == test_label
  64. correct += np.sum(_.numpy(), axis=-1)
  65. _sum += _.shape[0]
  66. #日志输出
  67. print('accuracy: {:.2f}'.format(correct / _sum))
  68. #模型输出位置放在/model下
  69. torch.save(model, '/model/mnist_epoch{}_{:.2f}.pkl'.format(_epoch+1, correct / _sum))

No Description