`Learn the Basics `_ ||
`Quick Start `_ ||
`Dataset & Data Structure `_ ||
`Learning Part `_ ||
`Reasoning Part `_ ||
`Evaluation Metrics `_ ||
**Bridge**
Bridge
======
In this section, we will look at how to bridge learning and reasoning parts to train the model, which is the fundamental idea of Abductive Learning. ABLkit implements a set of bridge classes to achieve this.
.. code:: python
from ablkit.bridge import BaseBridge, SimpleBridge
``BaseBridge`` is an abstract class with the following initialization parameters:
- ``model`` is an object of type ``ABLModel``. The learning part is wrapped in this object.
- ``reasoner`` is an object of type ``Reasoner``. The reasoning part is wrapped in this object.
``BaseBridge`` has the following important methods that need to be overridden in subclasses:
+---------------------------------------+----------------------------------------------------+
| Method Signature | Description |
+=======================================+====================================================+
| ``predict(data_examples)`` | Predicts class probabilities and indices |
| | for the given data examples. |
+---------------------------------------+----------------------------------------------------+
| ``abduce_pseudo_label(data_examples)``| Abduces pseudo-labels for the given data examples. |
+---------------------------------------+----------------------------------------------------+
| ``idx_to_pseudo_label(data_examples)``| Converts indices to pseudo-labels using |
| | the provided or default mapping. |
+---------------------------------------+----------------------------------------------------+
| ``pseudo_label_to_idx(data_examples)``| Converts pseudo-labels to indices |
| | using the provided or default remapping. |
+---------------------------------------+----------------------------------------------------+
| ``train(train_data)`` | Train the model. |
+---------------------------------------+----------------------------------------------------+
| ``test(test_data)`` | Test the model. |
+---------------------------------------+----------------------------------------------------+
where ``train_data`` and ``test_data`` are both in the form of a tuple or a `ListData <../API/ablkit.data.html#structures.ListData>`_. Regardless of the form, they all need to include three components: ``X``, ``gt_pseudo_label`` and ``Y``. Since ``ListData`` is the underlying data structure used throughout the ABLkit, tuple-formed data will be firstly transformed into ``ListData`` in the ``train`` and ``test`` methods, and such ``ListData`` instances are referred to as ``data_examples``. More details can be found in `preparing datasets `_.
``SimpleBridge`` inherits from ``BaseBridge`` and provides a basic implementation. Besides the ``model`` and ``reasoner``, ``SimpleBridge`` has an extra initialization argument, ``metric_list``, which will be used to evaluate model performance. Its training process involves several Abductive Learning loops and each loop consists of the following five steps:
1. Predict class probabilities and indices for the given data examples.
2. Transform indices into pseudo-labels.
3. Revise pseudo-labels based on abdutive reasoning.
4. Transform the revised pseudo-labels to indices.
5. Train the model.
The fundamental part of the ``train`` method is as follows:
.. code-block:: python
def train(self, train_data, loops=50, segment_size=10000):
"""
Parameters
----------
train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData``
object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes.
- ``X`` is a list of sublists representing the input data.
- ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but not
to train. ``gt_pseudo_label`` can be ``None``.
- ``Y`` is a list representing the ground truth reasoning result for each sublist in ``X``.
loops : int
Learning part and Reasoning part will be iteratively optimized for ``loops`` times.
segment_size : Union[int, float]
Data will be split into segments of this size and data in each segment
will be used together to train the model.
"""
if isinstance(train_data, ListData):
data_examples = train_data
else:
data_examples = self.data_preprocess(*train_data)
if isinstance(segment_size, float):
segment_size = int(segment_size * len(data_examples))
for loop in range(loops):
for seg_idx in range((len(data_examples) - 1) // segment_size + 1):
sub_data_examples = data_examples[
seg_idx * segment_size : (seg_idx + 1) * segment_size
]
self.predict(sub_data_examples) # 1
self.idx_to_pseudo_label(sub_data_examples) # 2
self.abduce_pseudo_label(sub_data_examples) # 3
self.pseudo_label_to_idx(sub_data_examples) # 4
loss = self.model.train(sub_data_examples) # 5, self.model is an ABLModel object