You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

8 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from torchvision.transforms import transforms
  2. from torchvision.datasets import CIFAR10
  3. from torch.utils.data import DataLoader
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from torch.optim import Adam
  8. from torch.autograd import Variable
  9. import argparse
  10. import time
  11. import shutil
  12. from c2net.context import prepare,upload_output
  13. import argparse
  14. parser = argparse.ArgumentParser(description='忽略超参数不存在的报错问题')
  15. #添加自定义参数
  16. parser.add_argument("--test")
  17. parser.add_argument('--epoch', type=int, default=1)
  18. parser.add_argument('--card', type=str, default='cuda:0')
  19. args = parser.parse_args()
  20. args, unknown = parser.parse_known_args()
  21. #初始化导入数据集和预训练模型到容器内
  22. c2net_context = prepare()
  23. codePath = c2net_context.code_path
  24. test = codePath + '/pytorch-cnn-cifar10-dcu' + '/test.py'
  25. #获取数据集路径
  26. cifar_10_python_path = c2net_context.dataset_path+"/"+"cifar-10-python"
  27. #输出结果必须保存在该目录
  28. outputPath = c2net_context.output_path
  29. #回传结果到openi,只有训练任务才能回传
  30. upload_output()
  31. class Network(nn.Module):
  32. def __init__(self):
  33. super(Network, self).__init__()
  34. self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1)
  35. self.bn1 = nn.BatchNorm2d(12)
  36. self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1)
  37. self.bn2 = nn.BatchNorm2d(12)
  38. self.pool = nn.MaxPool2d(2, 2)
  39. self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1)
  40. self.bn4 = nn.BatchNorm2d(24)
  41. self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1)
  42. self.bn5 = nn.BatchNorm2d(24)
  43. self.fc1 = nn.Linear(24 * 10 * 10, 10)
  44. def forward(self, input):
  45. output = F.relu(self.bn1(self.conv1(input)))
  46. output = F.relu(self.bn2(self.conv2(output)))
  47. output = self.pool(output)
  48. output = F.relu(self.bn4(self.conv4(output)))
  49. output = F.relu(self.bn5(self.conv5(output)))
  50. output = output.view(-1, 24 * 10 * 10)
  51. output = self.fc1(output)
  52. return output
  53. def saveModel():
  54. path = outputPath + '/' + 'test.pth'
  55. torch.save(model.state_dict(), path)
  56. zipfileName = outputPath + '/' + 'test_database'
  57. save_zipfile(zipfileName, outputPath)
  58. def testAccuracy(card):
  59. model.eval()
  60. accuracy = 0.0
  61. total = 0.0
  62. # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  63. device = torch.device(card)
  64. model.to(device)
  65. test_dataloader = Get_dataloader(False)
  66. with torch.no_grad():
  67. for data in test_dataloader:
  68. images, labels = data
  69. images = Variable(images.to(device))
  70. labels = Variable(labels.to(device))
  71. outputs = model(images)
  72. _, predicted = torch.max(outputs.data, 1)
  73. total += labels.size(0)
  74. accuracy += (predicted == labels).sum().item()
  75. accuracy = (100 * accuracy / total)
  76. return (accuracy)
  77. def train(num_epochs, card):
  78. # best_accuracy = 0.0
  79. # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  80. device = torch.device(card)
  81. model.to(device)
  82. train_dataloader = Get_dataloader(True)
  83. for epoch in range(num_epochs):
  84. running_loss = 0.0
  85. running_acc = 0.0
  86. for i, (images, labels) in enumerate(train_dataloader, 0):
  87. images = Variable(images.to(device))
  88. labels = Variable(labels.to(device))
  89. optimizer.zero_grad()
  90. outputs = model(images)
  91. loss = loss_fn(outputs, labels)
  92. loss.backward()
  93. optimizer.step()
  94. running_loss += loss.item()
  95. if i % 1000 == 999:
  96. print('[%d, %5d] loss: %.3f' %
  97. (epoch + 1, i + 1, running_loss / 1000))
  98. running_loss = 0.0
  99. accuracy = testAccuracy(card)
  100. print('For epoch', epoch + 1, 'the test accuracy over the whole test set is %d %%' % (accuracy))
  101. # saveModel()
  102. def Get_dataloader(train):
  103. transform_fn = transforms.Compose([
  104. transforms.ToTensor(),
  105. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  106. ])
  107. dataset = CIFAR10(root=DATA_ROOT, train=train, transform=transform_fn)
  108. data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=0)
  109. return data_loader
  110. def save_zipfile(filename, dest):
  111. shutil.make_archive(filename, 'zip', dest)
  112. DATA_ROOT = cifar_10_python_path
  113. classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  114. batch_size = 10
  115. number_of_labels = 10
  116. model = Network()
  117. loss_fn = nn.CrossEntropyLoss()
  118. optimizer = Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
  119. if __name__ == "__main__":
  120. accuracy = testAccuracy(args.card)
  121. print('before training, accuracy for test data is: ', accuracy)
  122. start = time.perf_counter()
  123. train(args.epoch, args.card)
  124. end = time.perf_counter()
  125. print(f"training completed in {end - start:0.4f} seconds")
  126. saveModel()

No Description