You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

loading.py 20 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import os.path as osp
  3. import mmcv
  4. import numpy as np
  5. import pycocotools.mask as maskUtils
  6. from mmdet.core import BitmapMasks, PolygonMasks
  7. from ..builder import PIPELINES
  8. try:
  9. from panopticapi.utils import rgb2id
  10. except ImportError:
  11. rgb2id = None
  12. @PIPELINES.register_module()
  13. class LoadImageFromFile:
  14. """Load an image from file.
  15. Required keys are "img_prefix" and "img_info" (a dict that must contain the
  16. key "filename"). Added or updated keys are "filename", "img", "img_shape",
  17. "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
  18. "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
  19. Args:
  20. to_float32 (bool): Whether to convert the loaded image to a float32
  21. numpy array. If set to False, the loaded image is an uint8 array.
  22. Defaults to False.
  23. color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
  24. Defaults to 'color'.
  25. file_client_args (dict): Arguments to instantiate a FileClient.
  26. See :class:`mmcv.fileio.FileClient` for details.
  27. Defaults to ``dict(backend='disk')``.
  28. """
  29. def __init__(self,
  30. to_float32=False,
  31. color_type='color',
  32. file_client_args=dict(backend='disk')):
  33. self.to_float32 = to_float32
  34. self.color_type = color_type
  35. self.file_client_args = file_client_args.copy()
  36. self.file_client = None
  37. def __call__(self, results):
  38. """Call functions to load image and get image meta information.
  39. Args:
  40. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  41. Returns:
  42. dict: The dict contains loaded image and meta information.
  43. """
  44. if self.file_client is None:
  45. self.file_client = mmcv.FileClient(**self.file_client_args)
  46. if results['img_prefix'] is not None:
  47. filename = osp.join(results['img_prefix'],
  48. results['img_info']['filename'])
  49. else:
  50. filename = results['img_info']['filename']
  51. img_bytes = self.file_client.get(filename)
  52. img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
  53. if self.to_float32:
  54. img = img.astype(np.float32)
  55. results['filename'] = filename
  56. results['ori_filename'] = results['img_info']['filename']
  57. results['img'] = img
  58. results['img_shape'] = img.shape
  59. results['ori_shape'] = img.shape
  60. results['img_fields'] = ['img']
  61. return results
  62. def __repr__(self):
  63. repr_str = (f'{self.__class__.__name__}('
  64. f'to_float32={self.to_float32}, '
  65. f"color_type='{self.color_type}', "
  66. f'file_client_args={self.file_client_args})')
  67. return repr_str
  68. @PIPELINES.register_module()
  69. class LoadImageFromWebcam(LoadImageFromFile):
  70. """Load an image from webcam.
  71. Similar with :obj:`LoadImageFromFile`, but the image read from webcam is in
  72. ``results['img']``.
  73. """
  74. def __call__(self, results):
  75. """Call functions to add image meta information.
  76. Args:
  77. results (dict): Result dict with Webcam read image in
  78. ``results['img']``.
  79. Returns:
  80. dict: The dict contains loaded image and meta information.
  81. """
  82. img = results['img']
  83. if self.to_float32:
  84. img = img.astype(np.float32)
  85. results['filename'] = None
  86. results['ori_filename'] = None
  87. results['img'] = img
  88. results['img_shape'] = img.shape
  89. results['ori_shape'] = img.shape
  90. results['img_fields'] = ['img']
  91. return results
  92. @PIPELINES.register_module()
  93. class LoadMultiChannelImageFromFiles:
  94. """Load multi-channel images from a list of separate channel files.
  95. Required keys are "img_prefix" and "img_info" (a dict that must contain the
  96. key "filename", which is expected to be a list of filenames).
  97. Added or updated keys are "filename", "img", "img_shape",
  98. "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
  99. "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
  100. Args:
  101. to_float32 (bool): Whether to convert the loaded image to a float32
  102. numpy array. If set to False, the loaded image is an uint8 array.
  103. Defaults to False.
  104. color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
  105. Defaults to 'color'.
  106. file_client_args (dict): Arguments to instantiate a FileClient.
  107. See :class:`mmcv.fileio.FileClient` for details.
  108. Defaults to ``dict(backend='disk')``.
  109. """
  110. def __init__(self,
  111. to_float32=False,
  112. color_type='unchanged',
  113. file_client_args=dict(backend='disk')):
  114. self.to_float32 = to_float32
  115. self.color_type = color_type
  116. self.file_client_args = file_client_args.copy()
  117. self.file_client = None
  118. def __call__(self, results):
  119. """Call functions to load multiple images and get images meta
  120. information.
  121. Args:
  122. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  123. Returns:
  124. dict: The dict contains loaded images and meta information.
  125. """
  126. if self.file_client is None:
  127. self.file_client = mmcv.FileClient(**self.file_client_args)
  128. if results['img_prefix'] is not None:
  129. filename = [
  130. osp.join(results['img_prefix'], fname)
  131. for fname in results['img_info']['filename']
  132. ]
  133. else:
  134. filename = results['img_info']['filename']
  135. img = []
  136. for name in filename:
  137. img_bytes = self.file_client.get(name)
  138. img.append(mmcv.imfrombytes(img_bytes, flag=self.color_type))
  139. img = np.stack(img, axis=-1)
  140. if self.to_float32:
  141. img = img.astype(np.float32)
  142. results['filename'] = filename
  143. results['ori_filename'] = results['img_info']['filename']
  144. results['img'] = img
  145. results['img_shape'] = img.shape
  146. results['ori_shape'] = img.shape
  147. # Set initial values for default meta_keys
  148. results['pad_shape'] = img.shape
  149. results['scale_factor'] = 1.0
  150. num_channels = 1 if len(img.shape) < 3 else img.shape[2]
  151. results['img_norm_cfg'] = dict(
  152. mean=np.zeros(num_channels, dtype=np.float32),
  153. std=np.ones(num_channels, dtype=np.float32),
  154. to_rgb=False)
  155. return results
  156. def __repr__(self):
  157. repr_str = (f'{self.__class__.__name__}('
  158. f'to_float32={self.to_float32}, '
  159. f"color_type='{self.color_type}', "
  160. f'file_client_args={self.file_client_args})')
  161. return repr_str
  162. @PIPELINES.register_module()
  163. class LoadAnnotations:
  164. """Load multiple types of annotations.
  165. Args:
  166. with_bbox (bool): Whether to parse and load the bbox annotation.
  167. Default: True.
  168. with_label (bool): Whether to parse and load the label annotation.
  169. Default: True.
  170. with_mask (bool): Whether to parse and load the mask annotation.
  171. Default: False.
  172. with_seg (bool): Whether to parse and load the semantic segmentation
  173. annotation. Default: False.
  174. poly2mask (bool): Whether to convert the instance masks from polygons
  175. to bitmaps. Default: True.
  176. file_client_args (dict): Arguments to instantiate a FileClient.
  177. See :class:`mmcv.fileio.FileClient` for details.
  178. Defaults to ``dict(backend='disk')``.
  179. """
  180. def __init__(self,
  181. with_bbox=True,
  182. with_label=True,
  183. with_mask=False,
  184. with_seg=False,
  185. poly2mask=True,
  186. file_client_args=dict(backend='disk')):
  187. self.with_bbox = with_bbox
  188. self.with_label = with_label
  189. self.with_mask = with_mask
  190. self.with_seg = with_seg
  191. self.poly2mask = poly2mask
  192. self.file_client_args = file_client_args.copy()
  193. self.file_client = None
  194. def _load_bboxes(self, results):
  195. """Private function to load bounding box annotations.
  196. Args:
  197. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  198. Returns:
  199. dict: The dict contains loaded bounding box annotations.
  200. """
  201. ann_info = results['ann_info']
  202. results['gt_bboxes'] = ann_info['bboxes'].copy()
  203. gt_bboxes_ignore = ann_info.get('bboxes_ignore', None)
  204. if gt_bboxes_ignore is not None:
  205. results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy()
  206. results['bbox_fields'].append('gt_bboxes_ignore')
  207. results['bbox_fields'].append('gt_bboxes')
  208. return results
  209. def _load_labels(self, results):
  210. """Private function to load label annotations.
  211. Args:
  212. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  213. Returns:
  214. dict: The dict contains loaded label annotations.
  215. """
  216. results['gt_labels'] = results['ann_info']['labels'].copy()
  217. return results
  218. def _poly2mask(self, mask_ann, img_h, img_w):
  219. """Private function to convert masks represented with polygon to
  220. bitmaps.
  221. Args:
  222. mask_ann (list | dict): Polygon mask annotation input.
  223. img_h (int): The height of output mask.
  224. img_w (int): The width of output mask.
  225. Returns:
  226. numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
  227. """
  228. if isinstance(mask_ann, list):
  229. # polygon -- a single object might consist of multiple parts
  230. # we merge all parts into one mask rle code
  231. rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
  232. rle = maskUtils.merge(rles)
  233. elif isinstance(mask_ann['counts'], list):
  234. # uncompressed RLE
  235. rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
  236. else:
  237. # rle
  238. rle = mask_ann
  239. mask = maskUtils.decode(rle)
  240. return mask
  241. def process_polygons(self, polygons):
  242. """Convert polygons to list of ndarray and filter invalid polygons.
  243. Args:
  244. polygons (list[list]): Polygons of one instance.
  245. Returns:
  246. list[numpy.ndarray]: Processed polygons.
  247. """
  248. polygons = [np.array(p) for p in polygons]
  249. valid_polygons = []
  250. for polygon in polygons:
  251. if len(polygon) % 2 == 0 and len(polygon) >= 6:
  252. valid_polygons.append(polygon)
  253. return valid_polygons
  254. def _load_masks(self, results):
  255. """Private function to load mask annotations.
  256. Args:
  257. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  258. Returns:
  259. dict: The dict contains loaded mask annotations.
  260. If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
  261. :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
  262. """
  263. h, w = results['img_info']['height'], results['img_info']['width']
  264. gt_masks = results['ann_info']['masks']
  265. if self.poly2mask:
  266. gt_masks = BitmapMasks(
  267. [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
  268. else:
  269. gt_masks = PolygonMasks(
  270. [self.process_polygons(polygons) for polygons in gt_masks], h,
  271. w)
  272. results['gt_masks'] = gt_masks
  273. results['mask_fields'].append('gt_masks')
  274. return results
  275. def _load_semantic_seg(self, results):
  276. """Private function to load semantic segmentation annotations.
  277. Args:
  278. results (dict): Result dict from :obj:`dataset`.
  279. Returns:
  280. dict: The dict contains loaded semantic segmentation annotations.
  281. """
  282. if self.file_client is None:
  283. self.file_client = mmcv.FileClient(**self.file_client_args)
  284. filename = osp.join(results['seg_prefix'],
  285. results['ann_info']['seg_map'])
  286. img_bytes = self.file_client.get(filename)
  287. results['gt_semantic_seg'] = mmcv.imfrombytes(
  288. img_bytes, flag='unchanged').squeeze()
  289. results['seg_fields'].append('gt_semantic_seg')
  290. return results
  291. def __call__(self, results):
  292. """Call function to load multiple types annotations.
  293. Args:
  294. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  295. Returns:
  296. dict: The dict contains loaded bounding box, label, mask and
  297. semantic segmentation annotations.
  298. """
  299. if self.with_bbox:
  300. results = self._load_bboxes(results)
  301. if results is None:
  302. return None
  303. if self.with_label:
  304. results = self._load_labels(results)
  305. if self.with_mask:
  306. results = self._load_masks(results)
  307. if self.with_seg:
  308. results = self._load_semantic_seg(results)
  309. return results
  310. def __repr__(self):
  311. repr_str = self.__class__.__name__
  312. repr_str += f'(with_bbox={self.with_bbox}, '
  313. repr_str += f'with_label={self.with_label}, '
  314. repr_str += f'with_mask={self.with_mask}, '
  315. repr_str += f'with_seg={self.with_seg}, '
  316. repr_str += f'poly2mask={self.poly2mask}, '
  317. repr_str += f'poly2mask={self.file_client_args})'
  318. return repr_str
  319. @PIPELINES.register_module()
  320. class LoadPanopticAnnotations(LoadAnnotations):
  321. """Load multiple types of panoptic annotations.
  322. Args:
  323. with_bbox (bool): Whether to parse and load the bbox annotation.
  324. Default: True.
  325. with_label (bool): Whether to parse and load the label annotation.
  326. Default: True.
  327. with_mask (bool): Whether to parse and load the mask annotation.
  328. Default: True.
  329. with_seg (bool): Whether to parse and load the semantic segmentation
  330. annotation. Default: True.
  331. file_client_args (dict): Arguments to instantiate a FileClient.
  332. See :class:`mmcv.fileio.FileClient` for details.
  333. Defaults to ``dict(backend='disk')``.
  334. """
  335. def __init__(self,
  336. with_bbox=True,
  337. with_label=True,
  338. with_mask=True,
  339. with_seg=True,
  340. file_client_args=dict(backend='disk')):
  341. if rgb2id is None:
  342. raise RuntimeError(
  343. 'panopticapi is not installed, please install it by: '
  344. 'pip install git+https://github.com/cocodataset/'
  345. 'panopticapi.git.')
  346. super(LoadPanopticAnnotations,
  347. self).__init__(with_bbox, with_label, with_mask, with_seg, True,
  348. file_client_args)
  349. def _load_masks_and_semantic_segs(self, results):
  350. """Private function to load mask and semantic segmentation annotations.
  351. In gt_semantic_seg, the foreground label is from `0` to
  352. `num_things - 1`, the background label is from `num_things` to
  353. `num_things + num_stuff - 1`, 255 means the ignored label (`VOID`).
  354. Args:
  355. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  356. Returns:
  357. dict: The dict contains loaded mask and semantic segmentation
  358. annotations. `BitmapMasks` is used for mask annotations.
  359. """
  360. if self.file_client is None:
  361. self.file_client = mmcv.FileClient(**self.file_client_args)
  362. filename = osp.join(results['seg_prefix'],
  363. results['ann_info']['seg_map'])
  364. img_bytes = self.file_client.get(filename)
  365. pan_png = mmcv.imfrombytes(
  366. img_bytes, flag='color', channel_order='rgb').squeeze()
  367. pan_png = rgb2id(pan_png)
  368. gt_masks = []
  369. gt_seg = np.zeros_like(pan_png) + 255 # 255 as ignore
  370. for mask_info in results['ann_info']['masks']:
  371. mask = (pan_png == mask_info['id'])
  372. gt_seg = np.where(mask, mask_info['category'], gt_seg)
  373. # The legal thing masks
  374. if mask_info.get('is_thing'):
  375. gt_masks.append(mask.astype(np.uint8))
  376. if self.with_mask:
  377. h, w = results['img_info']['height'], results['img_info']['width']
  378. gt_masks = BitmapMasks(gt_masks, h, w)
  379. results['gt_masks'] = gt_masks
  380. results['mask_fields'].append('gt_masks')
  381. if self.with_seg:
  382. results['gt_semantic_seg'] = gt_seg
  383. results['seg_fields'].append('gt_semantic_seg')
  384. return results
  385. def __call__(self, results):
  386. """Call function to load multiple types panoptic annotations.
  387. Args:
  388. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  389. Returns:
  390. dict: The dict contains loaded bounding box, label, mask and
  391. semantic segmentation annotations.
  392. """
  393. if self.with_bbox:
  394. results = self._load_bboxes(results)
  395. if results is None:
  396. return None
  397. if self.with_label:
  398. results = self._load_labels(results)
  399. if self.with_mask or self.with_seg:
  400. # The tasks completed by '_load_masks' and '_load_semantic_segs'
  401. # in LoadAnnotations are merged to one function.
  402. results = self._load_masks_and_semantic_segs(results)
  403. return results
  404. @PIPELINES.register_module()
  405. class LoadProposals:
  406. """Load proposal pipeline.
  407. Required key is "proposals". Updated keys are "proposals", "bbox_fields".
  408. Args:
  409. num_max_proposals (int, optional): Maximum number of proposals to load.
  410. If not specified, all proposals will be loaded.
  411. """
  412. def __init__(self, num_max_proposals=None):
  413. self.num_max_proposals = num_max_proposals
  414. def __call__(self, results):
  415. """Call function to load proposals from file.
  416. Args:
  417. results (dict): Result dict from :obj:`mmdet.CustomDataset`.
  418. Returns:
  419. dict: The dict contains loaded proposal annotations.
  420. """
  421. proposals = results['proposals']
  422. if proposals.shape[1] not in (4, 5):
  423. raise AssertionError(
  424. 'proposals should have shapes (n, 4) or (n, 5), '
  425. f'but found {proposals.shape}')
  426. proposals = proposals[:, :4]
  427. if self.num_max_proposals is not None:
  428. proposals = proposals[:self.num_max_proposals]
  429. if len(proposals) == 0:
  430. proposals = np.array([[0, 0, 0, 0]], dtype=np.float32)
  431. results['proposals'] = proposals
  432. results['bbox_fields'].append('proposals')
  433. return results
  434. def __repr__(self):
  435. return self.__class__.__name__ + \
  436. f'(num_max_proposals={self.num_max_proposals})'
  437. @PIPELINES.register_module()
  438. class FilterAnnotations:
  439. """Filter invalid annotations.
  440. Args:
  441. min_gt_bbox_wh (tuple[int]): Minimum width and height of ground truth
  442. boxes.
  443. """
  444. def __init__(self, min_gt_bbox_wh):
  445. # TODO: add more filter options
  446. self.min_gt_bbox_wh = min_gt_bbox_wh
  447. def __call__(self, results):
  448. assert 'gt_bboxes' in results
  449. gt_bboxes = results['gt_bboxes']
  450. w = gt_bboxes[:, 2] - gt_bboxes[:, 0]
  451. h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
  452. keep = (w > self.min_gt_bbox_wh[0]) & (h > self.min_gt_bbox_wh[1])
  453. if not keep.any():
  454. return None
  455. else:
  456. keys = ('gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg')
  457. for key in keys:
  458. if key in results:
  459. results[key] = results[key][keep]
  460. return results

No Description

Contributors (1)