You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_gcu.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. 示例选用的数据集是MnistDataset_torch.zip
  8. 数据集结构是:
  9. MnistDataset_torch.zip
  10. ├── test
  11. │ ├── MNIST/processed/test.pt
  12. │ └── MNIST/processed/training.pt
  13. │ ├── MNIST/raw/train-images-idx3-ubyte
  14. │ └── MNIST/raw/train-labels-idx1-ubyte
  15. │ ├── MNIST/raw/t10k-images-idx3-ubyte
  16. │ └── MNIST/raw/t10k-labels-idx1-ubyte
  17. ├── train
  18. │ ├── MNIST/processed/test.pt
  19. │ └── MNIST/processed/training.pt
  20. │ ├── MNIST/raw/train-images-idx3-ubyte
  21. │ └── MNIST/raw/train-labels-idx1-ubyte
  22. │ ├── MNIST/raw/t10k-images-idx3-ubyte
  23. │ └── MNIST/raw/t10k-labels-idx1-ubyte
  24. 示例选用的预训练模型文件为:mnist_epoch1_0.86.pkl
  25. '''
  26. import os
  27. print("begin:")
  28. os.system("pip uninstall openi-test")
  29. os.system("pip install {}".format(os.getenv("OPENI_SDK_PATH")))
  30. import torch
  31. from model import Model
  32. import numpy as np
  33. from torchvision.datasets import mnist
  34. from torch.nn import CrossEntropyLoss
  35. from torch.optim import SGD
  36. from torch.utils.data import DataLoader
  37. from torchvision.transforms import ToTensor
  38. import argparse
  39. from openi.context import prepare, upload_openi
  40. import importlib.util
  41. def is_torch_dtu_available():
  42. if importlib.util.find_spec("torch_dtu") is None:
  43. return False
  44. if importlib.util.find_spec("torch_dtu.core") is None:
  45. return False
  46. return importlib.util.find_spec("torch_dtu.core.dtu_model") is not None
  47. # Training settings
  48. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  49. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  50. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  51. if __name__ == '__main__':
  52. #获取参数并忽略超参数报错
  53. args, unknown = parser.parse_known_args()
  54. #初始化导入数据集和预训练模型到容器内
  55. openi_context = prepare()
  56. #获取数据集路径,预训练模型路径,输出路径
  57. dataset_path = openi_context.dataset_path
  58. pretrain_model_path = openi_context.pretrain_model_path
  59. output_path = openi_context.output_path
  60. dataset_path_A = dataset_path + "/MnistDataset"
  61. pretrain_model_path_A = pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j"
  62. print("dataset_path:")
  63. print(os.listdir(dataset_path))
  64. os.listdir(dataset_path)
  65. print("pretrain_model_path:")
  66. print(os.listdir(pretrain_model_path))
  67. os.listdir(pretrain_model_path)
  68. print("output_path:")
  69. print(os.listdir(output_path))
  70. os.listdir(output_path)
  71. # load DPU envs-xx.sh
  72. DTU_FLAG = True
  73. if is_torch_dtu_available():
  74. import torch_dtu
  75. import torch_dtu.distributed as dist
  76. import torch_dtu.core.dtu_model as dm
  77. from torch_dtu.nn.parallel import DistributedDataParallel as torchDDP
  78. print('dtu is available: True')
  79. device = dm.dtu_device()
  80. DTU_FLAG = True
  81. else:
  82. print('dtu is available: False')
  83. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  84. DTU_FLAG = False
  85. # 参数声明
  86. model = Model().to(device)
  87. optimizer = SGD(model.parameters(), lr=1e-1)
  88. #log output
  89. batch_size = args.batch_size
  90. train_dataset = mnist.MNIST(root=dataset_path_A + "/train", train=True, transform=ToTensor(),download=False)
  91. test_dataset = mnist.MNIST(root=dataset_path_A + "/test", train=False, transform=ToTensor(),download=False)
  92. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  93. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  94. model = Model().to(device)
  95. sgd = SGD(model.parameters(), lr=1e-1)
  96. cost = CrossEntropyLoss()
  97. epochs = args.epoch_size
  98. print('epoch_size is:{}'.format(epochs))
  99. # 如果有保存的模型,则加载模型,并在其基础上继续训练
  100. if os.path.exists(pretrain_model_path_A+"/mnist_epoch1_0.70.pkl"):
  101. checkpoint = torch.load(pretrain_model_path_A+"/mnist_epoch1_0.70.pkl")
  102. model.load_state_dict(checkpoint['model'])
  103. optimizer.load_state_dict(checkpoint['optimizer'])
  104. start_epoch = checkpoint['epoch']
  105. print('加载 epoch {} 权重成功!'.format(start_epoch))
  106. else:
  107. start_epoch = 0
  108. print('无保存模型,将从头开始训练!')
  109. for _epoch in range(start_epoch, epochs):
  110. print('the {} epoch_size begin'.format(_epoch + 1))
  111. model.train()
  112. for idx, (train_x, train_label) in enumerate(train_loader):
  113. train_x = train_x.to(device)
  114. train_label = train_label.to(device)
  115. label_np = np.zeros((train_label.shape[0], 10))
  116. sgd.zero_grad()
  117. predict_y = model(train_x.float())
  118. loss = cost(predict_y, train_label.long())
  119. if idx % 10 == 0:
  120. print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
  121. loss.backward()
  122. if DTU_FLAG:
  123. dm.optimizer_step(sgd, barrier=True)
  124. else:
  125. sgd.step()
  126. correct = 0
  127. _sum = 0
  128. model.eval()
  129. for idx, (test_x, test_label) in enumerate(test_loader):
  130. test_x = test_x
  131. test_label = test_label
  132. predict_y = model(test_x.to(device).float()).detach()
  133. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  134. label_np = test_label.numpy()
  135. _ = predict_ys == test_label
  136. correct += np.sum(_.numpy(), axis=-1)
  137. _sum += _.shape[0]
  138. print('accuracy: {:.2f}'.format(correct / _sum))
  139. #The model output location is placed under /tmp/output
  140. state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':_epoch+1}
  141. torch.save(state, '/tmp/output/mnist_epoch{}_{:.2f}.pkl'.format(_epoch+1, correct / _sum))
  142. print('test:')
  143. print(os.listdir("/tmp/output"))

No Description