You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

list_data.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import itertools
  3. from collections.abc import Sized
  4. from typing import Any, List, Union
  5. import numpy as np
  6. import torch
  7. from ..utils import flatten as flatten_list
  8. from ..utils import to_hashable
  9. from .base_data_element import BaseDataElement
  10. BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
  11. LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
  12. IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray]
  13. # Modified from
  14. # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
  15. class ListData(BaseDataElement):
  16. """Data structure for instance-level annotations or predictions.
  17. Subclass of :class:`BaseDataElement`. All value in `data_fields`
  18. should have the same length. This design refer to
  19. https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
  20. ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value
  21. in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`,
  22. and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes.
  23. Examples:
  24. >>> # custom data structure
  25. >>> class TmpObject:
  26. ... def __init__(self, tmp) -> None:
  27. ... assert isinstance(tmp, list)
  28. ... self.tmp = tmp
  29. ... def __len__(self):
  30. ... return len(self.tmp)
  31. ... def __getitem__(self, item):
  32. ... if isinstance(item, int):
  33. ... if item >= len(self) or item < -len(self): # type:ignore
  34. ... raise IndexError(f'Index {item} out of range!')
  35. ... else:
  36. ... # keep the dimension
  37. ... item = slice(item, None, len(self))
  38. ... return TmpObject(self.tmp[item])
  39. ... @staticmethod
  40. ... def cat(tmp_objs):
  41. ... assert all(isinstance(results, TmpObject) for results in tmp_objs)
  42. ... if len(tmp_objs) == 1:
  43. ... return tmp_objs[0]
  44. ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
  45. ... tmp_list = list(itertools.chain(*tmp_list))
  46. ... new_data = TmpObject(tmp_list)
  47. ... return new_data
  48. ... def __repr__(self):
  49. ... return str(self.tmp)
  50. >>> from mmengine.structures import ListData
  51. >>> import numpy as np
  52. >>> import torch
  53. >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
  54. >>> instance_data = ListData(metainfo=img_meta)
  55. >>> 'img_shape' in instance_data
  56. True
  57. >>> instance_data.det_labels = torch.LongTensor([2, 3])
  58. >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
  59. >>> instance_data.bboxes = torch.rand((2, 4))
  60. >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]])
  61. >>> len(instance_data)
  62. 2
  63. >>> print(instance_data)
  64. <ListData(
  65. META INFORMATION
  66. img_shape: (800, 1196, 3)
  67. pad_shape: (800, 1216, 3)
  68. DATA FIELDS
  69. det_labels: tensor([2, 3])
  70. det_scores: tensor([0.8000, 0.7000])
  71. bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
  72. [0.8101, 0.3105, 0.5123, 0.6263]])
  73. polygons: [[1, 2, 3, 4], [5, 6, 7, 8]]
  74. ) at 0x7fb492de6280>
  75. >>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
  76. >>> sorted_results.det_scores
  77. tensor([0.7000, 0.8000])
  78. >>> print(instance_data[instance_data.det_scores > 0.75])
  79. <ListData(
  80. META INFORMATION
  81. img_shape: (800, 1196, 3)
  82. pad_shape: (800, 1216, 3)
  83. DATA FIELDS
  84. det_labels: tensor([2])
  85. det_scores: tensor([0.8000])
  86. bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
  87. polygons: [[1, 2, 3, 4]]
  88. ) at 0x7f64ecf0ec40>
  89. >>> print(instance_data[instance_data.det_scores > 1])
  90. <ListData(
  91. META INFORMATION
  92. img_shape: (800, 1196, 3)
  93. pad_shape: (800, 1216, 3)
  94. DATA FIELDS
  95. det_labels: tensor([], dtype=torch.int64)
  96. det_scores: tensor([])
  97. bboxes: tensor([], size=(0, 4))
  98. polygons: []
  99. ) at 0x7f660a6a7f70>
  100. >>> print(instance_data.cat([instance_data, instance_data]))
  101. <ListData(
  102. META INFORMATION
  103. img_shape: (800, 1196, 3)
  104. pad_shape: (800, 1216, 3)
  105. DATA FIELDS
  106. det_labels: tensor([2, 3, 2, 3])
  107. det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000])
  108. bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
  109. [0.8101, 0.3105, 0.5123, 0.6263],
  110. [0.4997, 0.7707, 0.0595, 0.4188],
  111. [0.8101, 0.3105, 0.5123, 0.6263]])
  112. polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]]
  113. ) at 0x7f203542feb0>
  114. """
  115. def __setattr__(self, name: str, value: list):
  116. """setattr is only used to set data.
  117. The value must have the attribute of `__len__` and have the same length
  118. of `ListData`.
  119. """
  120. if name in ("_metainfo_fields", "_data_fields"):
  121. if not hasattr(self, name):
  122. super().__setattr__(name, value)
  123. else:
  124. raise AttributeError(
  125. f"{name} has been used as a " "private attribute, which is immutable."
  126. )
  127. else:
  128. # assert isinstance(value, list), "value must be of type `list`"
  129. # if len(self) > 0:
  130. # assert len(value) == len(self), (
  131. # "The length of "
  132. # f"values {len(value)} is "
  133. # "not consistent with "
  134. # "the length of this "
  135. # ":obj:`ListData` "
  136. # f"{len(self)}"
  137. # )
  138. super().__setattr__(name, value)
  139. __setitem__ = __setattr__
  140. def __getitem__(self, item: IndexType) -> "ListData":
  141. """
  142. Args:
  143. item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
  144. :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
  145. Get the corresponding values according to item.
  146. Returns:
  147. :obj:`ListData`: Corresponding values.
  148. """
  149. assert isinstance(item, IndexType.__args__)
  150. if isinstance(item, list):
  151. item = np.array(item)
  152. if isinstance(item, np.ndarray):
  153. # The default int type of numpy is platform dependent, int32 for
  154. # windows and int64 for linux. `torch.Tensor` requires the index
  155. # should be int64, therefore we simply convert it to int64 here.
  156. # More details in https://github.com/numpy/numpy/issues/9464
  157. item = item.astype(np.int64) if item.dtype == np.int32 else item
  158. item = torch.from_numpy(item)
  159. if isinstance(item, str):
  160. return getattr(self, item)
  161. new_data = self.__class__(metainfo=self.metainfo)
  162. if isinstance(item, torch.Tensor):
  163. assert item.dim() == 1, "Only support to get the" " values along the first dimension."
  164. for k, v in self.items():
  165. if v is None:
  166. new_data[k] = None
  167. elif isinstance(v, torch.Tensor):
  168. new_data[k] = v[item]
  169. elif isinstance(v, np.ndarray):
  170. new_data[k] = v[item.cpu().numpy()]
  171. elif isinstance(v, (str, list, tuple)) or (
  172. hasattr(v, "__getitem__") and hasattr(v, "cat")
  173. ):
  174. # convert to indexes from BoolTensor
  175. if isinstance(item, BoolTypeTensor.__args__):
  176. indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist()
  177. else:
  178. indexes = item.cpu().numpy().tolist()
  179. slice_list = []
  180. if indexes:
  181. for index in indexes:
  182. slice_list.append(slice(index, None, len(v)))
  183. else:
  184. slice_list.append(slice(None, 0, None))
  185. r_list = [v[s] for s in slice_list]
  186. if isinstance(v, (str, list, tuple)):
  187. new_value = r_list[0]
  188. for r in r_list[1:]:
  189. new_value = new_value + r
  190. else:
  191. new_value = v.cat(r_list)
  192. new_data[k] = new_value
  193. else:
  194. raise ValueError(
  195. f"The type of `{k}` is `{type(v)}`, which has no "
  196. "attribute of `cat`, so it does not "
  197. "support slice with `bool`"
  198. )
  199. else:
  200. # item is a slice or int
  201. for k, v in self.items():
  202. if v is None:
  203. new_data[k] = None
  204. else:
  205. new_data[k] = v[item]
  206. return new_data # type:ignore
  207. @staticmethod
  208. def cat(instances_list: List["ListData"]) -> "ListData":
  209. """Concat the instances of all :obj:`ListData` in the list.
  210. Note: To ensure that cat returns as expected, make sure that
  211. all elements in the list must have exactly the same keys.
  212. Args:
  213. instances_list (list[:obj:`ListData`]): A list
  214. of :obj:`ListData`.
  215. Returns:
  216. :obj:`ListData`
  217. """
  218. assert all(isinstance(results, ListData) for results in instances_list)
  219. assert len(instances_list) > 0
  220. if len(instances_list) == 1:
  221. return instances_list[0]
  222. # metainfo and data_fields must be exactly the
  223. # same for each element to avoid exceptions.
  224. field_keys_list = [instances.all_keys() for instances in instances_list]
  225. assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len(
  226. set(itertools.chain(*field_keys_list))
  227. ) == len(field_keys_list[0]), (
  228. "There are different keys in "
  229. "`instances_list`, which may "
  230. "cause the cat operation "
  231. "to fail. Please make sure all "
  232. "elements in `instances_list` "
  233. "have the exact same key."
  234. )
  235. new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo)
  236. for k in instances_list[0].keys():
  237. values = [results[k] for results in instances_list]
  238. v0 = values[0]
  239. if isinstance(v0, torch.Tensor):
  240. new_values = torch.cat(values, dim=0)
  241. elif isinstance(v0, np.ndarray):
  242. new_values = np.concatenate(values, axis=0)
  243. elif isinstance(v0, (str, list, tuple)):
  244. new_values = v0[:]
  245. for v in values[1:]:
  246. new_values += v
  247. elif hasattr(v0, "cat"):
  248. new_values = v0.cat(values)
  249. else:
  250. raise ValueError(
  251. f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`"
  252. )
  253. new_data[k] = new_values
  254. return new_data # type:ignore
  255. def flatten(self, item: IndexType) -> List:
  256. """Flatten self[item].
  257. Returns:
  258. list: Flattened data fields.
  259. """
  260. return flatten_list(self[item])
  261. def elements_num(self, item: IndexType) -> int:
  262. """int: The number of elements in self[item]."""
  263. return len(self.flatten(item))
  264. def to_tuple(self, item: IndexType) -> tuple:
  265. """tuple: The data fields in self[item] converted to tuple."""
  266. return to_hashable(self[item])
  267. def __len__(self) -> int:
  268. """int: The length of ListData."""
  269. if len(self._data_fields) > 0:
  270. one_element = next(iter(self._data_fields))
  271. return len(getattr(self, one_element))
  272. # return len(self.values()[0])
  273. else:
  274. return 0

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.