You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_head.py 14 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. from functools import partial
  4. import mmcv
  5. import numpy as np
  6. import pytest
  7. import torch
  8. from mmcv.cnn import Scale
  9. from mmdet import digit_version
  10. from mmdet.models import build_detector
  11. from mmdet.models.dense_heads import (FCOSHead, FSAFHead, RetinaHead, SSDHead,
  12. YOLOV3Head)
  13. from .utils import ort_validate
  14. data_path = osp.join(osp.dirname(__file__), 'data')
  15. if digit_version(torch.__version__) <= digit_version('1.5.0'):
  16. pytest.skip(
  17. 'ort backend does not support version below 1.5.0',
  18. allow_module_level=True)
  19. def test_cascade_onnx_export():
  20. config_path = './configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py'
  21. cfg = mmcv.Config.fromfile(config_path)
  22. model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
  23. with torch.no_grad():
  24. model.forward = partial(model.forward, img_metas=[[dict()]])
  25. dynamic_axes = {
  26. 'input_img': {
  27. 0: 'batch',
  28. 2: 'width',
  29. 3: 'height'
  30. },
  31. 'dets': {
  32. 0: 'batch',
  33. 1: 'num_dets',
  34. },
  35. 'labels': {
  36. 0: 'batch',
  37. 1: 'num_dets',
  38. },
  39. }
  40. torch.onnx.export(
  41. model, [torch.rand(1, 3, 400, 500)],
  42. 'tmp.onnx',
  43. output_names=['dets', 'labels'],
  44. input_names=['input_img'],
  45. keep_initializers_as_inputs=True,
  46. do_constant_folding=True,
  47. verbose=False,
  48. opset_version=11,
  49. dynamic_axes=dynamic_axes)
  50. def test_faster_onnx_export():
  51. config_path = './configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
  52. cfg = mmcv.Config.fromfile(config_path)
  53. model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
  54. with torch.no_grad():
  55. model.forward = partial(model.forward, img_metas=[[dict()]])
  56. dynamic_axes = {
  57. 'input_img': {
  58. 0: 'batch',
  59. 2: 'width',
  60. 3: 'height'
  61. },
  62. 'dets': {
  63. 0: 'batch',
  64. 1: 'num_dets',
  65. },
  66. 'labels': {
  67. 0: 'batch',
  68. 1: 'num_dets',
  69. },
  70. }
  71. torch.onnx.export(
  72. model, [torch.rand(1, 3, 400, 500)],
  73. 'tmp.onnx',
  74. output_names=['dets', 'labels'],
  75. input_names=['input_img'],
  76. keep_initializers_as_inputs=True,
  77. do_constant_folding=True,
  78. verbose=False,
  79. opset_version=11,
  80. dynamic_axes=dynamic_axes)
  81. def retinanet_config():
  82. """RetinanNet Head Config."""
  83. head_cfg = dict(
  84. stacked_convs=6,
  85. feat_channels=2,
  86. anchor_generator=dict(
  87. type='AnchorGenerator',
  88. octave_base_scale=4,
  89. scales_per_octave=3,
  90. ratios=[0.5, 1.0, 2.0],
  91. strides=[8, 16, 32, 64, 128]),
  92. bbox_coder=dict(
  93. type='DeltaXYWHBBoxCoder',
  94. target_means=[.0, .0, .0, .0],
  95. target_stds=[1.0, 1.0, 1.0, 1.0]))
  96. test_cfg = mmcv.Config(
  97. dict(
  98. deploy_nms_pre=0,
  99. min_bbox_size=0,
  100. score_thr=0.05,
  101. nms=dict(type='nms', iou_threshold=0.5),
  102. max_per_img=100))
  103. model = RetinaHead(
  104. num_classes=4, in_channels=1, test_cfg=test_cfg, **head_cfg)
  105. model.requires_grad_(False)
  106. return model
  107. def test_retina_head_forward_single():
  108. """Test RetinaNet Head single forward in torch and onnxruntime env."""
  109. retina_model = retinanet_config()
  110. feat = torch.rand(1, retina_model.in_channels, 32, 32)
  111. # validate the result between the torch and ort
  112. ort_validate(retina_model.forward_single, feat)
  113. def test_retina_head_forward():
  114. """Test RetinaNet Head forward in torch and onnxruntime env."""
  115. retina_model = retinanet_config()
  116. s = 128
  117. # RetinaNet head expects a multiple levels of features per image
  118. feats = [
  119. torch.rand(1, retina_model.in_channels, s // (2**(i + 2)),
  120. s // (2**(i + 2))) # [32, 16, 8, 4, 2]
  121. for i in range(len(retina_model.prior_generator.strides))
  122. ]
  123. ort_validate(retina_model.forward, feats)
  124. def test_retinanet_head_onnx_export():
  125. """Test RetinaNet Head _get_bboxes() in torch and onnxruntime env."""
  126. retina_model = retinanet_config()
  127. s = 128
  128. img_metas = [{
  129. 'img_shape_for_onnx': torch.Tensor([s, s]),
  130. 'scale_factor': np.ones(4),
  131. 'pad_shape': (s, s, 3),
  132. 'img_shape': (s, s, 2)
  133. }]
  134. # The data of retina_head_get_bboxes.pkl contains two parts:
  135. # cls_score(list(Tensor)) and bboxes(list(Tensor)),
  136. # where each torch.Tensor is generated by torch.rand().
  137. # the cls_score's size: (1, 36, 32, 32), (1, 36, 16, 16),
  138. # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2).
  139. # the bboxes's size: (1, 36, 32, 32), (1, 36, 16, 16),
  140. # (1, 36, 8, 8), (1, 36, 4, 4), (1, 36, 2, 2)
  141. retina_head_data = 'retina_head_get_bboxes.pkl'
  142. feats = mmcv.load(osp.join(data_path, retina_head_data))
  143. cls_score = feats[:5]
  144. bboxes = feats[5:]
  145. retina_model.onnx_export = partial(
  146. retina_model.onnx_export, img_metas=img_metas, with_nms=False)
  147. ort_validate(retina_model.onnx_export, (cls_score, bboxes))
  148. def yolo_config():
  149. """YoloV3 Head Config."""
  150. head_cfg = dict(
  151. anchor_generator=dict(
  152. type='YOLOAnchorGenerator',
  153. base_sizes=[[(116, 90), (156, 198), (373, 326)],
  154. [(30, 61), (62, 45), (59, 119)],
  155. [(10, 13), (16, 30), (33, 23)]],
  156. strides=[32, 16, 8]),
  157. bbox_coder=dict(type='YOLOBBoxCoder'))
  158. test_cfg = mmcv.Config(
  159. dict(
  160. deploy_nms_pre=0,
  161. min_bbox_size=0,
  162. score_thr=0.05,
  163. conf_thr=0.005,
  164. nms=dict(type='nms', iou_threshold=0.45),
  165. max_per_img=100))
  166. model = YOLOV3Head(
  167. num_classes=4,
  168. in_channels=[1, 1, 1],
  169. out_channels=[16, 8, 4],
  170. test_cfg=test_cfg,
  171. **head_cfg)
  172. model.requires_grad_(False)
  173. # yolov3 need eval()
  174. model.cpu().eval()
  175. return model
  176. def test_yolov3_head_forward():
  177. """Test Yolov3 head forward() in torch and ort env."""
  178. yolo_model = yolo_config()
  179. # Yolov3 head expects a multiple levels of features per image
  180. feats = [
  181. torch.rand(1, 1, 64 // (2**(i + 2)), 64 // (2**(i + 2)))
  182. for i in range(len(yolo_model.in_channels))
  183. ]
  184. ort_validate(yolo_model.forward, feats)
  185. def test_yolov3_head_onnx_export():
  186. """Test yolov3 head get_bboxes() in torch and ort env."""
  187. yolo_model = yolo_config()
  188. s = 128
  189. img_metas = [{
  190. 'img_shape_for_onnx': torch.Tensor([s, s]),
  191. 'img_shape': (s, s, 3),
  192. 'scale_factor': np.ones(4),
  193. 'pad_shape': (s, s, 3)
  194. }]
  195. # The data of yolov3_head_get_bboxes.pkl contains
  196. # a list of torch.Tensor, where each torch.Tensor
  197. # is generated by torch.rand and each tensor size is:
  198. # (1, 27, 32, 32), (1, 27, 16, 16), (1, 27, 8, 8).
  199. yolo_head_data = 'yolov3_head_get_bboxes.pkl'
  200. pred_maps = mmcv.load(osp.join(data_path, yolo_head_data))
  201. yolo_model.onnx_export = partial(
  202. yolo_model.onnx_export, img_metas=img_metas, with_nms=False)
  203. ort_validate(yolo_model.onnx_export, pred_maps)
  204. def fcos_config():
  205. """FCOS Head Config."""
  206. test_cfg = mmcv.Config(
  207. dict(
  208. deploy_nms_pre=0,
  209. min_bbox_size=0,
  210. score_thr=0.05,
  211. nms=dict(type='nms', iou_threshold=0.5),
  212. max_per_img=100))
  213. model = FCOSHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
  214. model.requires_grad_(False)
  215. return model
  216. def test_fcos_head_forward_single():
  217. """Test fcos forward single in torch and ort env."""
  218. fcos_model = fcos_config()
  219. feat = torch.rand(1, fcos_model.in_channels, 32, 32)
  220. fcos_model.forward_single = partial(
  221. fcos_model.forward_single,
  222. scale=Scale(1.0).requires_grad_(False),
  223. stride=(4, ))
  224. ort_validate(fcos_model.forward_single, feat)
  225. def test_fcos_head_forward():
  226. """Test fcos forward in mutil-level feature map."""
  227. fcos_model = fcos_config()
  228. s = 128
  229. feats = [
  230. torch.rand(1, 1, s // feat_size, s // feat_size)
  231. for feat_size in [4, 8, 16, 32, 64]
  232. ]
  233. ort_validate(fcos_model.forward, feats)
  234. def test_fcos_head_onnx_export():
  235. """Test fcos head get_bboxes() in ort."""
  236. fcos_model = fcos_config()
  237. s = 128
  238. img_metas = [{
  239. 'img_shape_for_onnx': torch.Tensor([s, s]),
  240. 'img_shape': (s, s, 3),
  241. 'scale_factor': np.ones(4),
  242. 'pad_shape': (s, s, 3)
  243. }]
  244. cls_scores = [
  245. torch.rand(1, fcos_model.num_classes, s // feat_size, s // feat_size)
  246. for feat_size in [4, 8, 16, 32, 64]
  247. ]
  248. bboxes = [
  249. torch.rand(1, 4, s // feat_size, s // feat_size)
  250. for feat_size in [4, 8, 16, 32, 64]
  251. ]
  252. centerness = [
  253. torch.rand(1, 1, s // feat_size, s // feat_size)
  254. for feat_size in [4, 8, 16, 32, 64]
  255. ]
  256. fcos_model.onnx_export = partial(
  257. fcos_model.onnx_export, img_metas=img_metas, with_nms=False)
  258. ort_validate(fcos_model.onnx_export, (cls_scores, bboxes, centerness))
  259. def fsaf_config():
  260. """FSAF Head Config."""
  261. cfg = dict(
  262. anchor_generator=dict(
  263. type='AnchorGenerator',
  264. octave_base_scale=1,
  265. scales_per_octave=1,
  266. ratios=[1.0],
  267. strides=[8, 16, 32, 64, 128]))
  268. test_cfg = mmcv.Config(
  269. dict(
  270. deploy_nms_pre=0,
  271. min_bbox_size=0,
  272. score_thr=0.05,
  273. nms=dict(type='nms', iou_threshold=0.5),
  274. max_per_img=100))
  275. model = FSAFHead(num_classes=4, in_channels=1, test_cfg=test_cfg, **cfg)
  276. model.requires_grad_(False)
  277. return model
  278. def test_fsaf_head_forward_single():
  279. """Test RetinaNet Head forward_single() in torch and onnxruntime env."""
  280. fsaf_model = fsaf_config()
  281. feat = torch.rand(1, fsaf_model.in_channels, 32, 32)
  282. ort_validate(fsaf_model.forward_single, feat)
  283. def test_fsaf_head_forward():
  284. """Test RetinaNet Head forward in torch and onnxruntime env."""
  285. fsaf_model = fsaf_config()
  286. s = 128
  287. feats = [
  288. torch.rand(1, fsaf_model.in_channels, s // (2**(i + 2)),
  289. s // (2**(i + 2)))
  290. for i in range(len(fsaf_model.anchor_generator.strides))
  291. ]
  292. ort_validate(fsaf_model.forward, feats)
  293. def test_fsaf_head_onnx_export():
  294. """Test RetinaNet Head get_bboxes in torch and onnxruntime env."""
  295. fsaf_model = fsaf_config()
  296. s = 256
  297. img_metas = [{
  298. 'img_shape_for_onnx': torch.Tensor([s, s]),
  299. 'scale_factor': np.ones(4),
  300. 'pad_shape': (s, s, 3),
  301. 'img_shape': (s, s, 2)
  302. }]
  303. # The data of fsaf_head_get_bboxes.pkl contains two parts:
  304. # cls_score(list(Tensor)) and bboxes(list(Tensor)),
  305. # where each torch.Tensor is generated by torch.rand().
  306. # the cls_score's size: (1, 4, 64, 64), (1, 4, 32, 32),
  307. # (1, 4, 16, 16), (1, 4, 8, 8), (1, 4, 4, 4).
  308. # the bboxes's size: (1, 4, 64, 64), (1, 4, 32, 32),
  309. # (1, 4, 16, 16), (1, 4, 8, 8), (1, 4, 4, 4).
  310. fsaf_head_data = 'fsaf_head_get_bboxes.pkl'
  311. feats = mmcv.load(osp.join(data_path, fsaf_head_data))
  312. cls_score = feats[:5]
  313. bboxes = feats[5:]
  314. fsaf_model.onnx_export = partial(
  315. fsaf_model.onnx_export, img_metas=img_metas, with_nms=False)
  316. ort_validate(fsaf_model.onnx_export, (cls_score, bboxes))
  317. def ssd_config():
  318. """SSD Head Config."""
  319. cfg = dict(
  320. anchor_generator=dict(
  321. type='SSDAnchorGenerator',
  322. scale_major=False,
  323. input_size=300,
  324. basesize_ratio_range=(0.15, 0.9),
  325. strides=[8, 16, 32, 64, 100, 300],
  326. ratios=[[2], [2, 3], [2, 3], [2, 3], [2], [2]]),
  327. bbox_coder=dict(
  328. type='DeltaXYWHBBoxCoder',
  329. target_means=[.0, .0, .0, .0],
  330. target_stds=[0.1, 0.1, 0.2, 0.2]))
  331. test_cfg = mmcv.Config(
  332. dict(
  333. deploy_nms_pre=0,
  334. nms=dict(type='nms', iou_threshold=0.45),
  335. min_bbox_size=0,
  336. score_thr=0.02,
  337. max_per_img=200))
  338. model = SSDHead(
  339. num_classes=4,
  340. in_channels=(4, 8, 4, 2, 2, 2),
  341. test_cfg=test_cfg,
  342. **cfg)
  343. model.requires_grad_(False)
  344. return model
  345. def test_ssd_head_forward():
  346. """Test SSD Head forward in torch and onnxruntime env."""
  347. ssd_model = ssd_config()
  348. featmap_size = [38, 19, 10, 6, 5, 3, 1]
  349. feats = [
  350. torch.rand(1, ssd_model.in_channels[i], featmap_size[i],
  351. featmap_size[i]) for i in range(len(ssd_model.in_channels))
  352. ]
  353. ort_validate(ssd_model.forward, feats)
  354. def test_ssd_head_onnx_export():
  355. """Test SSD Head get_bboxes in torch and onnxruntime env."""
  356. ssd_model = ssd_config()
  357. s = 300
  358. img_metas = [{
  359. 'img_shape_for_onnx': torch.Tensor([s, s]),
  360. 'scale_factor': np.ones(4),
  361. 'pad_shape': (s, s, 3),
  362. 'img_shape': (s, s, 2)
  363. }]
  364. # The data of ssd_head_get_bboxes.pkl contains two parts:
  365. # cls_score(list(Tensor)) and bboxes(list(Tensor)),
  366. # where each torch.Tensor is generated by torch.rand().
  367. # the cls_score's size: (1, 20, 38, 38), (1, 30, 19, 19),
  368. # (1, 30, 10, 10), (1, 30, 5, 5), (1, 20, 3, 3), (1, 20, 1, 1).
  369. # the bboxes's size: (1, 16, 38, 38), (1, 24, 19, 19),
  370. # (1, 24, 10, 10), (1, 24, 5, 5), (1, 16, 3, 3), (1, 16, 1, 1).
  371. ssd_head_data = 'ssd_head_get_bboxes.pkl'
  372. feats = mmcv.load(osp.join(data_path, ssd_head_data))
  373. cls_score = feats[:6]
  374. bboxes = feats[6:]
  375. ssd_model.onnx_export = partial(
  376. ssd_model.onnx_export, img_metas=img_metas, with_nms=False)
  377. ort_validate(ssd_model.onnx_export, (cls_score, bboxes))

No Description

Contributors (1)