You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.7 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #!/usr/bin/python
  2. #coding=utf-8
  3. '''
  4. If there are Chinese comments in the code,please add at the beginning:
  5. #!/usr/bin/python
  6. #coding=utf-8
  7. 示例选用的数据集是MnistDataset_torch.zip
  8. 数据集结构是:
  9. MnistDataset_torch.zip
  10. ├── test
  11. │ ├── MNIST/processed/test.pt
  12. │ └── MNIST/processed/training.pt
  13. │ ├── MNIST/raw/train-images-idx3-ubyte
  14. │ └── MNIST/raw/train-labels-idx1-ubyte
  15. │ ├── MNIST/raw/t10k-images-idx3-ubyte
  16. │ └── MNIST/raw/t10k-labels-idx1-ubyte
  17. ├── train
  18. │ ├── MNIST/processed/test.pt
  19. │ └── MNIST/processed/training.pt
  20. │ ├── MNIST/raw/train-images-idx3-ubyte
  21. │ └── MNIST/raw/train-labels-idx1-ubyte
  22. │ ├── MNIST/raw/t10k-images-idx3-ubyte
  23. │ └── MNIST/raw/t10k-labels-idx1-ubyte
  24. '''
  25. import torch
  26. from model import Model
  27. import numpy as np
  28. from torchvision.datasets import mnist
  29. from torch.nn import CrossEntropyLoss
  30. from torch.optim import SGD
  31. from torch.utils.data import DataLoader
  32. from torchvision.transforms import ToTensor
  33. import argparse
  34. import os
  35. #导入c2net包
  36. from c2net.context import prepare, upload_output
  37. import importlib.util
  38. def is_torch_dtu_available():
  39. if importlib.util.find_spec("torch_dtu") is None:
  40. return False
  41. if importlib.util.find_spec("torch_dtu.core") is None:
  42. return False
  43. return importlib.util.find_spec("torch_dtu.core.dtu_model") is not None
  44. # Training settings
  45. parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  46. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  47. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  48. if __name__ == '__main__':
  49. args, unknown = parser.parse_known_args()
  50. #初始化导入数据集和预训练模型到容器内
  51. c2net_context = prepare()
  52. #获取数据集路径
  53. MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch"
  54. #获取预训练模型路径
  55. Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Torch_MNIST_Example_Model"
  56. #获取输出路径
  57. output_path = c2net_context.output_path
  58. # load DPU envs-xx.sh
  59. DTU_FLAG = True
  60. if is_torch_dtu_available():
  61. import torch_dtu
  62. import torch_dtu.distributed as dist
  63. import torch_dtu.core.dtu_model as dm
  64. from torch_dtu.nn.parallel import DistributedDataParallel as torchDDP
  65. print('dtu is available: True')
  66. device = dm.dtu_device()
  67. DTU_FLAG = True
  68. else:
  69. print('dtu is available: False')
  70. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  71. DTU_FLAG = False
  72. # 参数声明
  73. model = Model().to(device)
  74. optimizer = SGD(model.parameters(), lr=1e-1)
  75. args, unknown = parser.parse_known_args()
  76. #log output
  77. batch_size = args.batch_size
  78. train_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch_path, "train"), train=True, transform=ToTensor(),download=False)
  79. test_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch_path, "test"), train=False, transform=ToTensor(),download=False)
  80. train_loader = DataLoader(train_dataset, batch_size=batch_size)
  81. test_loader = DataLoader(test_dataset, batch_size=batch_size)
  82. model = Model().to(device)
  83. sgd = SGD(model.parameters(), lr=1e-1)
  84. cost = CrossEntropyLoss()
  85. epochs = args.epoch_size
  86. print('epoch_size is:{}'.format(epochs))
  87. # 如果有保存的模型,则加载模型,并在其基础上继续训练
  88. if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.76.pkl")):
  89. checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.76.pkl"))
  90. model.load_state_dict(checkpoint['model'])
  91. optimizer.load_state_dict(checkpoint['optimizer'])
  92. start_epoch = checkpoint['epoch']
  93. print('加载 epoch {} 权重成功!'.format(start_epoch))
  94. else:
  95. start_epoch = 0
  96. print('无保存模型,将从头开始训练!')
  97. for _epoch in range(start_epoch, epochs):
  98. print('the {} epoch_size begin'.format(_epoch + 1))
  99. model.train()
  100. for idx, (train_x, train_label) in enumerate(train_loader):
  101. train_x = train_x.to(device)
  102. train_label = train_label.to(device)
  103. label_np = np.zeros((train_label.shape[0], 10))
  104. sgd.zero_grad()
  105. predict_y = model(train_x.float())
  106. loss = cost(predict_y, train_label.long())
  107. if idx % 10 == 0:
  108. print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
  109. loss.backward()
  110. if DTU_FLAG:
  111. dm.optimizer_step(sgd, barrier=True)
  112. else:
  113. sgd.step()
  114. correct = 0
  115. _sum = 0
  116. model.eval()
  117. for idx, (test_x, test_label) in enumerate(test_loader):
  118. test_x = test_x
  119. test_label = test_label
  120. predict_y = model(test_x.to(device).float()).detach()
  121. predict_ys = np.argmax(predict_y.cpu(), axis=-1)
  122. label_np = test_label.numpy()
  123. _ = predict_ys == test_label
  124. correct += np.sum(_.numpy(), axis=-1)
  125. _sum += _.shape[0]
  126. print('accuracy: {:.2f}'.format(correct / _sum))
  127. state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':_epoch+1}
  128. torch.save(state, '{}/mnist_epoch{}_{:.2f}.pkl'.format(output_path, _epoch+1, correct / _sum))
  129. print(os.listdir('{}'.format(output_path)))

No Description