Browse Source

[DOC] refine API

pull/1/head
troyyyyy 2 years ago
parent
commit
e1ff6d8508
8 changed files with 54 additions and 58 deletions
  1. +7
    -8
      abl/data/structures/list_data.py
  2. +11
    -11
      abl/learning/abl_model.py
  3. +2
    -2
      abl/learning/basic_nn.py
  4. +24
    -24
      abl/reasoning/kb.py
  5. +2
    -2
      abl/reasoning/reasoner.py
  6. +3
    -3
      docs/API/abl.data.rst
  7. +2
    -5
      docs/API/abl.learning.rst
  8. +3
    -3
      docs/Intro/Datasets.rst

+ 7
- 8
abl/data/structures/list_data.py View File

@@ -20,12 +20,12 @@ class ListData(BaseDataElement):
"""
Abstract Data Interface used throughout the ABL-Package.

`ListData` is the underlying data structure used in the ABL-Package,
``ListData`` is the underlying data structure used in the ABL-Package,
designed to manage diverse forms of data dynamically generated throughout the
Abductive Learning (ABL) framework. This includes handling raw data, predicted
pseudo-labels, abduced pseudo-labels, pseudo-label indices, etc.

As a fundamental data structure in ABL, `ListData` is essential for the smooth
As a fundamental data structure in ABL, ``ListData`` is essential for the smooth
transfer and manipulation of data across various components of the ABL framework,
such as prediction, abductive reasoning, and training phases. It provides a
unified data format across these stages, ensuring compatibility and flexibility
@@ -48,13 +48,12 @@ class ListData(BaseDataElement):
methods to all :obj:`torch.Tensor` in the ``data_fields``, such as ``.cuda()``,
``.cpu()``, ``.numpy()``, ``.to()``, ``to_tensor()``, ``.detach()``.

ListData supports `index` and `slice` for data field. The type of value in
data field can be either `None` or `list` of base data structures such as
`torch.Tensor`, `numpy.ndarray`, `list`, `str` and `tuple`.
ListData supports ``index`` and ``slice`` for data field. The type of value in
data field can be either ``None`` or ``list`` of base data structures such as
``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``.

This design is inspired by and extends the functionalities of the `BaseDataElement`
class implemented in MMEngine.
https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py # noqa E501
This design is inspired by and extends the functionalities of the ``BaseDataElement``
class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_.

Examples:
>>> from abl.data.structures import ListData


+ 11
- 11
abl/learning/abl_model.py View File

@@ -13,9 +13,9 @@ class ABLModel:
----------
base_model : Machine Learning Model
The machine learning base model used for training and prediction. This model should
implement the 'fit' and 'predict' methods. It's recommended, but not required,for the
model to also implement the 'predict_proba' method for generating probabilistic
predictions.
implement the ``fit`` and ``predict`` methods. It's recommended, but not required, for the
model to also implement the ``predict_proba`` method for generating
predictions on the probabilities.
"""

def __init__(self, base_model: Any) -> None:
@@ -61,8 +61,8 @@ class ABLModel:
Parameters
----------
data_examples : ListData
A batch of data to train on, which typically contains the data, `X`, and the
corresponding labels, `abduced_idx`.
A batch of data to train on, which typically contains the data, ``X``, and the
corresponding labels, ``abduced_idx``.

Returns
-------
@@ -80,8 +80,8 @@ class ABLModel:
Parameters
----------
data_examples : ListData
A batch of data to train on, which typically contains the data, `X`,
and the corresponding labels, `abduced_idx`.
A batch of data to train on, which typically contains the data, ``X``,
and the corresponding labels, ``abduced_idx``.

Returns
-------
@@ -119,8 +119,8 @@ class ABLModel:
"""
Save the model to a file.

This method delegates to the 'save' method of self.base_model. The arguments passed to
this method should match those expected by the 'save' method of self.base_model.
This method delegates to the ``save`` method of self.base_model. The arguments passed to
this method should match those expected by the ``save`` method of self.base_model.
"""
self._model_operation("save", *args, **kwargs)

@@ -128,7 +128,7 @@ class ABLModel:
"""
Load the model from a file.

This method delegates to the 'load' method of self.base_model. The arguments passed to
this method should match those expected by the 'load' method of self.base_model.
This method delegates to the ``load`` method of self.base_model. The arguments passed to
this method should match those expected by the ``load`` method of self.base_model.
"""
self._model_operation("load", *args, **kwargs)

+ 2
- 2
abl/learning/basic_nn.py View File

@@ -119,7 +119,7 @@ class BasicNN:

