You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_for_c2net.py 3.2 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. '''
  2. 在训练环境中,代码会自动放在/tmp/code目录下,上传的数据集会自动放在/tmp/dataset目录下,模型下载路径默认在/tmp/output下,请将模型输出位置指定到/tmp/model,
  3. 启智平台界面会提供/tmp/output目录下的文件下载。
  4. '''
  5. from model import Model
  6. import numpy as np
  7. import torch
  8. from torchvision.datasets import mnist
  9. from torch.nn import CrossEntropyLoss
  10. from torch.optim import SGD
  11. from torch.utils.data import DataLoader
  12. from torchvision.transforms import ToTensor
  13. import argparse
  14. # Training settings
  15. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  16. #数据集位置放在/tmp/dataset下
  17. parser.add_argument('--traindata', default="/tmp/dataset/train" ,help='path to train dataset')
  18. parser.add_argument('--testdata', default="/tmp/dataset/test" ,help='path to test dataset')
  19. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  20. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  21. if __name__ == '__main__':
  22. args = parser.parse_args()
  23. #日志输出
  24. print('cuda is available:{}'.format(torch.cuda.is_available()))
  25. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  26. batch_size = args.batch_size
  27. train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False)
  28. test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False)
  29. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  30. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  31. model = Model().to(device)
  32. sgd = SGD(model.parameters(), lr=1e-1)
  33. cost = CrossEntropyLoss()
  34. epoch = args.epoch_size
  35. #日志输出
  36. print('epoch_size is:{}'.format(epoch))
  37. for _epoch in range(epoch):
  38. print('the {} epoch_size begin'.format(_epoch + 1))
  39. model.train()
  40. for idx, (train_x, train_label) in enumerate(train_loader):
  41. train_x = train_x.to(device)
  42. train_label = train_label.to(device)
  43. label_np = np.zeros((train_label.shape[0], 10))
  44. sgd.zero_grad()
  45. predict_y = model(train_x.float())
  46. loss = cost(predict_y, train_label.long())
  47. if idx % 10 == 0:
  48. print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
  49. loss.backward()
  50. sgd.step()
  51. correct = 0
  52. _sum = 0
  53. model.eval()
  54. for idx, (test_x, test_label) in enumerate(test_loader):
  55. test_x = test_x
  56. test_label = test_label
  57. predict_y = model(test_x.to(device).float()).detach()
  58. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  59. label_np = test_label.numpy()
  60. _ = predict_ys == test_label
  61. correct += np.sum(_.numpy(), axis=-1)
  62. _sum += _.shape[0]
  63. #日志输出
  64. print('accuracy: {:.2f}'.format(correct / _sum))
  65. #模型输出位置放在/tmp/output下
  66. torch.save(model, '/tmp/output/mnist_epoch{}_{:.2f}.pkl'.format(_epoch+1, correct / _sum))

No Description