You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Bridge.rst 5.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. `Learn the Basics <Basics.html>`_ ||
  2. `Quick Start <Quick-Start.html>`_ ||
  3. `Dataset & Data Structure <Datasets.html>`_ ||
  4. `Learning Part <Learning.html>`_ ||
  5. `Reasoning Part <Reasoning.html>`_ ||
  6. `Evaluation Metrics <Evaluation.html>`_ ||
  7. **Bridge**
  8. Bridge
  9. ======
  10. In this section, we will look at how to bridge learning and reasoning parts to train the model, which is the fundamental idea of Abductive Learning. ABLkit implements a set of bridge classes to achieve this.
  11. .. code:: python
  12. from ablkit.bridge import BaseBridge, SimpleBridge
  13. ``BaseBridge`` is an abstract class with the following initialization parameters:
  14. - ``model`` is an object of type ``ABLModel``. The learning part is wrapped in this object.
  15. - ``reasoner`` is an object of type ``Reasoner``. The reasoning part is wrapped in this object.
  16. ``BaseBridge`` has the following important methods that need to be overridden in subclasses:
  17. +---------------------------------------+----------------------------------------------------+
  18. | Method Signature | Description |
  19. +=======================================+====================================================+
  20. | ``predict(data_examples)`` | Predicts class probabilities and indices |
  21. | | for the given data examples. |
  22. +---------------------------------------+----------------------------------------------------+
  23. | ``abduce_pseudo_label(data_examples)``| Abduces pseudo-labels for the given data examples. |
  24. +---------------------------------------+----------------------------------------------------+
  25. | ``idx_to_pseudo_label(data_examples)``| Converts indices to pseudo-labels using |
  26. | | the provided or default mapping. |
  27. +---------------------------------------+----------------------------------------------------+
  28. | ``pseudo_label_to_idx(data_examples)``| Converts pseudo-labels to indices |
  29. | | using the provided or default remapping. |
  30. +---------------------------------------+----------------------------------------------------+
  31. | ``train(train_data)`` | Train the model. |
  32. +---------------------------------------+----------------------------------------------------+
  33. | ``test(test_data)`` | Test the model. |
  34. +---------------------------------------+----------------------------------------------------+
  35. where ``train_data`` and ``test_data`` are both in the form of a tuple or a `ListData <../API/ablkit.data.html#structures.ListData>`_. Regardless of the form, they all need to include three components: ``X``, ``gt_pseudo_label`` and ``Y``. Since ``ListData`` is the underlying data structure used throughout the ABLkit, tuple-formed data will be firstly transformed into ``ListData`` in the ``train`` and ``test`` methods, and such ``ListData`` instances are referred to as ``data_examples``. More details can be found in `preparing datasets <Datasets.html>`_.
  36. ``SimpleBridge`` inherits from ``BaseBridge`` and provides a basic implementation. Besides the ``model`` and ``reasoner``, ``SimpleBridge`` has an extra initialization argument, ``metric_list``, which will be used to evaluate model performance. Its training process involves several Abductive Learning loops and each loop consists of the following five steps:
  37. 1. Predict class probabilities and indices for the given data examples.
  38. 2. Transform indices into pseudo-labels.
  39. 3. Revise pseudo-labels based on abdutive reasoning.
  40. 4. Transform the revised pseudo-labels to indices.
  41. 5. Train the model.
  42. The fundamental part of the ``train`` method is as follows:
  43. .. code-block:: python
  44. def train(self, train_data, loops=50, segment_size=10000):
  45. """
  46. Parameters
  47. ----------
  48. train_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], List[Any]]]
  49. Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData``
  50. object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes.
  51. - ``X`` is a list of sublists representing the input data.
  52. - ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but not
  53. to train. ``gt_pseudo_label`` can be ``None``.
  54. - ``Y`` is a list representing the ground truth reasoning result for each sublist in ``X``.
  55. loops : int
  56. Learning part and Reasoning part will be iteratively optimized for ``loops`` times.
  57. segment_size : Union[int, float]
  58. Data will be split into segments of this size and data in each segment
  59. will be used together to train the model.
  60. """
  61. if isinstance(train_data, ListData):
  62. data_examples = train_data
  63. else:
  64. data_examples = self.data_preprocess(*train_data)
  65. if isinstance(segment_size, float):
  66. segment_size = int(segment_size * len(data_examples))
  67. for loop in range(loops):
  68. for seg_idx in range((len(data_examples) - 1) // segment_size + 1):
  69. sub_data_examples = data_examples[
  70. seg_idx * segment_size : (seg_idx + 1) * segment_size
  71. ]
  72. self.predict(sub_data_examples) # 1
  73. self.idx_to_pseudo_label(sub_data_examples) # 2
  74. self.abduce_pseudo_label(sub_data_examples) # 3
  75. self.pseudo_label_to_idx(sub_data_examples) # 4
  76. loss = self.model.train(sub_data_examples) # 5, self.model is an ABLModel object

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.