You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_gcu.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. 示例选用的数据集是MnistDataset_torch.zip
  8. 数据集结构是:
  9. MnistDataset_torch.zip
  10. ├── test
  11. │ ├── MNIST/processed/test.pt
  12. │ └── MNIST/processed/training.pt
  13. │ ├── MNIST/raw/train-images-idx3-ubyte
  14. │ └── MNIST/raw/train-labels-idx1-ubyte
  15. │ ├── MNIST/raw/t10k-images-idx3-ubyte
  16. │ └── MNIST/raw/t10k-labels-idx1-ubyte
  17. ├── train
  18. │ ├── MNIST/processed/test.pt
  19. │ └── MNIST/processed/training.pt
  20. │ ├── MNIST/raw/train-images-idx3-ubyte
  21. │ └── MNIST/raw/train-labels-idx1-ubyte
  22. │ ├── MNIST/raw/t10k-images-idx3-ubyte
  23. │ └── MNIST/raw/t10k-labels-idx1-ubyte
  24. 示例选用的预训练模型文件为:mnist_epoch1_0.86.pkl
  25. 代码会自动放置在/tmp/code目录下。
  26. 数据集在界面选择后,会自动放置在/tmp/dataset目录下。
  27. 预训练模型文件在界面选择后,会自动放置在/tmp/pretrainmodel目录下。
  28. 输出的模型文件也需要放置在/tmp/output目录下,平台会自动下载/tmp/output目录下的文件。
  29. 如果选用了多数据集,则应在/tmp/dataset后带上数据集名称,比如/tmp/dataset/MnistDataset_torch/train
  30. '''
  31. import torch
  32. from model import Model
  33. import numpy as np
  34. from torchvision.datasets import mnist
  35. from torch.nn import CrossEntropyLoss
  36. from torch.optim import SGD
  37. from torch.utils.data import DataLoader
  38. from torchvision.transforms import ToTensor
  39. import argparse
  40. import os
  41. os.system("pip install {}".format(os.getenv("OPENI_SDK_PATH")))
  42. import importlib.util
  43. def is_torch_dtu_available():
  44. if importlib.util.find_spec("torch_dtu") is None:
  45. return False
  46. if importlib.util.find_spec("torch_dtu.core") is None:
  47. return False
  48. return importlib.util.find_spec("torch_dtu.core.dtu_model") is not None
  49. # Training settings
  50. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  51. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  52. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  53. if __name__ == '__main__':
  54. #获取参数并忽略超参数报错
  55. args, unknown = parser.parse_known_args()
  56. #初始化导入数据集和预训练模型到容器内
  57. openi_context = prepare()
  58. #获取数据集路径,预训练模型路径,输出路径
  59. dataset_path = openi_context.dataset_path
  60. pretrain_model_path = openi_context.pretrain_model_path
  61. output_path = openi_context.output_path
  62. dataset_path_A = dataset_path + "/MnistDataset"
  63. pretrain_model_path_A = pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j"
  64. print("dataset_path:")
  65. print(os.listdir(dataset_path))
  66. os.listdir(dataset_path)
  67. print("pretrain_model_path:")
  68. print(os.listdir(pretrain_model_path))
  69. os.listdir(pretrain_model_path)
  70. print("output_path:")
  71. print(os.listdir(output_path))
  72. os.listdir(output_path)
  73. # load DPU envs-xx.sh
  74. DTU_FLAG = True
  75. if is_torch_dtu_available():
  76. import torch_dtu
  77. import torch_dtu.distributed as dist
  78. import torch_dtu.core.dtu_model as dm
  79. from torch_dtu.nn.parallel import DistributedDataParallel as torchDDP
  80. print('dtu is available: True')
  81. device = dm.dtu_device()
  82. DTU_FLAG = True
  83. else:
  84. print('dtu is available: False')
  85. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  86. DTU_FLAG = False
  87. # 参数声明
  88. model = Model().to(device)
  89. optimizer = SGD(model.parameters(), lr=1e-1)
  90. #log output
  91. batch_size = args.batch_size
  92. train_dataset = mnist.MNIST(root=dataset_path_A + "/train", train=True, transform=ToTensor(),download=False)
  93. test_dataset = mnist.MNIST(root=dataset_path_A + "/test", train=False, transform=ToTensor(),download=False)
  94. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  95. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  96. model = Model().to(device)
  97. sgd = SGD(model.parameters(), lr=1e-1)
  98. cost = CrossEntropyLoss()
  99. epochs = args.epoch_size
  100. print('epoch_size is:{}'.format(epochs))
  101. # 如果有保存的模型,则加载模型,并在其基础上继续训练
  102. if os.path.exists(pretrain_model_path_A+"/mnist_epoch1_0.70.pkl"):
  103. checkpoint = torch.load(pretrain_model_path_A+"/mnist_epoch1_0.70.pkl")
  104. model.load_state_dict(checkpoint['model'])
  105. optimizer.load_state_dict(checkpoint['optimizer'])
  106. start_epoch = checkpoint['epoch']
  107. print('加载 epoch {} 权重成功!'.format(start_epoch))
  108. else:
  109. start_epoch = 0
  110. print('无保存模型,将从头开始训练!')
  111. for _epoch in range(start_epoch, epochs):
  112. print('the {} epoch_size begin'.format(_epoch + 1))
  113. model.train()
  114. for idx, (train_x, train_label) in enumerate(train_loader):
  115. train_x = train_x.to(device)
  116. train_label = train_label.to(device)
  117. label_np = np.zeros((train_label.shape[0], 10))
  118. sgd.zero_grad()
  119. predict_y = model(train_x.float())
  120. loss = cost(predict_y, train_label.long())
  121. if idx % 10 == 0:
  122. print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
  123. loss.backward()
  124. if DTU_FLAG:
  125. dm.optimizer_step(sgd, barrier=True)
  126. else:
  127. sgd.step()
  128. correct = 0
  129. _sum = 0
  130. model.eval()
  131. for idx, (test_x, test_label) in enumerate(test_loader):
  132. test_x = test_x
  133. test_label = test_label
  134. predict_y = model(test_x.to(device).float()).detach()
  135. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  136. label_np = test_label.numpy()
  137. _ = predict_ys == test_label
  138. correct += np.sum(_.numpy(), axis=-1)
  139. _sum += _.shape[0]
  140. print('accuracy: {:.2f}'.format(correct / _sum))
  141. #The model output location is placed under /tmp/output
  142. state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':_epoch+1}
  143. torch.save(state, '/tmp/output/mnist_epoch{}_{:.2f}.pkl'.format(_epoch+1, correct / _sum))
  144. print('test:')
  145. print(os.listdir("/tmp/output"))

No Description