You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pytorch2onnx.py 12 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import os.path as osp
  4. import warnings
  5. from functools import partial
  6. import sys
  7. sys.path.append("/home/shanwei-luo/userdata/mmdetection")
  8. import numpy as np
  9. #import onnx
  10. import torch
  11. from mmcv import Config, DictAction
  12. from mmdet.core.export import build_model_from_cfg, preprocess_example_input
  13. from mmdet.core.export.model_wrappers import ONNXRuntimeDetector
  14. import onnxruntime as ort
  15. import onnx
  16. print(f"onnxruntime device: {ort.get_device()}") # output: GPU
  17. print(f'ort avail providers: {ort.get_available_providers()}') # output: ['CUDAExecutionProvider', 'CPUExecutionProvider']
  18. def pytorch2onnx(model,
  19. input_img,
  20. input_shape,
  21. normalize_cfg,
  22. opset_version=11,
  23. show=False,
  24. output_file='tmp.onnx',
  25. verify=False,
  26. test_img=None,
  27. do_simplify=False,
  28. dynamic_export=None,
  29. skip_postprocess=False):
  30. input_config = {
  31. 'input_shape': input_shape,
  32. 'input_path': input_img,
  33. 'normalize_cfg': normalize_cfg
  34. }
  35. # prepare input
  36. one_img, one_meta = preprocess_example_input(input_config)
  37. img_list, img_meta_list = [one_img], [[one_meta]]
  38. if skip_postprocess:
  39. warnings.warn('Not all models support export onnx without post '
  40. 'process, especially two stage detectors!')
  41. model.forward = model.forward_dummy
  42. torch.onnx.export(
  43. model,
  44. one_img,
  45. output_file,
  46. input_names=['input'],
  47. export_params=True,
  48. keep_initializers_as_inputs=True,
  49. do_constant_folding=True,
  50. verbose=show,
  51. opset_version=opset_version)
  52. print(f'Successfully exported ONNX model without '
  53. f'post process: {output_file}')
  54. return
  55. # replace original forward function
  56. origin_forward = model.forward
  57. model.forward = partial(
  58. model.forward,
  59. img_metas=img_meta_list,
  60. return_loss=False,
  61. rescale=False)
  62. output_names = ['dets', 'labels']
  63. if model.with_mask:
  64. output_names.append('masks')
  65. input_name = 'input'
  66. dynamic_axes = None
  67. if dynamic_export:
  68. dynamic_axes = {
  69. input_name: {
  70. 0: 'batch',
  71. 2: 'height',
  72. 3: 'width'
  73. },
  74. 'dets': {
  75. 0: 'batch',
  76. 1: 'num_dets',
  77. },
  78. 'labels': {
  79. 0: 'batch',
  80. 1: 'num_dets',
  81. },
  82. }
  83. if model.with_mask:
  84. dynamic_axes['masks'] = {0: 'batch', 1: 'num_dets'}
  85. torch.onnx.export(
  86. model,
  87. img_list,
  88. output_file,
  89. input_names=[input_name],
  90. output_names=output_names,
  91. export_params=True,
  92. keep_initializers_as_inputs=True,
  93. do_constant_folding=True,
  94. verbose=show,
  95. opset_version=opset_version,
  96. dynamic_axes=dynamic_axes)
  97. model.forward = origin_forward
  98. # get the custom op path
  99. ort_custom_op_path = ''
  100. try:
  101. from mmcv.ops import get_onnxruntime_op_path
  102. ort_custom_op_path = get_onnxruntime_op_path()
  103. except (ImportError, ModuleNotFoundError):
  104. warnings.warn('If input model has custom op from mmcv, \
  105. you may have to build mmcv with ONNXRuntime from source.')
  106. if do_simplify:
  107. import onnxsim
  108. from mmdet import digit_version
  109. min_required_version = '0.3.0'
  110. assert digit_version(onnxsim.__version__) >= digit_version(
  111. min_required_version
  112. ), f'Requires to install onnx-simplify>={min_required_version}'
  113. input_dic = {'input': img_list[0].detach().cpu().numpy()}
  114. model_opt, check_ok = onnxsim.simplify(
  115. output_file,
  116. input_data=input_dic,
  117. custom_lib=ort_custom_op_path,
  118. dynamic_input_shape=dynamic_export)
  119. if check_ok:
  120. onnx.save(model_opt, output_file)
  121. print(f'Successfully simplified ONNX model: {output_file}')
  122. else:
  123. warnings.warn('Failed to simplify ONNX model.')
  124. print(f'Successfully exported ONNX model: {output_file}')
  125. if verify:
  126. # check by onnx
  127. onnx_model = onnx.load(output_file)
  128. onnx.checker.check_model(onnx_model)
  129. # wrap onnx model
  130. onnx_model = ONNXRuntimeDetector(output_file, model.CLASSES, 0)
  131. if dynamic_export:
  132. # scale up to test dynamic shape
  133. h, w = [int((_ * 1.5) // 32 * 32) for _ in input_shape[2:]]
  134. h, w = min(1344, h), min(1344, w)
  135. input_config['input_shape'] = (1, 3, h, w)
  136. if test_img is None:
  137. input_config['input_path'] = input_img
  138. # prepare input once again
  139. one_img, one_meta = preprocess_example_input(input_config)
  140. img_list, img_meta_list = [one_img], [[one_meta]]
  141. # get pytorch output
  142. with torch.no_grad():
  143. pytorch_results = model(
  144. img_list,
  145. img_metas=img_meta_list,
  146. return_loss=False,
  147. rescale=True)[0]
  148. img_list = [_.cuda().contiguous() for _ in img_list]
  149. if dynamic_export:
  150. img_list = img_list + [_.flip(-1).contiguous() for _ in img_list]
  151. img_meta_list = img_meta_list * 2
  152. # get onnx output
  153. onnx_results = onnx_model(
  154. img_list, img_metas=img_meta_list, return_loss=False)[0]
  155. # visualize predictions
  156. score_thr = 0.3
  157. if show:
  158. out_file_ort, out_file_pt = None, None
  159. else:
  160. out_file_ort, out_file_pt = 'show-ort.png', 'show-pt.png'
  161. show_img = one_meta['show_img']
  162. model.show_result(
  163. show_img,
  164. pytorch_results,
  165. score_thr=score_thr,
  166. show=True,
  167. win_name='PyTorch',
  168. out_file=out_file_pt)
  169. onnx_model.show_result(
  170. show_img,
  171. onnx_results,
  172. score_thr=score_thr,
  173. show=True,
  174. win_name='ONNXRuntime',
  175. out_file=out_file_ort)
  176. # compare a part of result
  177. print(input_config['input_shape'])
  178. print(one_img)
  179. print(len(onnx_results))
  180. print(len(pytorch_results))
  181. print(onnx_results)
  182. print(pytorch_results)
  183. for i in range(len(onnx_results)):
  184. print(onnx_results[i].shape)
  185. for i in range(len(pytorch_results)):
  186. print(pytorch_results[i].shape)
  187. if model.with_mask:
  188. compare_pairs = list(zip(onnx_results, pytorch_results))
  189. else:
  190. compare_pairs = [(onnx_results, pytorch_results)]
  191. err_msg = 'The numerical values are different between Pytorch' + \
  192. ' and ONNX, but it does not necessarily mean the' + \
  193. ' exported ONNX model is problematic.'
  194. # check the numerical value
  195. for onnx_res, pytorch_res in compare_pairs:
  196. for o_res, p_res in zip(onnx_res, pytorch_res):
  197. np.testing.assert_allclose(
  198. o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg)
  199. print('The numerical values are the same between Pytorch and ONNX')
  200. def parse_normalize_cfg(test_pipeline):
  201. transforms = None
  202. for pipeline in test_pipeline:
  203. if 'transforms' in pipeline:
  204. transforms = pipeline['transforms']
  205. break
  206. assert transforms is not None, 'Failed to find `transforms`'
  207. norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize']
  208. assert len(norm_config_li) == 1, '`norm_config` should only have one'
  209. norm_config = norm_config_li[0]
  210. return norm_config
  211. def parse_args():
  212. parser = argparse.ArgumentParser(
  213. description='Convert MMDetection models to ONNX')
  214. parser.add_argument('config', help='test config file path')
  215. parser.add_argument('checkpoint', help='checkpoint file')
  216. parser.add_argument('--input-img', type=str, help='Images for input')
  217. parser.add_argument(
  218. '--show',
  219. action='store_true',
  220. help='Show onnx graph and detection outputs')
  221. parser.add_argument('--output-file', type=str, default='tmp.onnx')
  222. parser.add_argument('--opset-version', type=int, default=11)
  223. parser.add_argument(
  224. '--test-img', type=str, default=None, help='Images for test')
  225. parser.add_argument(
  226. '--dataset',
  227. type=str,
  228. default='coco',
  229. help='Dataset name. This argument is deprecated and will be removed \
  230. in future releases.')
  231. parser.add_argument(
  232. '--verify',
  233. action='store_true',
  234. help='verify the onnx model output against pytorch output')
  235. parser.add_argument(
  236. '--simplify',
  237. action='store_true',
  238. help='Whether to simplify onnx model.')
  239. parser.add_argument(
  240. '--shape',
  241. type=int,
  242. nargs='+',
  243. default=[800, 1216],
  244. help='input image size')
  245. parser.add_argument(
  246. '--mean',
  247. type=float,
  248. nargs='+',
  249. default=[123.675, 116.28, 103.53],
  250. help='mean value used for preprocess input data.This argument \
  251. is deprecated and will be removed in future releases.')
  252. parser.add_argument(
  253. '--std',
  254. type=float,
  255. nargs='+',
  256. default=[58.395, 57.12, 57.375],
  257. help='variance value used for preprocess input data. '
  258. 'This argument is deprecated and will be removed in future releases.')
  259. parser.add_argument(
  260. '--cfg-options',
  261. nargs='+',
  262. action=DictAction,
  263. help='Override some settings in the used config, the key-value pair '
  264. 'in xxx=yyy format will be merged into config file. If the value to '
  265. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  266. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  267. 'Note that the quotation marks are necessary and that no white space '
  268. 'is allowed.')
  269. parser.add_argument(
  270. '--dynamic-export',
  271. action='store_true',
  272. help='Whether to export onnx with dynamic axis.')
  273. parser.add_argument(
  274. '--skip-postprocess',
  275. action='store_true',
  276. help='Whether to export model without post process. Experimental '
  277. 'option. We do not guarantee the correctness of the exported '
  278. 'model.')
  279. args = parser.parse_args()
  280. return args
  281. if __name__ == '__main__':
  282. args = parse_args()
  283. warnings.warn('Arguments like `--mean`, `--std`, `--dataset` would be \
  284. parsed directly from config file and are deprecated and \
  285. will be removed in future releases.')
  286. assert args.opset_version == 11, 'MMDet only support opset 11 now'
  287. try:
  288. from mmcv.onnx.symbolic import register_extra_symbolics
  289. except ModuleNotFoundError:
  290. raise NotImplementedError('please update mmcv to version>=v1.0.4')
  291. register_extra_symbolics(args.opset_version)
  292. cfg = Config.fromfile(args.config)
  293. if args.cfg_options is not None:
  294. cfg.merge_from_dict(args.cfg_options)
  295. if args.shape is None:
  296. img_scale = cfg.test_pipeline[1]['img_scale']
  297. input_shape = (1, 3, img_scale[1], img_scale[0])
  298. elif len(args.shape) == 1:
  299. input_shape = (1, 3, args.shape[0], args.shape[0])
  300. elif len(args.shape) == 2:
  301. input_shape = (1, 3) + tuple(args.shape)
  302. else:
  303. raise ValueError('invalid input shape')
  304. # build the model and load checkpoint
  305. model = build_model_from_cfg(args.config, args.checkpoint,
  306. args.cfg_options)
  307. if not args.input_img:
  308. args.input_img = osp.join(osp.dirname(__file__), '../../demo/demo.jpg')
  309. normalize_cfg = parse_normalize_cfg(cfg.test_pipeline)
  310. # convert model to onnx file
  311. pytorch2onnx(
  312. model,
  313. args.input_img,
  314. input_shape,
  315. normalize_cfg,
  316. opset_version=args.opset_version,
  317. show=args.show,
  318. output_file=args.output_file,
  319. verify=args.verify,
  320. test_img=args.test_img,
  321. do_simplify=args.simplify,
  322. dynamic_export=args.dynamic_export,
  323. skip_postprocess=args.skip_postprocess)

No Description

Contributors (1)