You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

list_data.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. from typing import List, Union
  3. import numpy as np
  4. import torch
  5. from ...utils import flatten as flatten_list
  6. from ...utils import to_hashable
  7. from .base_data_element import BaseDataElement
  8. BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
  9. LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
  10. IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray]
  11. # Modified from
  12. # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
  13. class ListData(BaseDataElement):
  14. """
  15. Abstract Data Interface used throughout the ABL-Package.
  16. ``ListData`` is the underlying data structure used in the ABL-Package,
  17. designed to manage diverse forms of data dynamically generated throughout the
  18. Abductive Learning (ABL) framework. This includes handling raw data, predicted
  19. pseudo-labels, abduced pseudo-labels, pseudo-label indices, etc.
  20. As a fundamental data structure in ABL, ``ListData`` is essential for the smooth
  21. transfer and manipulation of data across various components of the ABL framework,
  22. such as prediction, abductive reasoning, and training phases. It provides a
  23. unified data format across these stages, ensuring compatibility and flexibility
  24. in handling diverse data forms in the ABL framework.
  25. The attributes in ``ListData`` are divided into two parts,
  26. the ``metainfo`` and the ``data`` respectively.
  27. - ``metainfo``: Usually used to store basic information about data examples,
  28. such as symbol number, image size, etc. The attributes can be accessed or
  29. modified by dict-like or object-like operations, such as ``.`` (for data
  30. access and modification), ``in``, ``del``, ``pop(str)``, ``get(str)``,
  31. ``metainfo_keys()``, ``metainfo_values()``, ``metainfo_items()``,
  32. ``set_metainfo()`` (for set or change key-value pairs in metainfo).
  33. - ``data``: raw data, labels, predictions, and abduced results are stored.
  34. The attributes can be accessed or modified by dict-like or object-like operations,
  35. such as ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``,
  36. ``values()``, ``items()``. Users can also apply tensor-like
  37. methods to all :obj:`torch.Tensor` in the ``data_fields``, such as ``.cuda()``,
  38. ``.cpu()``, ``.numpy()``, ``.to()``, ``to_tensor()``, ``.detach()``.
  39. ListData supports ``index`` and ``slice`` for data field. The type of value in
  40. data field can be either ``None`` or ``list`` of base data structures such as
  41. ``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``.
  42. This design is inspired by and extends the functionalities of the ``BaseDataElement``
  43. class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_.
  44. Examples:
  45. >>> from abl.data.structures import ListData
  46. >>> import numpy as np
  47. >>> import torch
  48. >>> data_examples = ListData()
  49. >>> data_examples.X = [list(torch.randn(2)) for _ in range(3)]
  50. >>> data_examples.Y = [1, 2, 3]
  51. >>> data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]]
  52. >>> len(data_examples)
  53. 3
  54. >>> print(data_examples)
  55. <ListData(
  56. META INFORMATION
  57. DATA FIELDS
  58. Y: [1, 2, 3]
  59. gt_pseudo_label: [[1, 2], [3, 4], [5, 6]]
  60. X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]]
  61. ) at 0x7f3bbf1991c0>
  62. >>> print(data_examples[:1])
  63. <ListData(
  64. META INFORMATION
  65. DATA FIELDS
  66. Y: [1]
  67. gt_pseudo_label: [[1, 2]]
  68. X: [[tensor(1.1949), tensor(-0.9378)]]
  69. ) at 0x7f3bbf1a3580>
  70. >>> print(data_examples.elements_num("X"))
  71. 6
  72. >>> print(data_examples.flatten("gt_pseudo_label"))
  73. [1, 2, 3, 4, 5, 6]
  74. >>> print(data_examples.to_tuple("Y"))
  75. (1, 2, 3)
  76. """
  77. def __setattr__(self, name: str, value: list):
  78. """setattr is only used to set data.
  79. The value must have the attribute of `__len__` and have the same length
  80. of `ListData`.
  81. """
  82. if name in ("_metainfo_fields", "_data_fields"):
  83. if not hasattr(self, name):
  84. super().__setattr__(name, value)
  85. else:
  86. raise AttributeError(
  87. f"{name} has been used as a " "private attribute, which is immutable."
  88. )
  89. else:
  90. # assert isinstance(value, list), "value must be of type `list`"
  91. # if len(self) > 0:
  92. # assert len(value) == len(self), (
  93. # "The length of "
  94. # f"values {len(value)} is "
  95. # "not consistent with "
  96. # "the length of this "
  97. # ":obj:`ListData` "
  98. # f"{len(self)}"
  99. # )
  100. super().__setattr__(name, value)
  101. __setitem__ = __setattr__
  102. def __getitem__(self, item: IndexType) -> "ListData":
  103. """
  104. Args:
  105. item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
  106. :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
  107. Get the corresponding values according to item.
  108. Returns:
  109. :obj:`ListData`: Corresponding values.
  110. """
  111. assert isinstance(item, IndexType.__args__)
  112. if isinstance(item, list):
  113. item = np.array(item)
  114. if isinstance(item, np.ndarray):
  115. # The default int type of numpy is platform dependent, int32 for
  116. # windows and int64 for linux. `torch.Tensor` requires the index
  117. # should be int64, therefore we simply convert it to int64 here.
  118. # More details in https://github.com/numpy/numpy/issues/9464
  119. item = item.astype(np.int64) if item.dtype == np.int32 else item
  120. item = torch.from_numpy(item)
  121. if isinstance(item, str):
  122. return getattr(self, item)
  123. new_data = self.__class__(metainfo=self.metainfo)
  124. if isinstance(item, torch.Tensor):
  125. assert item.dim() == 1, "Only support to get the" " values along the first dimension."
  126. for k, v in self.items():
  127. if v is None:
  128. new_data[k] = None
  129. elif isinstance(v, torch.Tensor):
  130. new_data[k] = v[item]
  131. elif isinstance(v, np.ndarray):
  132. new_data[k] = v[item.cpu().numpy()]
  133. elif isinstance(v, (str, list, tuple)) or (
  134. hasattr(v, "__getitem__") and hasattr(v, "cat")
  135. ):
  136. # convert to indexes from BoolTensor
  137. if isinstance(item, BoolTypeTensor.__args__):
  138. indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist()
  139. else:
  140. indexes = item.cpu().numpy().tolist()
  141. slice_list = []
  142. if indexes:
  143. for index in indexes:
  144. slice_list.append(slice(index, None, len(v)))
  145. else:
  146. slice_list.append(slice(None, 0, None))
  147. r_list = [v[s] for s in slice_list]
  148. if isinstance(v, (str, list, tuple)):
  149. new_value = r_list[0]
  150. for r in r_list[1:]:
  151. new_value = new_value + r
  152. else:
  153. new_value = v.cat(r_list)
  154. new_data[k] = new_value
  155. else:
  156. raise ValueError(
  157. f"The type of `{k}` is `{type(v)}`, which has no "
  158. "attribute of `cat`, so it does not "
  159. "support slice with `bool`"
  160. )
  161. else:
  162. # item is a slice or int
  163. for k, v in self.items():
  164. if v is None:
  165. new_data[k] = None
  166. else:
  167. new_data[k] = v[item]
  168. return new_data # type:ignore
  169. def flatten(self, item: str) -> List:
  170. """
  171. Flatten the list of the attribute specified by ``item``.
  172. Parameters
  173. ----------
  174. item
  175. Name of the attribute to be flattened.
  176. Returns
  177. -------
  178. list
  179. The flattened list of the attribute specified by ``item``.
  180. """
  181. return flatten_list(self[item])
  182. def elements_num(self, item: str) -> int:
  183. """
  184. Return the number of elements in the attribute specified by ``item``.
  185. Parameters
  186. ----------
  187. item : str
  188. Name of the attribute for which the number of elements is to be determined.
  189. Returns
  190. -------
  191. int
  192. The number of elements in the attribute specified by ``item``.
  193. """
  194. return len(self.flatten(item))
  195. def to_tuple(self, item: str) -> tuple:
  196. """
  197. Convert the attribute specified by ``item`` to a tuple.
  198. Parameters
  199. ----------
  200. item : str
  201. Name of the attribute to be converted.
  202. Returns
  203. -------
  204. tuple
  205. The attribute after conversion to a tuple.
  206. """
  207. return to_hashable(self[item])
  208. def __len__(self) -> int:
  209. """int: The length of ListData."""
  210. iterator = iter(self._data_fields)
  211. data = next(iterator)
  212. while getattr(self, data) is None:
  213. try:
  214. data = next(iterator)
  215. except StopIteration:
  216. break
  217. if getattr(self, data) is None:
  218. raise ValueError("All data fields are None.")
  219. else:
  220. return len(getattr(self, data))

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.