You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

formating.py 12 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from collections.abc import Sequence
  3. import mmcv
  4. import numpy as np
  5. import torch
  6. from mmcv.parallel import DataContainer as DC
  7. from ..builder import PIPELINES
  8. def to_tensor(data):
  9. """Convert objects of various python types to :obj:`torch.Tensor`.
  10. Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
  11. :class:`Sequence`, :class:`int` and :class:`float`.
  12. Args:
  13. data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
  14. be converted.
  15. """
  16. if isinstance(data, torch.Tensor):
  17. return data
  18. elif isinstance(data, np.ndarray):
  19. return torch.from_numpy(data)
  20. elif isinstance(data, Sequence) and not mmcv.is_str(data):
  21. return torch.tensor(data)
  22. elif isinstance(data, int):
  23. return torch.LongTensor([data])
  24. elif isinstance(data, float):
  25. return torch.FloatTensor([data])
  26. else:
  27. raise TypeError(f'type {type(data)} cannot be converted to tensor.')
  28. @PIPELINES.register_module()
  29. class ToTensor:
  30. """Convert some results to :obj:`torch.Tensor` by given keys.
  31. Args:
  32. keys (Sequence[str]): Keys that need to be converted to Tensor.
  33. """
  34. def __init__(self, keys):
  35. self.keys = keys
  36. def __call__(self, results):
  37. """Call function to convert data in results to :obj:`torch.Tensor`.
  38. Args:
  39. results (dict): Result dict contains the data to convert.
  40. Returns:
  41. dict: The result dict contains the data converted
  42. to :obj:`torch.Tensor`.
  43. """
  44. for key in self.keys:
  45. results[key] = to_tensor(results[key])
  46. return results
  47. def __repr__(self):
  48. return self.__class__.__name__ + f'(keys={self.keys})'
  49. @PIPELINES.register_module()
  50. class ImageToTensor:
  51. """Convert image to :obj:`torch.Tensor` by given keys.
  52. The dimension order of input image is (H, W, C). The pipeline will convert
  53. it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
  54. (1, H, W).
  55. Args:
  56. keys (Sequence[str]): Key of images to be converted to Tensor.
  57. """
  58. def __init__(self, keys):
  59. self.keys = keys
  60. def __call__(self, results):
  61. """Call function to convert image in results to :obj:`torch.Tensor` and
  62. transpose the channel order.
  63. Args:
  64. results (dict): Result dict contains the image data to convert.
  65. Returns:
  66. dict: The result dict contains the image converted
  67. to :obj:`torch.Tensor` and transposed to (C, H, W) order.
  68. """
  69. for key in self.keys:
  70. img = results[key]
  71. if len(img.shape) < 3:
  72. img = np.expand_dims(img, -1)
  73. results[key] = (to_tensor(img.transpose(2, 0, 1))).contiguous()
  74. return results
  75. def __repr__(self):
  76. return self.__class__.__name__ + f'(keys={self.keys})'
  77. @PIPELINES.register_module()
  78. class Transpose:
  79. """Transpose some results by given keys.
  80. Args:
  81. keys (Sequence[str]): Keys of results to be transposed.
  82. order (Sequence[int]): Order of transpose.
  83. """
  84. def __init__(self, keys, order):
  85. self.keys = keys
  86. self.order = order
  87. def __call__(self, results):
  88. """Call function to transpose the channel order of data in results.
  89. Args:
  90. results (dict): Result dict contains the data to transpose.
  91. Returns:
  92. dict: The result dict contains the data transposed to \
  93. ``self.order``.
  94. """
  95. for key in self.keys:
  96. results[key] = results[key].transpose(self.order)
  97. return results
  98. def __repr__(self):
  99. return self.__class__.__name__ + \
  100. f'(keys={self.keys}, order={self.order})'
  101. @PIPELINES.register_module()
  102. class ToDataContainer:
  103. """Convert results to :obj:`mmcv.DataContainer` by given fields.
  104. Args:
  105. fields (Sequence[dict]): Each field is a dict like
  106. ``dict(key='xxx', **kwargs)``. The ``key`` in result will
  107. be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
  108. Default: ``(dict(key='img', stack=True), dict(key='gt_bboxes'),
  109. dict(key='gt_labels'))``.
  110. """
  111. def __init__(self,
  112. fields=(dict(key='img', stack=True), dict(key='gt_bboxes'),
  113. dict(key='gt_labels'))):
  114. self.fields = fields
  115. def __call__(self, results):
  116. """Call function to convert data in results to
  117. :obj:`mmcv.DataContainer`.
  118. Args:
  119. results (dict): Result dict contains the data to convert.
  120. Returns:
  121. dict: The result dict contains the data converted to \
  122. :obj:`mmcv.DataContainer`.
  123. """
  124. for field in self.fields:
  125. field = field.copy()
  126. key = field.pop('key')
  127. results[key] = DC(results[key], **field)
  128. return results
  129. def __repr__(self):
  130. return self.__class__.__name__ + f'(fields={self.fields})'
  131. @PIPELINES.register_module()
  132. class DefaultFormatBundle:
  133. """Default formatting bundle.
  134. It simplifies the pipeline of formatting common fields, including "img",
  135. "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg".
  136. These fields are formatted as follows.
  137. - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
  138. - proposals: (1)to tensor, (2)to DataContainer
  139. - gt_bboxes: (1)to tensor, (2)to DataContainer
  140. - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
  141. - gt_labels: (1)to tensor, (2)to DataContainer
  142. - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True)
  143. - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \
  144. (3)to DataContainer (stack=True)
  145. """
  146. def __call__(self, results):
  147. """Call function to transform and format common fields in results.
  148. Args:
  149. results (dict): Result dict contains the data to convert.
  150. Returns:
  151. dict: The result dict contains the data that is formatted with \
  152. default bundle.
  153. """
  154. if 'img' in results:
  155. img = results['img']
  156. # add default meta keys
  157. results = self._add_default_meta_keys(results)
  158. if len(img.shape) < 3:
  159. img = np.expand_dims(img, -1)
  160. img = np.ascontiguousarray(img.transpose(2, 0, 1))
  161. results['img'] = DC(to_tensor(img), stack=True)
  162. for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
  163. if key not in results:
  164. continue
  165. results[key] = DC(to_tensor(results[key]))
  166. if 'gt_masks' in results:
  167. results['gt_masks'] = DC(results['gt_masks'], cpu_only=True)
  168. if 'gt_semantic_seg' in results:
  169. results['gt_semantic_seg'] = DC(
  170. to_tensor(results['gt_semantic_seg'][None, ...]), stack=True)
  171. return results
  172. def _add_default_meta_keys(self, results):
  173. """Add default meta keys.
  174. We set default meta keys including `pad_shape`, `scale_factor` and
  175. `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and
  176. `Pad` are implemented during the whole pipeline.
  177. Args:
  178. results (dict): Result dict contains the data to convert.
  179. Returns:
  180. results (dict): Updated result dict contains the data to convert.
  181. """
  182. img = results['img']
  183. results.setdefault('pad_shape', img.shape)
  184. results.setdefault('scale_factor', 1.0)
  185. num_channels = 1 if len(img.shape) < 3 else img.shape[2]
  186. results.setdefault(
  187. 'img_norm_cfg',
  188. dict(
  189. mean=np.zeros(num_channels, dtype=np.float32),
  190. std=np.ones(num_channels, dtype=np.float32),
  191. to_rgb=False))
  192. return results
  193. def __repr__(self):
  194. return self.__class__.__name__
  195. @PIPELINES.register_module()
  196. class Collect:
  197. """Collect data from the loader relevant to the specific task.
  198. This is usually the last stage of the data loader pipeline. Typically keys
  199. is set to some subset of "img", "proposals", "gt_bboxes",
  200. "gt_bboxes_ignore", "gt_labels", and/or "gt_masks".
  201. The "img_meta" item is always populated. The contents of the "img_meta"
  202. dictionary depends on "meta_keys". By default this includes:
  203. - "img_shape": shape of the image input to the network as a tuple \
  204. (h, w, c). Note that images may be zero padded on the \
  205. bottom/right if the batch tensor is larger than this shape.
  206. - "scale_factor": a float indicating the preprocessing scale
  207. - "flip": a boolean indicating if image flip transform was used
  208. - "filename": path to the image file
  209. - "ori_shape": original shape of the image as a tuple (h, w, c)
  210. - "pad_shape": image shape after padding
  211. - "img_norm_cfg": a dict of normalization information:
  212. - mean - per channel mean subtraction
  213. - std - per channel std divisor
  214. - to_rgb - bool indicating if bgr was converted to rgb
  215. Args:
  216. keys (Sequence[str]): Keys of results to be collected in ``data``.
  217. meta_keys (Sequence[str], optional): Meta keys to be converted to
  218. ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
  219. Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
  220. 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
  221. 'img_norm_cfg')``
  222. """
  223. def __init__(self,
  224. keys,
  225. meta_keys=('filename', 'ori_filename', 'ori_shape',
  226. 'img_shape', 'pad_shape', 'scale_factor', 'flip',
  227. 'flip_direction', 'img_norm_cfg')):
  228. self.keys = keys
  229. self.meta_keys = meta_keys
  230. def __call__(self, results):
  231. """Call function to collect keys in results. The keys in ``meta_keys``
  232. will be converted to :obj:mmcv.DataContainer.
  233. Args:
  234. results (dict): Result dict contains the data to collect.
  235. Returns:
  236. dict: The result dict contains the following keys
  237. - keys in``self.keys``
  238. - ``img_metas``
  239. """
  240. data = {}
  241. img_meta = {}
  242. for key in self.meta_keys:
  243. img_meta[key] = results[key]
  244. data['img_metas'] = DC(img_meta, cpu_only=True)
  245. for key in self.keys:
  246. data[key] = results[key]
  247. return data
  248. def __repr__(self):
  249. return self.__class__.__name__ + \
  250. f'(keys={self.keys}, meta_keys={self.meta_keys})'
  251. @PIPELINES.register_module()
  252. class WrapFieldsToLists:
  253. """Wrap fields of the data dictionary into lists for evaluation.
  254. This class can be used as a last step of a test or validation
  255. pipeline for single image evaluation or inference.
  256. Example:
  257. >>> test_pipeline = [
  258. >>> dict(type='LoadImageFromFile'),
  259. >>> dict(type='Normalize',
  260. mean=[123.675, 116.28, 103.53],
  261. std=[58.395, 57.12, 57.375],
  262. to_rgb=True),
  263. >>> dict(type='Pad', size_divisor=32),
  264. >>> dict(type='ImageToTensor', keys=['img']),
  265. >>> dict(type='Collect', keys=['img']),
  266. >>> dict(type='WrapFieldsToLists')
  267. >>> ]
  268. """
  269. def __call__(self, results):
  270. """Call function to wrap fields into lists.
  271. Args:
  272. results (dict): Result dict contains the data to wrap.
  273. Returns:
  274. dict: The result dict where value of ``self.keys`` are wrapped \
  275. into list.
  276. """
  277. # Wrap dict fields into lists
  278. for key, val in results.items():
  279. results[key] = [val]
  280. return results
  281. def __repr__(self):
  282. return f'{self.__class__.__name__}()'

No Description

Contributors (3)