You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gpu_train_resnet50.py 1.4 kB

3 years ago
123456789101112131415161718192021222324252627282930
  1. '''
  2. 由于a100的适配性问题,使用训练环境前请使用平台的含有cuda11以上的推荐镜像在调试环境中调试自己的代码,
  3. 本示例的镜像地址是dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191,并
  4. 提交镜像,再切到训练环境训练已跑通的代码。
  5. 在训练环境中,上传的数据集会自动放在/dataset目录下,模型下载路径默认在/model下,请将模型输出位置指定到/model,
  6. 启智平台界面会提供/model目录下的文件下载。
  7. '''
  8. import torchvision
  9. from torch.autograd import Variable
  10. import torch
  11. import argparse
  12. # Training settings
  13. parser = argparse.ArgumentParser(description='Resnet50 Example')
  14. #数据集位置放在/dataset下
  15. parser.add_argument('--traindata', default="/dataset/train" ,help='path to train dataset')
  16. parser.add_argument('--testdata', default="/dataset/test" ,help='path to test dataset')
  17. parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
  18. parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
  19. if __name__ == '__main__':
  20. input_name = ['input']
  21. output_name = ['output']
  22. input = Variable(torch.randn(1, 3, 224, 224)).cuda()
  23. model = torchvision.models.resnet50(pretrained=True).cuda()
  24. #模型输出位置放在/model下
  25. torch.save(model, '/model/resnet50.pth')