def _fit(self, data_loader: DataLoader) -> BasicNN:
"""
Internal method to fit the model on data for self.num_epochs times,
Internal method to fit the model on data for ``self.num_epochs`` times,
with early stopping.

Parameters
@@ -458,7 +458,7 @@ class BasicNN:
the epoch_id at which the model and optimizer is saved. if both save_path and
epoch_id are provided, save_path will be used. If only epoch_id is specified,
model and optimizer will be saved to the path f"model_checkpoint_epoch_{epoch_id}.pth"
under self.save_dir. save_path and epoch_id can not be None simultaneously.
under ``self.save_dir``. save_path and epoch_id can not be None simultaneously.

Parameters
----------


+ 24
- 24
abl/reasoning/kb.py View File

@@ -45,7 +45,7 @@ class KBBase(ABC):
-----
Users should derive from this base class to build their own knowledge base. For the
user-build KB (a derived subclass), it's only required for the user to provide the
`pseudo_label_list` and override the `logic_forward` function (specifying how to
``pseudo_label_list`` and override the ``logic_forward`` function (specifying how to
perform logical reasoning). After that, other operations (e.g. how to perform abductive
reasoning) will be automatically set up.
"""
@@ -88,7 +88,7 @@ class KBBase(ABC):
Parameters
----------
pseudo_label : List[Any]
Pseudo label example.
Pseudo-label example.
x : Optional[List[Any]]
The corresponding input example. If deductive logical reasoning does not require any
information from the input, the overridden function provided by the user can omit
@@ -114,7 +114,7 @@ class KBBase(ABC):
Parameters
----------
pseudo_label : List[Any]
Pseudo label example (to be revised by abductive reasoning).
Pseudo-label example (to be revised by abductive reasoning).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
@@ -167,7 +167,7 @@ class KBBase(ABC):
Parameters
----------
pseudo_label : List[Any]
Pseudo label example (to be revised).
Pseudo-label example (to be revised).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
@@ -231,7 +231,7 @@ class KBBase(ABC):
Parameters
----------
pseudo_label : List[Any]
Pseudo label example (to be revised).
Pseudo-label example (to be revised).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
@@ -285,22 +285,22 @@ class GroundKB(KBBase):
"""
Knowledge base with a ground KB (GKB). Ground KB is a knowledge base prebuilt upon
class initialization, storing all potential candidates along with their respective
reasoning result. Ground KB can accelerate abductive reasoning in `abduce_candidates`.
reasoning result. Ground KB can accelerate abductive reasoning in ``abduce_candidates``.

Parameters
----------
pseudo_label_list : list
Refer to class `KBBase`.
Refer to class ``KBBase``.
GKB_len_list : list
List of possible lengths for a pseudo-label example.
max_err : float, optional
Refer to class `KBBase`.
Refer to class ``KBBase``.

Notes
-----
Users can also inherit from this class to build their own knowledge base. Similar
to `KBBase`, users are only required to provide the `pseudo_label_list` and override
the `logic_forward` function. Additionally, users should provide the `GKB_len_list`.
to ``KBBase``, users are only required to provide the ``pseudo_label_list`` and override
the ``logic_forward`` function. Additionally, users should provide the ``GKB_len_list``.
After that, other operations (e.g. auto-construction of GKB, and how to perform
abductive reasoning) will be automatically set up.
"""
@@ -329,7 +329,7 @@ class GroundKB(KBBase):

def _get_GKB(self):
"""
Prebuild the GKB according to `pseudo_label_list` and `GKB_len_list`.
Prebuild the GKB according to ``pseudo_label_list`` and ``GKB_len_list``.
"""
X, Y = [], []
for length in self.GKB_len_list:
@@ -365,7 +365,7 @@ class GroundKB(KBBase):
Parameters
----------
pseudo_label : List[Any]
Pseudo label example (to be revised by abductive reasoning).
Pseudo-label example (to be revised by abductive reasoning).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
@@ -447,20 +447,20 @@ class PrologKB(KBBase):
Parameters
----------
pseudo_label_list : list
Refer to class `KBBase`.
Refer to class ``KBBase``.
pl_file :
Prolog file containing the KB.
max_err : float, optional
Refer to class `KBBase`.
Refer to class ``KBBase``.

Notes
-----
Users can instantiate this class to build their own knowledge base. During the
instantiation, users are only required to provide the `pseudo_label_list` and `pl_file`.
instantiation, users are only required to provide the ``pseudo_label_list`` and ``pl_file``.
To use the default logic forward and abductive reasoning methods in this class, in the
Prolog (.pl) file, there needs to be a rule which is strictly formatted as
`logic_forward(Pseudo_labels, Res).`, e.g., `logic_forward([A,B], C) :- C is A+B`.
For specifics, refer to the `logic_forward` and `get_query_string` functions in this
``logic_forward(Pseudo_labels, Res).``, e.g., ``logic_forward([A,B], C) :- C is A+B``.
For specifics, refer to the ``logic_forward`` and ``get_query_string`` functions in this
class. Users are also welcome to override related functions for more flexible support.
"""

@@ -475,15 +475,15 @@ class PrologKB(KBBase):

def logic_forward(self, pseudo_label: List[Any]) -> Any:
"""
Consult prolog with the query `logic_forward(pseudo_labels, Res).`, and set the
returned `Res` as the reasoning results. To use this default function, there must be
a `logic_forward` method in the pl file to perform reasoning.
Consult prolog with the query ``logic_forward(pseudo_labels, Res).``, and set the
returned ``Res`` as the reasoning results. To use this default function, there must be
a ``logic_forward`` method in the pl file to perform reasoning.
Otherwise, users would override this function.

