You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

async_benchmark.py 3.2 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import asyncio
  3. import os
  4. import shutil
  5. import urllib
  6. import mmcv
  7. import torch
  8. from mmdet.apis import (async_inference_detector, inference_detector,
  9. init_detector)
  10. from mmdet.utils.contextmanagers import concurrent
  11. from mmdet.utils.profiling import profile_time
  12. async def main():
  13. """Benchmark between async and synchronous inference interfaces.
  14. Sample runs for 20 demo images on K80 GPU, model - mask_rcnn_r50_fpn_1x:
  15. async sync
  16. 7981.79 ms 9660.82 ms
  17. 8074.52 ms 9660.94 ms
  18. 7976.44 ms 9406.83 ms
  19. Async variant takes about 0.83-0.85 of the time of the synchronous
  20. interface.
  21. """
  22. project_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
  23. project_dir = os.path.join(project_dir, '..')
  24. config_file = os.path.join(
  25. project_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py')
  26. checkpoint_file = os.path.join(
  27. project_dir,
  28. 'checkpoints/mask_rcnn_r50_fpn_1x_coco_20200205-d4b0c5d6.pth')
  29. if not os.path.exists(checkpoint_file):
  30. url = ('https://download.openmmlab.com/mmdetection/v2.0'
  31. '/mask_rcnn/mask_rcnn_r50_fpn_1x_coco'
  32. '/mask_rcnn_r50_fpn_1x_coco_20200205-d4b0c5d6.pth')
  33. print(f'Downloading {url} ...')
  34. local_filename, _ = urllib.request.urlretrieve(url)
  35. os.makedirs(os.path.dirname(checkpoint_file), exist_ok=True)
  36. shutil.move(local_filename, checkpoint_file)
  37. print(f'Saved as {checkpoint_file}')
  38. else:
  39. print(f'Using existing checkpoint {checkpoint_file}')
  40. device = 'cuda:0'
  41. model = init_detector(
  42. config_file, checkpoint=checkpoint_file, device=device)
  43. # queue is used for concurrent inference of multiple images
  44. streamqueue = asyncio.Queue()
  45. # queue size defines concurrency level
  46. streamqueue_size = 4
  47. for _ in range(streamqueue_size):
  48. streamqueue.put_nowait(torch.cuda.Stream(device=device))
  49. # test a single image and show the results
  50. img = mmcv.imread(os.path.join(project_dir, 'demo/demo.jpg'))
  51. # warmup
  52. await async_inference_detector(model, img)
  53. async def detect(img):
  54. async with concurrent(streamqueue):
  55. return await async_inference_detector(model, img)
  56. num_of_images = 20
  57. with profile_time('benchmark', 'async'):
  58. tasks = [
  59. asyncio.create_task(detect(img)) for _ in range(num_of_images)
  60. ]
  61. async_results = await asyncio.gather(*tasks)
  62. with torch.cuda.stream(torch.cuda.default_stream()):
  63. with profile_time('benchmark', 'sync'):
  64. sync_results = [
  65. inference_detector(model, img) for _ in range(num_of_images)
  66. ]
  67. result_dir = os.path.join(project_dir, 'demo')
  68. model.show_result(
  69. img,
  70. async_results[0],
  71. score_thr=0.5,
  72. show=False,
  73. out_file=os.path.join(result_dir, 'result_async.jpg'))
  74. model.show_result(
  75. img,
  76. sync_results[0],
  77. score_thr=0.5,
  78. show=False,
  79. out_file=os.path.join(result_dir, 'result_sync.jpg'))
  80. if __name__ == '__main__':
  81. asyncio.run(main())

No Description

Contributors (3)