You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

list_data.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. """
  2. Copyright (c) OpenMMLab. All rights reserved.
  3. Modified from
  4. https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa: E501 pylint: disable=line-too-long
  5. """
  6. from typing import List, Union
  7. import numpy as np
  8. import torch
  9. from ...utils import flatten as flatten_list
  10. from ...utils import to_hashable
  11. from .base_data_element import BaseDataElement
  12. BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor]
  13. LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor]
  14. IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndarray]
  15. class ListData(BaseDataElement):
  16. """
  17. Abstract Data Interface used throughout the ABL kit.
  18. ``ListData`` is the underlying data structure used in the ABL kit,
  19. designed to manage diverse forms of data dynamically generated throughout the
  20. Abductive Learning (ABL) framework. This includes handling raw data, predicted
  21. pseudo-labels, abduced pseudo-labels, pseudo-label indices, etc.
  22. As a fundamental data structure in ABL, ``ListData`` is essential for the smooth
  23. transfer and manipulation of data across various components of the ABL framework,
  24. such as prediction, abductive reasoning, and training phases. It provides a
  25. unified data format across these stages, ensuring compatibility and flexibility
  26. in handling diverse data forms in the ABL framework.
  27. The attributes in ``ListData`` are divided into two parts,
  28. the ``metainfo`` and the ``data`` respectively.
  29. - ``metainfo``: Usually used to store basic information about data examples,
  30. such as symbol number, image size, etc. The attributes can be accessed or
  31. modified by dict-like or object-like operations, such as ``.`` (for data
  32. access and modification), ``in``, ``del``, ``pop(str)``, ``get(str)``,
  33. ``metainfo_keys()``, ``metainfo_values()``, ``metainfo_items()``,
  34. ``set_metainfo()`` (for set or change key-value pairs in metainfo).
  35. - ``data``: raw data, labels, predictions, and abduced results are stored.
  36. The attributes can be accessed or modified by dict-like or object-like operations,
  37. such as ``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``,
  38. ``values()``, ``items()``. Users can also apply tensor-like
  39. methods to all :obj:`torch.Tensor` in the ``data_fields``, such as ``.cuda()``,
  40. ``.cpu()``, ``.numpy()``, ``.to()``, ``to_tensor()``, ``.detach()``.
  41. ListData supports ``index`` and ``slice`` for data field. The type of value in
  42. data field can be either ``None`` or ``list`` of base data structures such as
  43. ``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``.
  44. This design is inspired by and extends the functionalities of the ``BaseDataElement``
  45. class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501 pylint: disable=line-too-long
  46. Examples:
  47. >>> from ablkit.data.structures import ListData
  48. >>> import numpy as np
  49. >>> import torch
  50. >>> data_examples = ListData()
  51. >>> data_examples.X = [list(torch.randn(2)) for _ in range(3)]
  52. >>> data_examples.Y = [1, 2, 3]
  53. >>> data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]]
  54. >>> len(data_examples)
  55. 3
  56. >>> print(data_examples)
  57. <ListData(
  58. META INFORMATION
  59. DATA FIELDS
  60. Y: [1, 2, 3]
  61. gt_pseudo_label: [[1, 2], [3, 4], [5, 6]]
  62. X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501 pylint: disable=line-too-long
  63. ) at 0x7f3bbf1991c0>
  64. >>> print(data_examples[:1])
  65. <ListData(
  66. META INFORMATION
  67. DATA FIELDS
  68. Y: [1]
  69. gt_pseudo_label: [[1, 2]]
  70. X: [[tensor(1.1949), tensor(-0.9378)]]
  71. ) at 0x7f3bbf1a3580>
  72. >>> print(data_examples.elements_num("X"))
  73. 6
  74. >>> print(data_examples.flatten("gt_pseudo_label"))
  75. [1, 2, 3, 4, 5, 6]
  76. >>> print(data_examples.to_tuple("Y"))
  77. (1, 2, 3)
  78. """
  79. def __setattr__(self, name: str, value: list):
  80. """setattr is only used to set data.
  81. The value must have the attribute of `__len__` and have the same length
  82. of `ListData`.
  83. """
  84. if name in ("_metainfo_fields", "_data_fields"):
  85. if not hasattr(self, name):
  86. super().__setattr__(name, value)
  87. else:
  88. raise AttributeError(
  89. f"{name} has been used as a " "private attribute, which is immutable."
  90. )
  91. else:
  92. # assert isinstance(value, list), "value must be of type `list`"
  93. # if len(self) > 0:
  94. # assert len(value) == len(self), (
  95. # "The length of "
  96. # f"values {len(value)} is "
  97. # "not consistent with "
  98. # "the length of this "
  99. # ":obj:`ListData` "
  100. # f"{len(self)}"
  101. # )
  102. super().__setattr__(name, value)
  103. __setitem__ = __setattr__
  104. def __getitem__(self, item: IndexType) -> "ListData":
  105. """
  106. Args:
  107. item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`,
  108. :obj:`torch.LongTensor`, :obj:`torch.BoolTensor`):
  109. Get the corresponding values according to item.
  110. Returns:
  111. :obj:`ListData`: Corresponding values.
  112. """
  113. assert isinstance(item, IndexType.__args__)
  114. if isinstance(item, list):
  115. item = np.array(item)
  116. if isinstance(item, np.ndarray):
  117. # The default int type of numpy is platform dependent, int32 for
  118. # windows and int64 for linux. `torch.Tensor` requires the index
  119. # should be int64, therefore we simply convert it to int64 here.
  120. # More details in https://github.com/numpy/numpy/issues/9464
  121. item = item.astype(np.int64) if item.dtype == np.int32 else item
  122. item = torch.from_numpy(item)
  123. if isinstance(item, str):
  124. return getattr(self, item)
  125. new_data = self.__class__(metainfo=self.metainfo)
  126. if isinstance(item, torch.Tensor):
  127. assert item.dim() == 1, "Only support to get the" " values along the first dimension."
  128. for k, v in self.items():
  129. if v is None:
  130. new_data[k] = None
  131. elif isinstance(v, torch.Tensor):
  132. new_data[k] = v[item]
  133. elif isinstance(v, np.ndarray):
  134. new_data[k] = v[item.cpu().numpy()]
  135. elif isinstance(v, (str, list, tuple)) or (
  136. hasattr(v, "__getitem__") and hasattr(v, "cat")
  137. ):
  138. # convert to indexes from BoolTensor
  139. if isinstance(item, BoolTypeTensor.__args__):
  140. indexes = torch.nonzero(item).view(-1).cpu().numpy().tolist()
  141. else:
  142. indexes = item.cpu().numpy().tolist()
  143. slice_list = []
  144. if indexes:
  145. for index in indexes:
  146. slice_list.append(slice(index, None, len(v)))
  147. else:
  148. slice_list.append(slice(None, 0, None))
  149. r_list = [v[s] for s in slice_list]
  150. if isinstance(v, (str, list, tuple)):
  151. new_value = r_list[0]
  152. for r in r_list[1:]:
  153. new_value = new_value + r
  154. else:
  155. new_value = v.cat(r_list)
  156. new_data[k] = new_value
  157. else:
  158. raise ValueError(
  159. f"The type of `{k}` is `{type(v)}`, which has no "
  160. "attribute of `cat`, so it does not "
  161. "support slice with `bool`"
  162. )
  163. else:
  164. # item is a slice or int
  165. for k, v in self.items():
  166. if v is None:
  167. new_data[k] = None
  168. else:
  169. new_data[k] = v[item]
  170. return new_data # type:ignore
  171. def flatten(self, item: str) -> List:
  172. """
  173. Flatten the list of the attribute specified by ``item``.
  174. Parameters
  175. ----------
  176. item
  177. Name of the attribute to be flattened.
  178. Returns
  179. -------
  180. list
  181. The flattened list of the attribute specified by ``item``.
  182. """
  183. return flatten_list(self[item])
  184. def elements_num(self, item: str) -> int:
  185. """
  186. Return the number of elements in the attribute specified by ``item``.
  187. Parameters
  188. ----------
  189. item : str
  190. Name of the attribute for which the number of elements is to be determined.
  191. Returns
  192. -------
  193. int
  194. The number of elements in the attribute specified by ``item``.
  195. """
  196. return len(self.flatten(item))
  197. def to_tuple(self, item: str) -> tuple:
  198. """
  199. Convert the attribute specified by ``item`` to a tuple.
  200. Parameters
  201. ----------
  202. item : str
  203. Name of the attribute to be converted.
  204. Returns
  205. -------
  206. tuple
  207. The attribute after conversion to a tuple.
  208. """
  209. return to_hashable(self[item])
  210. def __len__(self) -> int:
  211. """int: The length of ListData."""
  212. iterator = iter(self._data_fields)
  213. data = next(iterator)
  214. while getattr(self, data) is None:
  215. try:
  216. data = next(iterator)
  217. except StopIteration:
  218. break
  219. if getattr(self, data) is None:
  220. raise ValueError("All data fields are None.")
  221. else:
  222. return len(getattr(self, data))

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.