You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 5.5 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import mmcv
  4. from mmcv import Config, DictAction
  5. from mmcv.parallel import MMDataParallel
  6. from mmdet.apis import single_gpu_test
  7. from mmdet.datasets import (build_dataloader, build_dataset,
  8. replace_ImageToTensor)
  9. def parse_args():
  10. parser = argparse.ArgumentParser(
  11. description='MMDet test (and eval) an ONNX model using ONNXRuntime')
  12. parser.add_argument('config', help='test config file path')
  13. parser.add_argument('model', help='Input model file')
  14. parser.add_argument('--out', help='output result file in pickle format')
  15. parser.add_argument(
  16. '--format-only',
  17. action='store_true',
  18. help='Format the output results without perform evaluation. It is'
  19. 'useful when you want to format the result to a specific format and '
  20. 'submit it to the test server')
  21. parser.add_argument(
  22. '--backend',
  23. required=True,
  24. choices=['onnxruntime', 'tensorrt'],
  25. help='Backend for input model to run. ')
  26. parser.add_argument(
  27. '--eval',
  28. type=str,
  29. nargs='+',
  30. help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
  31. ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
  32. parser.add_argument('--show', action='store_true', help='show results')
  33. parser.add_argument(
  34. '--show-dir', help='directory where painted images will be saved')
  35. parser.add_argument(
  36. '--show-score-thr',
  37. type=float,
  38. default=0.3,
  39. help='score threshold (default: 0.3)')
  40. parser.add_argument(
  41. '--cfg-options',
  42. nargs='+',
  43. action=DictAction,
  44. help='override some settings in the used config, the key-value pair '
  45. 'in xxx=yyy format will be merged into config file. If the value to '
  46. 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
  47. 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
  48. 'Note that the quotation marks are necessary and that no white space '
  49. 'is allowed.')
  50. parser.add_argument(
  51. '--eval-options',
  52. nargs='+',
  53. action=DictAction,
  54. help='custom options for evaluation, the key-value pair in xxx=yyy '
  55. 'format will be kwargs for dataset.evaluate() function')
  56. args = parser.parse_args()
  57. return args
  58. def main():
  59. args = parse_args()
  60. assert args.out or args.eval or args.format_only or args.show \
  61. or args.show_dir, \
  62. ('Please specify at least one operation (save/eval/format/show the '
  63. 'results / save the results) with the argument "--out", "--eval"'
  64. ', "--format-only", "--show" or "--show-dir"')
  65. if args.eval and args.format_only:
  66. raise ValueError('--eval and --format_only cannot be both specified')
  67. if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
  68. raise ValueError('The output file must be a pkl file.')
  69. cfg = Config.fromfile(args.config)
  70. if args.cfg_options is not None:
  71. cfg.merge_from_dict(args.cfg_options)
  72. # in case the test dataset is concatenated
  73. samples_per_gpu = 1
  74. if isinstance(cfg.data.test, dict):
  75. cfg.data.test.test_mode = True
  76. samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
  77. if samples_per_gpu > 1:
  78. # Replace 'ImageToTensor' to 'DefaultFormatBundle'
  79. cfg.data.test.pipeline = replace_ImageToTensor(
  80. cfg.data.test.pipeline)
  81. elif isinstance(cfg.data.test, list):
  82. for ds_cfg in cfg.data.test:
  83. ds_cfg.test_mode = True
  84. samples_per_gpu = max(
  85. [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
  86. if samples_per_gpu > 1:
  87. for ds_cfg in cfg.data.test:
  88. ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
  89. # build the dataloader
  90. dataset = build_dataset(cfg.data.test)
  91. data_loader = build_dataloader(
  92. dataset,
  93. samples_per_gpu=samples_per_gpu,
  94. workers_per_gpu=cfg.data.workers_per_gpu,
  95. dist=False,
  96. shuffle=False)
  97. if args.backend == 'onnxruntime':
  98. from mmdet.core.export.model_wrappers import ONNXRuntimeDetector
  99. model = ONNXRuntimeDetector(
  100. args.model, class_names=dataset.CLASSES, device_id=0)
  101. elif args.backend == 'tensorrt':
  102. from mmdet.core.export.model_wrappers import TensorRTDetector
  103. model = TensorRTDetector(
  104. args.model, class_names=dataset.CLASSES, device_id=0)
  105. model = MMDataParallel(model, device_ids=[0])
  106. outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
  107. args.show_score_thr)
  108. if args.out:
  109. print(f'\nwriting results to {args.out}')
  110. mmcv.dump(outputs, args.out)
  111. kwargs = {} if args.eval_options is None else args.eval_options
  112. if args.format_only:
  113. dataset.format_results(outputs, **kwargs)
  114. if args.eval:
  115. eval_kwargs = cfg.get('evaluation', {}).copy()
  116. # hard-code way to remove EvalHook args
  117. for key in [
  118. 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
  119. 'rule'
  120. ]:
  121. eval_kwargs.pop(key, None)
  122. eval_kwargs.update(dict(metric=args.eval, **kwargs))
  123. print(dataset.evaluate(outputs, **eval_kwargs))
  124. if __name__ == '__main__':
  125. main()

No Description

Contributors (2)