You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pretrain_for_c2net.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. In the training environment,
  8. (1)the code will be automatically placed in the /tmp/code directory,
  9. (2)the uploaded dataset will be automatically placed in the /tmp/dataset directory
  10. Note: the paths are different when selecting a single dataset and multiple datasets.
  11. (1)If it is a single dataset: if MnistDataset_torch.zip is selected,
  12. the dataset directory is /tmp/dataset/train, /dataset/test;
  13. The dataset structure of the single dataset in the training image in this example:
  14. tmp
  15. ├──dataset
  16. ├── test
  17. └── train
  18. If multiple datasets are selected, such as MnistDataset_torch.zip and checkpoint_epoch1_0.73.zip,
  19. the dataset directory is /tmp/dataset/MnistDataset_torch/train, /tmp/dataset/MnistDataset_torch/test
  20. and /tmp/dataset/checkpoint_epoch1_0.73/mnist_epoch1_0.73.pkl
  21. The dataset structure in the training image for multiple datasets in this example:
  22. tmp
  23. ├──dataset
  24. ├── MnistDataset_torch
  25. | ├── test
  26. | └── train
  27. └── checkpoint_epoch1_0.73
  28. ├── mnist_epoch1_0.73.pkl
  29. (3)the model download path is under /tmp/output by default, please specify the model output location to /tmp/output,
  30. qizhi platform will provide file downloads under the /tmp/output directory.
  31. (4)If the pre-training model file is selected, the selected pre-training model will be
  32. automatically placed in the /tmp/pretrainmodel directory.
  33. for example:
  34. If the model file is selected, the calling method is: '/pretrainmodel/' + args.pretrainmodelname
  35. In addition, if you want to get the model file after each training, you can call the uploader_for_gpu tool,
  36. which is written as:
  37. import os
  38. os.system("cd /tmp/script_for_grampus/ &&./uploader_for_gpu " + "/tmp/output/")
  39. '''
  40. from model import Model
  41. import numpy as np
  42. import torch
  43. from torchvision.datasets import mnist
  44. from torch.nn import CrossEntropyLoss
  45. from torch.optim import SGD
  46. from torch.utils.data import DataLoader
  47. from torchvision.transforms import ToTensor
  48. import argparse
  49. import os
  50. # Training settings
  51. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  52. #The dataset location is placed under /dataset
  53. parser.add_argument('--traindata', default="/tmp/dataset/train" ,help='path to train dataset')
  54. parser.add_argument('--testdata', default="/tmp/dataset/test" ,help='path to test dataset')
  55. parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train')
  56. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  57. #获取模型文件名称
  58. parser.add_argument('--ckpt_url', default="", help='pretrain model path')
  59. # 参数声明
  60. WORKERS = 0 # dataloder线程数
  61. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  62. model = Model().to(device)
  63. optimizer = SGD(model.parameters(), lr=1e-1)
  64. cost = CrossEntropyLoss()
  65. # 模型训练
  66. def train(model, train_loader, epoch):
  67. model.train()
  68. train_loss = 0
  69. for i, data in enumerate(train_loader, 0):
  70. x, y = data
  71. x = x.to(device)
  72. y = y.to(device)
  73. optimizer.zero_grad()
  74. y_hat = model(x)
  75. loss = cost(y_hat, y)
  76. loss.backward()
  77. optimizer.step()
  78. train_loss += loss
  79. loss_mean = train_loss / (i+1)
  80. print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
  81. # 模型测试
  82. def test(model, test_loader, test_data):
  83. model.eval()
  84. test_loss = 0
  85. correct = 0
  86. with torch.no_grad():
  87. for i, data in enumerate(test_loader, 0):
  88. x, y = data
  89. x = x.to(device)
  90. y = y.to(device)
  91. optimizer.zero_grad()
  92. y_hat = model(x)
  93. test_loss += cost(y_hat, y).item()
  94. pred = y_hat.max(1, keepdim=True)[1]
  95. correct += pred.eq(y.view_as(pred)).sum().item()
  96. test_loss /= (i+1)
  97. print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
  98. test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  99. def main():
  100. # 如果有保存的模型,则加载模型,并在其基础上继续训练
  101. if os.path.exists(args.ckpt_url):
  102. checkpoint = torch.load(args.ckpt_url)
  103. model.load_state_dict(checkpoint['model'])
  104. optimizer.load_state_dict(checkpoint['optimizer'])
  105. start_epoch = checkpoint['epoch']
  106. print('加载 epoch {} 权重成功!'.format(start_epoch))
  107. else:
  108. start_epoch = 0
  109. print('无保存模型,将从头开始训练!')
  110. for epoch in range(start_epoch+1, epochs):
  111. train(model, train_loader, epoch)
  112. test(model, test_loader, test_dataset)
  113. # 保存模型
  114. state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
  115. torch.save(state, '/tmp/output/mnist_epoch{}.pkl'.format(epoch))
  116. #After calling uploader_for_gpu, after each epoch training, the result file under /tmp/output will be sent back to Qizhi
  117. os.system("cd /tmp/script_for_grampus/ &&./uploader_for_gpu " + "/tmp/output/")
  118. if __name__ == '__main__':
  119. args, unknown = parser.parse_known_args()
  120. #log output
  121. print('cuda is available:{}'.format(torch.cuda.is_available()))
  122. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  123. batch_size = args.batch_size
  124. epochs = args.epoch_size
  125. train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False)
  126. test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False)
  127. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  128. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  129. main()