Parameters
----------
pseudo_label : List[Any]
Pseudo label example.
Pseudo-label example.
"""
result = list(self.prolog.query("logic_forward(%s, Res)." % pseudo_label))[0]["Res"]
if result == "true":
@@ -520,12 +520,12 @@ class PrologKB(KBBase):
Get the query to be used for consulting Prolog.
This is a default function for demo, users would override this function to adapt to
their own Prolog file. In this demo function, return query
`logic_forward([kept_labels, Revise_labels], Res).`.
``logic_forward([kept_labels, Revise_labels], Res).``.

Parameters
----------
pseudo_label : List[Any]
Pseudo label example (to be revised by abductive reasoning).
Pseudo-label example (to be revised by abductive reasoning).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
@@ -559,7 +559,7 @@ class PrologKB(KBBase):
Parameters
----------
pseudo_label : List[Any]
Pseudo label example (to be revised).
Pseudo-label example (to be revised).
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]


+ 2
- 2
abl/reasoning/reasoner.py View File

@@ -251,7 +251,7 @@ class Reasoner:

def _get_max_revision_num(self, max_revision: Union[int, float], symbol_num: int) -> int:
"""
Get the maximum revision number according to input `max_revision`.
Get the maximum revision number according to input ``max_revision``.
"""
if not isinstance(max_revision, (int, float)):
raise TypeError(f"Parameter must be of type int or float, but got {type(max_revision)}")
@@ -313,7 +313,7 @@ class Reasoner:
def batch_abduce(self, data_examples: ListData) -> List[List[Any]]:
"""
Perform abductive reasoning on the given prediction data examples.
For detailed information, refer to `abduce`.
For detailed information, refer to ``abduce``.
"""
abduced_pseudo_label = [self.abduce(data_example) for data_example in data_examples]
data_examples.abduced_pseudo_label = abduced_pseudo_label


+ 3
- 3
docs/API/abl.data.rst View File

@@ -1,7 +1,7 @@
abl.data
===================

Data Structure
``structures``
--------------

.. autoclass:: abl.data.structures.ListData
@@ -9,8 +9,8 @@ Data Structure
:undoc-members:
:show-inheritance:

Evaluation Metric
-----------------
``evaluation``
--------------

.. automodule:: abl.data.evaluation
:members:


+ 2
- 5
docs/API/abl.learning.rst View File

@@ -1,9 +1,6 @@
abl.learning
==================

Learning
--------

.. autoclass:: abl.learning.ABLModel
:members:
:undoc-members:
@@ -14,8 +11,8 @@ Learning
:undoc-members:
:show-inheritance:

Torch Dataset
-------------
``torch_dataset``
-----------------

.. automodule:: abl.learning.torch_dataset
:members:


+ 3
- 3
docs/Intro/Datasets.rst View File

@@ -44,7 +44,7 @@ ABL-Package assumes user data to be either structured as a tuple or a ``ListData

The length of ``X``, ``gt_pseudo_label`` (if not ``None``) and ``Y`` should be the same. Also, each sublist in ``gt_pseudo_label`` should have the same length as the sublist in ``X``.

As an illustration, in the MNIST Addition example, the data used for training are organized as follows:
As an illustration, in the MNIST Addition task, the data are organized as follows:

.. image:: ../img/Datasets_1.png
:width: 350px
@@ -53,11 +53,11 @@ As an illustration, in the MNIST Addition example, the data used for training ar
Data Structure
--------------

Besides the user-provided dataset, various forms of data are utilized and dynamicly generated throughout the training and testing process of Abductive Learning framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, and so on. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.data.html#data-structure>`_ to encapsulate different forms of data that will be used in the total learning process.
Besides the user-provided dataset, various forms of data are utilized and dynamicly generated throughout the training and testing process of Abductive Learning framework. Examples include raw data, predicted pseudo-label, abduced pseudo-label, pseudo-label indices, etc. To manage this diversity and ensure a stable, versatile interface, ABL-Package employs `abstract data interfaces <../API/abl.data.html#structure>`_ to encapsulate different forms of data that will be used in the total learning process.

``ListData`` is the underlying abstract data interface utilized in ABL-Package. As the fundamental data structure, ``ListData`` implements commonly used data manipulation methods and is responsible for transferring data between various components of ABL, ensuring that stages such as prediction, abductive reasoning, and training can utilize ``ListData`` as a unified input format.

Before proceeding to other stages, user-provided datasets are firstly converted into ``ListData``. For flexibility, ABL-Package also allows user to directly supply data in ``ListData`` format, which similarly requires the inclusion of three attributes: ``X``, ``gt_pseudo_label``, and ``Y``. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.data.html#data-structure>`_.
Before proceeding to other stages, user-provided datasets are firstly converted into ``ListData``. For flexibility, ABL-Package also allows user to directly supply data in ``ListData`` format, which similarly requires the inclusion of three attributes: ``X``, ``gt_pseudo_label``, and ``Y``. The following code shows the basic usage of ``ListData``. More information can be found in the `API documentation <../API/abl.data.html#structure>`_.

.. code-block:: python



Loading…
Cancel
Save