import torchvision import torch import argparse from torch.autograd import Variable import onnx print(torch.__version__) parser = argparse.ArgumentParser(description='MindSpore Lenet Example') parser.add_argument('--model', type=str, help='path to training/inference dataset folder' ) parser.add_argument('--n', type=int, default=256, help='batch size for input shape type' ) parser.add_argument('--c', type=int, default=1, help='channel for input shape type' ) parser.add_argument('--h', type=int, default=28, help='height for input shape type' ) parser.add_argument('--w', type=int, default=28, help='width for input shape type' ) if __name__ == "__main__": args = parser.parse_args() print('args:') print(args) model_file = '/dataset/' + args.model print(model_file) model = torch.load(model_file) print(model) print(type(model)) for k, v in model.named_parameters(): print("k:",k) print("v:",v.shape) suffix = args.model.rindex(".") out_file = '/model/' + args.model + ".onnx" if suffix!=-1 : out_file = '/model/' + args.model[0:suffix] + ".onnx" print(out_file) input_name = ['input'] output_name = ['output'] input = Variable(torch.randn(args.n, args.c, args.h, args.w)) torch.onnx.export(model, input, out_file, input_names=input_name, output_names=output_name, verbose=True)