You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_pytorch.py 1.8 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import torchvision
  2. import torch
  3. import argparse
  4. from torch.autograd import Variable
  5. import onnx
  6. print(torch.__version__)
  7. parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
  8. parser.add_argument('--model',
  9. type=str,
  10. help='path to training/inference dataset folder'
  11. )
  12. parser.add_argument('--n',
  13. type=int,
  14. default=256,
  15. help='batch size for input shape type'
  16. )
  17. parser.add_argument('--c',
  18. type=int,
  19. default=1,
  20. help='channel for input shape type'
  21. )
  22. parser.add_argument('--h',
  23. type=int,
  24. default=28,
  25. help='height for input shape type'
  26. )
  27. parser.add_argument('--w',
  28. type=int,
  29. default=28,
  30. help='width for input shape type'
  31. )
  32. if __name__ == "__main__":
  33. args = parser.parse_args()
  34. print('args:')
  35. print(args)
  36. model_file = '/dataset/' + args.model
  37. print(model_file)
  38. model = torch.load(model_file)
  39. print(model)
  40. print(type(model))
  41. for k, v in model.named_parameters():
  42. print("k:",k)
  43. print("v:",v.shape)
  44. suffix = args.model.rindex(".")
  45. out_file = '/model/' + args.model + ".onnx"
  46. if suffix!=-1 :
  47. out_file = '/model/' + args.model[0:suffix] + ".onnx"
  48. print(out_file)
  49. input_name = ['input']
  50. output_name = ['output']
  51. input = Variable(torch.randn(args.n, args.c, args.h, args.w))
  52. torch.onnx.export(model, input, out_file, input_names=input_name, output_names=output_name, verbose=True)