You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 8.9 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import warnings
  3. import mmcv
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from mmcv.ops import RoIPool
  8. from mmcv.parallel import collate, scatter
  9. from mmcv.runner import load_checkpoint
  10. from mmdet.core import get_classes
  11. from mmdet.datasets import replace_ImageToTensor
  12. from mmdet.datasets.pipelines import Compose
  13. from mmdet.models import build_detector
  14. def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
  15. """Initialize a detector from config file.
  16. Args:
  17. config (str or :obj:`mmcv.Config`): Config file path or the config
  18. object.
  19. checkpoint (str, optional): Checkpoint path. If left as None, the model
  20. will not load any weights.
  21. cfg_options (dict): Options to override some settings in the used
  22. config.
  23. Returns:
  24. nn.Module: The constructed detector.
  25. """
  26. if isinstance(config, str):
  27. config = mmcv.Config.fromfile(config)
  28. elif not isinstance(config, mmcv.Config):
  29. raise TypeError('config must be a filename or Config object, '
  30. f'but got {type(config)}')
  31. if cfg_options is not None:
  32. config.merge_from_dict(cfg_options)
  33. config.model.pretrained = None
  34. config.model.train_cfg = None
  35. model = build_detector(config.model, test_cfg=config.get('test_cfg'))
  36. if checkpoint is not None:
  37. checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
  38. if 'CLASSES' in checkpoint.get('meta', {}):
  39. model.CLASSES = checkpoint['meta']['CLASSES']
  40. else:
  41. warnings.simplefilter('once')
  42. warnings.warn('Class names are not saved in the checkpoint\'s '
  43. 'meta data, use COCO classes by default.')
  44. model.CLASSES = get_classes('coco')
  45. model.cfg = config # save the config in the model for convenience
  46. model.to(device)
  47. model.eval()
  48. return model
  49. class LoadImage:
  50. """Deprecated.
  51. A simple pipeline to load image.
  52. """
  53. def __call__(self, results):
  54. """Call function to load images into results.
  55. Args:
  56. results (dict): A result dict contains the file name
  57. of the image to be read.
  58. Returns:
  59. dict: ``results`` will be returned containing loaded image.
  60. """
  61. warnings.simplefilter('once')
  62. warnings.warn('`LoadImage` is deprecated and will be removed in '
  63. 'future releases. You may use `LoadImageFromWebcam` '
  64. 'from `mmdet.datasets.pipelines.` instead.')
  65. if isinstance(results['img'], str):
  66. results['filename'] = results['img']
  67. results['ori_filename'] = results['img']
  68. else:
  69. results['filename'] = None
  70. results['ori_filename'] = None
  71. img = mmcv.imread(results['img'])
  72. results['img'] = img
  73. results['img_fields'] = ['img']
  74. results['img_shape'] = img.shape
  75. results['ori_shape'] = img.shape
  76. return results
  77. def inference_detector(model, imgs, feat=False):
  78. """Inference image(s) with the detector.
  79. Args:
  80. model (nn.Module): The loaded detector.
  81. imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
  82. Either image files or loaded images.
  83. Returns:
  84. If imgs is a list or tuple, the same length list type results
  85. will be returned, otherwise return the detection results directly.
  86. """
  87. if isinstance(imgs, (list, tuple)):
  88. is_batch = True
  89. else:
  90. imgs = [imgs]
  91. is_batch = False
  92. cfg = model.cfg
  93. device = next(model.parameters()).device # model device
  94. if isinstance(imgs[0], np.ndarray):
  95. cfg = cfg.copy()
  96. # set loading pipeline type
  97. cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
  98. cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
  99. test_pipeline = Compose(cfg.data.test.pipeline)
  100. datas = []
  101. for img in imgs:
  102. # prepare data
  103. if isinstance(img, np.ndarray):
  104. # directly add img
  105. data = dict(img=img)
  106. else:
  107. # add information into dict
  108. data = dict(img_info=dict(filename=img), img_prefix=None)
  109. # build the data pipeline
  110. data = test_pipeline(data)
  111. datas.append(data)
  112. data = collate(datas, samples_per_gpu=len(imgs))
  113. # just get the actual data from DataContainer
  114. data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
  115. data['img'] = [img.data[0] for img in data['img']]
  116. if next(model.parameters()).is_cuda:
  117. # scatter to specified GPU
  118. data = scatter(data, [device])[0]
  119. else:
  120. for m in model.modules():
  121. assert not isinstance(
  122. m, RoIPool
  123. ), 'CPU inference with RoIPool is not supported currently.'
  124. # forward the model
  125. if feat:
  126. with torch.no_grad():
  127. results = model(return_loss=False, rescale=True, **data)
  128. feat = model.extract_feat(data['img'][0])
  129. feat_0 = F.adaptive_avg_pool2d(feat[0], 1).flatten(1)
  130. feat_1 = F.adaptive_avg_pool2d(feat[1], 1).flatten(1)
  131. feat_2 = F.adaptive_avg_pool2d(feat[2], 1).flatten(1)
  132. feat_3 = F.adaptive_avg_pool2d(feat[3], 1).flatten(1)
  133. feat_4 = F.adaptive_avg_pool2d(feat[4], 1).flatten(1)
  134. feat_final = torch.cat((feat_0, feat_1, feat_2, feat_3, feat_4), 1)
  135. feat_final = feat_final.cpu().numpy()
  136. results_feat = []
  137. for i in range(feat_final.shape[0]):
  138. results_feat.append(feat_final[i].tolist())
  139. if not is_batch:
  140. return results[0], results_feat[0]
  141. else:
  142. return results, results_feat
  143. else:
  144. with torch.no_grad():
  145. results = model(return_loss=False, rescale=True, **data)
  146. if not is_batch:
  147. return results[0]
  148. else:
  149. return results
  150. async def async_inference_detector(model, imgs):
  151. """Async inference image(s) with the detector.
  152. Args:
  153. model (nn.Module): The loaded detector.
  154. img (str | ndarray): Either image files or loaded images.
  155. Returns:
  156. Awaitable detection results.
  157. """
  158. if not isinstance(imgs, (list, tuple)):
  159. imgs = [imgs]
  160. cfg = model.cfg
  161. device = next(model.parameters()).device # model device
  162. if isinstance(imgs[0], np.ndarray):
  163. cfg = cfg.copy()
  164. # set loading pipeline type
  165. cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
  166. cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
  167. test_pipeline = Compose(cfg.data.test.pipeline)
  168. datas = []
  169. for img in imgs:
  170. # prepare data
  171. if isinstance(img, np.ndarray):
  172. # directly add img
  173. data = dict(img=img)
  174. else:
  175. # add information into dict
  176. data = dict(img_info=dict(filename=img), img_prefix=None)
  177. # build the data pipeline
  178. data = test_pipeline(data)
  179. datas.append(data)
  180. data = collate(datas, samples_per_gpu=len(imgs))
  181. # just get the actual data from DataContainer
  182. data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
  183. data['img'] = [img.data[0] for img in data['img']]
  184. if next(model.parameters()).is_cuda:
  185. # scatter to specified GPU
  186. data = scatter(data, [device])[0]
  187. else:
  188. for m in model.modules():
  189. assert not isinstance(
  190. m, RoIPool
  191. ), 'CPU inference with RoIPool is not supported currently.'
  192. # We don't restore `torch.is_grad_enabled()` value during concurrent
  193. # inference since execution can overlap
  194. torch.set_grad_enabled(False)
  195. results = await model.aforward_test(rescale=True, **data)
  196. return results
  197. def show_result_pyplot(model,
  198. img,
  199. result,
  200. score_thr=0.3,
  201. title='result',
  202. wait_time=0):
  203. """Visualize the detection results on the image.
  204. Args:
  205. model (nn.Module): The loaded detector.
  206. img (str or np.ndarray): Image filename or loaded image.
  207. result (tuple[list] or list): The detection result, can be either
  208. (bbox, segm) or just bbox.
  209. score_thr (float): The threshold to visualize the bboxes and masks.
  210. title (str): Title of the pyplot figure.
  211. wait_time (float): Value of waitKey param.
  212. Default: 0.
  213. """
  214. if hasattr(model, 'module'):
  215. model = model.module
  216. model.show_result(
  217. img,
  218. result,
  219. score_thr=score_thr,
  220. show=True,
  221. wait_time=wait_time,
  222. win_name=title,
  223. bbox_color=(72, 101, 241),
  224. text_color=(72, 101, 241))

No Description

Contributors (3)