You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_async.py 2.6 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """Tests for async interface."""
  3. import asyncio
  4. import os
  5. import sys
  6. import asynctest
  7. import mmcv
  8. import torch
  9. from mmdet.apis import async_inference_detector, init_detector
  10. if sys.version_info >= (3, 7):
  11. from mmdet.utils.contextmanagers import concurrent
  12. class AsyncTestCase(asynctest.TestCase):
  13. use_default_loop = False
  14. forbid_get_event_loop = True
  15. TEST_TIMEOUT = int(os.getenv('ASYNCIO_TEST_TIMEOUT', '30'))
  16. def _run_test_method(self, method):
  17. result = method()
  18. if asyncio.iscoroutine(result):
  19. self.loop.run_until_complete(
  20. asyncio.wait_for(result, timeout=self.TEST_TIMEOUT))
  21. class MaskRCNNDetector:
  22. def __init__(self,
  23. model_config,
  24. checkpoint=None,
  25. streamqueue_size=3,
  26. device='cuda:0'):
  27. self.streamqueue_size = streamqueue_size
  28. self.device = device
  29. # build the model and load checkpoint
  30. self.model = init_detector(
  31. model_config, checkpoint=None, device=self.device)
  32. self.streamqueue = None
  33. async def init(self):
  34. self.streamqueue = asyncio.Queue()
  35. for _ in range(self.streamqueue_size):
  36. stream = torch.cuda.Stream(device=self.device)
  37. self.streamqueue.put_nowait(stream)
  38. if sys.version_info >= (3, 7):
  39. async def apredict(self, img):
  40. if isinstance(img, str):
  41. img = mmcv.imread(img)
  42. async with concurrent(self.streamqueue):
  43. result = await async_inference_detector(self.model, img)
  44. return result
  45. class AsyncInferenceTestCase(AsyncTestCase):
  46. if sys.version_info >= (3, 7):
  47. async def test_simple_inference(self):
  48. if not torch.cuda.is_available():
  49. import pytest
  50. pytest.skip('test requires GPU and torch+cuda')
  51. ori_grad_enabled = torch.is_grad_enabled()
  52. root_dir = os.path.dirname(os.path.dirname(__name__))
  53. model_config = os.path.join(
  54. root_dir, 'configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py')
  55. detector = MaskRCNNDetector(model_config)
  56. await detector.init()
  57. img_path = os.path.join(root_dir, 'demo/demo.jpg')
  58. bboxes, _ = await detector.apredict(img_path)
  59. self.assertTrue(bboxes)
  60. # asy inference detector will hack grad_enabled,
  61. # so restore here to avoid it to influence other tests
  62. torch.set_grad_enabled(ori_grad_enabled)

No Description

Contributors (3)