From ea0a5a366808cf64c86843920cd624fdbe9b12cc Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 21 Dec 2023 03:50:41 +0800 Subject: [PATCH] [DOC] modify doc of abl.structures --- abl/structures/base_data_element.py | 3 +- abl/structures/list_data.py | 211 +++++++++------------------- docs/API/abl.structures.rst | 4 +- docs/Intro/Datasets.rst | 8 +- 4 files changed, 74 insertions(+), 152 deletions(-) diff --git a/abl/structures/base_data_element.py b/abl/structures/base_data_element.py index 269f4d2..79cfa61 100644 --- a/abl/structures/base_data_element.py +++ b/abl/structures/base_data_element.py @@ -5,7 +5,8 @@ from typing import Any, Iterator, Optional, Tuple, Type, Union import numpy as np import torch - +# Modified from +# https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py class BaseDataElement: """A base data interface that supports Tensor-like and dict-like operations. diff --git a/abl/structures/list_data.py b/abl/structures/list_data.py index f972133..dbc8c2d 100644 --- a/abl/structures/list_data.py +++ b/abl/structures/list_data.py @@ -18,106 +18,47 @@ IndexType = Union[str, slice, int, list, LongTypeTensor, BoolTypeTensor, np.ndar # Modified from # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa class ListData(BaseDataElement): - """Data structure for instance-level annotations or predictions. + """ + Data structure for example-level data. Subclass of :class:`BaseDataElement`. All value in `data_fields` should have the same length. This design refer to - https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 - ListData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value - in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, - and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. + https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py + + ListData supports `index` and `slice` for data field. The type of value in data field can be either `None` or `list` of base data structures such as `torch.Tensor`, `numpy.ndarray`, `list`, `str` and `tuple`. Examples: - >>> # custom data structure - >>> class TmpObject: - ... def __init__(self, tmp) -> None: - ... assert isinstance(tmp, list) - ... self.tmp = tmp - ... def __len__(self): - ... return len(self.tmp) - ... def __getitem__(self, item): - ... if isinstance(item, int): - ... if item >= len(self) or item < -len(self): # type:ignore - ... raise IndexError(f'Index {item} out of range!') - ... else: - ... # keep the dimension - ... item = slice(item, None, len(self)) - ... return TmpObject(self.tmp[item]) - ... @staticmethod - ... def cat(tmp_objs): - ... assert all(isinstance(results, TmpObject) for results in tmp_objs) - ... if len(tmp_objs) == 1: - ... return tmp_objs[0] - ... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] - ... tmp_list = list(itertools.chain(*tmp_list)) - ... new_data = TmpObject(tmp_list) - ... return new_data - ... def __repr__(self): - ... return str(self.tmp) - >>> from mmengine.structures import ListData + >>> from abl.structures import ListData >>> import numpy as np >>> import torch - >>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) - >>> instance_data = ListData(metainfo=img_meta) - >>> 'img_shape' in instance_data - True - >>> instance_data.det_labels = torch.LongTensor([2, 3]) - >>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) - >>> instance_data.bboxes = torch.rand((2, 4)) - >>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) - >>> len(instance_data) - 2 - >>> print(instance_data) - - >>> sorted_results = instance_data[instance_data.det_scores.sort().indices] - >>> sorted_results.det_scores - tensor([0.7000, 0.8000]) - >>> print(instance_data[instance_data.det_scores > 0.75]) - - >>> print(instance_data[instance_data.det_scores > 1]) + >>> data_examples = ListData() + >>> data_examples.X = [list(torch.randn(2)) for _ in range(3)] + >>> data_examples.Y = [1, 2, 3] + >>> data_examples.gt_pseudo_label = [[1, 2], [3, 4], [5, 6]] + >>> len(data_examples) + 3 + >>> print(data_examples) - >>> print(instance_data.cat([instance_data, instance_data])) + Y: [1, 2, 3] + gt_pseudo_label: [[1, 2], [3, 4], [5, 6]] + X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] + ) at 0x7f3bbf1991c0> + >>> print(data_examples[:1]) + Y: [1] + gt_pseudo_label: [[1, 2]] + X: [[tensor(1.1949), tensor(-0.9378)]] + ) at 0x7f3bbf1a3580> + >>> print(data_examples.elements_num("X")) + 6 + >>> print(data_examples.flatten("gt_pseudo_label")) + [1, 2, 3, 4, 5, 6] + >>> print(data_examples.to_tuple("Y")) + (1, 2, 3) """ def __setattr__(self, name: str, value: list): @@ -224,74 +165,52 @@ class ListData(BaseDataElement): new_data[k] = v[item] return new_data # type:ignore - @staticmethod - def cat(instances_list: List["ListData"]) -> "ListData": - """Concat the instances of all :obj:`ListData` in the list. + def flatten(self, item: str) -> List: + """ + Flatten the list of the attribute specified by ``item``. - Note: To ensure that cat returns as expected, make sure that - all elements in the list must have exactly the same keys. + Parameters + ---------- + item + Name of the attribute to be flattened. - Args: - instances_list (list[:obj:`ListData`]): A list - of :obj:`ListData`. + Returns + ------- + list + The flattened list of the attribute specified by ``item``. + """ + return flatten_list(self[item]) - Returns: - :obj:`ListData` + def elements_num(self, item: str) -> int: """ - assert all(isinstance(results, ListData) for results in instances_list) - assert len(instances_list) > 0 - if len(instances_list) == 1: - return instances_list[0] - - # metainfo and data_fields must be exactly the - # same for each element to avoid exceptions. - field_keys_list = [instances.all_keys() for instances in instances_list] - assert len({len(field_keys) for field_keys in field_keys_list}) == 1 and len( - set(itertools.chain(*field_keys_list)) - ) == len(field_keys_list[0]), ( - "There are different keys in " - "`instances_list`, which may " - "cause the cat operation " - "to fail. Please make sure all " - "elements in `instances_list` " - "have the exact same key." - ) - - new_data = instances_list[0].__class__(metainfo=instances_list[0].metainfo) - for k in instances_list[0].keys(): - values = [results[k] for results in instances_list] - v0 = values[0] - if isinstance(v0, torch.Tensor): - new_values = torch.cat(values, dim=0) - elif isinstance(v0, np.ndarray): - new_values = np.concatenate(values, axis=0) - elif isinstance(v0, (str, list, tuple)): - new_values = v0[:] - for v in values[1:]: - new_values += v - elif hasattr(v0, "cat"): - new_values = v0.cat(values) - else: - raise ValueError( - f"The type of `{k}` is `{type(v0)}` which has no " "attribute of `cat`" - ) - new_data[k] = new_values - return new_data # type:ignore + Return the number of elements in the attribute specified by ``item``. - def flatten(self, item: IndexType) -> List: - """Flatten self[item]. + Parameters + ---------- + item : str + Name of the attribute for which the number of elements is to be determined. - Returns: - list: Flattened data fields. + Returns + ------- + int + The number of elements in the attribute specified by ``item``. """ - return flatten_list(self[item]) - - def elements_num(self, item: IndexType) -> int: - """int: The number of elements in self[item].""" return len(self.flatten(item)) - def to_tuple(self, item: IndexType) -> tuple: - """tuple: The data fields in self[item] converted to tuple.""" + def to_tuple(self, item: str) -> tuple: + """ + Convert the attribute specified by ``item`` to a tuple. + + Parameters + ---------- + item : str + Name of the attribute to be converted. + + Returns + ------- + tuple + The attribute after conversion to a tuple. + """ return to_hashable(self[item]) def __len__(self) -> int: diff --git a/docs/API/abl.structures.rst b/docs/API/abl.structures.rst index 26f3973..fee74c4 100644 --- a/docs/API/abl.structures.rst +++ b/docs/API/abl.structures.rst @@ -1,7 +1,7 @@ abl.structures ================== -.. automodule:: abl.structures +.. autoclass:: abl.structures.ListData :members: :undoc-members: - :show-inheritance: \ No newline at end of file + :show-inheritance: diff --git a/docs/Intro/Datasets.rst b/docs/Intro/Datasets.rst index 20eb1ac..3ee1050 100644 --- a/docs/Intro/Datasets.rst +++ b/docs/Intro/Datasets.rst @@ -21,7 +21,7 @@ In this section, we will look at the datasets and data structures in ABL-Package Dataset ------- -ABL-Package assumes user data to be structured as a tuple, comprising the following three components: +ABL-Package assumes user data to be either structured as a tuple or a ``ListData`` which is the underlying data structure utilized in the whole package and will be introduced in the next section. Regardless of the chosen format, the data should encompass the following three essential components: - ``X``: List[List[Any]] @@ -53,9 +53,11 @@ As an illustration, in the MNIST Addition example, the data used for training ar Data Structure -------------- -In Abductive Learning, there are various types of data in the training and testing process, such as raw data, pseudo-label, index of the pseudo-label, abduced pseudo-label, etc. To enhance the stability and versatility, ABL-Package uses `abstract data interfaces <../API/abl.structures.html>`_ to encapsulate various data during the implementation of the model. +Besides the user-provided dataset, various forms of data are utilized and dynamicly generate throughout the training and testing process of Abductive Learning framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, and so on. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.structures.html>`_ to encapsulate different forms of data that will be used in the total learning process. -One of the most commonly used abstract data interface is ``ListData``. Besides orginizing data into tuple, we can also prepare data to be in the form of this data interface. +``BaseDataElement`` is the base class for all abstract data interfaces. Inherited from ``BaseDataElement``, ``ListData`` is the most commonly used abstract data interface in ABL-Package. As the fundamental data structure, ``ListData`` implements commonly used data manipulation methods and is responsible for transferring data between various components of ABL, ensuring that stages such as prediction, training, and abductive reasoning can utilize ``ListData`` as a unified input format. + +Before proceeding to other stages, user-provided datasets are firstly converted into ``ListData``. For flexibility, ABL-Package also allows user to directly supply data in ``ListData`` format, which similarly requires the inclusion of three attributes: ``X``, ``gt_pseudo_label``, and ``Y``. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.structures.html>`_. .. code-block:: python