You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_for_c2net.py 3.4 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. In the training environment,
  8. the code will be automatically placed in the /tmp/code directory,
  9. the uploaded dataset will be automatically placed in the /tmp/dataset directory, and
  10. the model download path is under /tmp/output by default, please specify the model output location to /tmp/model,
  11. qizhi platform will provide file downloads under the /tmp/output directory.
  12. '''
  13. from model import Model
  14. import numpy as np
  15. import torch
  16. from torchvision.datasets import mnist
  17. from torch.nn import CrossEntropyLoss
  18. from torch.optim import SGD
  19. from torch.utils.data import DataLoader
  20. from torchvision.transforms import ToTensor
  21. import argparse
  22. # Training settings
  23. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  24. #The dataset location is placed under /dataset
  25. parser.add_argument('--traindata', default="/tmp/dataset/train" ,help='path to train dataset')
  26. parser.add_argument('--testdata', default="/tmp/dataset/test" ,help='path to test dataset')
  27. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  28. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  29. if __name__ == '__main__':
  30. args, unknown = parser.parse_known_args()
  31. #log output
  32. print('cuda is available:{}'.format(torch.cuda.is_available()))
  33. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  34. batch_size = args.batch_size
  35. train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False)
  36. test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False)
  37. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  38. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  39. model = Model().to(device)
  40. sgd = SGD(model.parameters(), lr=1e-1)
  41. cost = CrossEntropyLoss()
  42. epoch = args.epoch_size
  43. print('epoch_size is:{}'.format(epoch))
  44. for _epoch in range(epoch):
  45. print('the {} epoch_size begin'.format(_epoch + 1))
  46. model.train()
  47. for idx, (train_x, train_label) in enumerate(train_loader):
  48. train_x = train_x.to(device)
  49. train_label = train_label.to(device)
  50. label_np = np.zeros((train_label.shape[0], 10))
  51. sgd.zero_grad()
  52. predict_y = model(train_x.float())
  53. loss = cost(predict_y, train_label.long())
  54. if idx % 10 == 0:
  55. print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
  56. loss.backward()
  57. sgd.step()
  58. correct = 0
  59. _sum = 0
  60. model.eval()
  61. for idx, (test_x, test_label) in enumerate(test_loader):
  62. test_x = test_x
  63. test_label = test_label
  64. predict_y = model(test_x.to(device).float()).detach()
  65. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  66. label_np = test_label.numpy()
  67. _ = predict_ys == test_label
  68. correct += np.sum(_.numpy(), axis=-1)
  69. _sum += _.shape[0]
  70. print('accuracy: {:.2f}'.format(correct / _sum))
  71. #The model output location is placed under /model
  72. torch.save(model, '/tmp/output/mnist_epoch{}_{:.2f}.pkl'.format(_epoch+1, correct / _sum))