You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MNISTAdd.rst 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. MNIST Addition
  2. ==============
  3. This example shows a simple implementation of MNIST Addition, which was
  4. first introduced in `Manhaeve et al.,
  5. 2018 <https://arxiv.org/abs/1805.10872>`__. In this task, the inputs are
  6. pairs of MNIST handwritten images, and the outputs are their sums.
  7. In Abductive Learning, we hope to first use learning part to map the
  8. input images to their digits (we call it pseudo labels), and then use
  9. reasoning part to calculate the summation of these pseudo labels to get
  10. the final result.
  11. .. code:: ipython3
  12. import os.path as osp
  13. import torch
  14. import torch.nn as nn
  15. import matplotlib.pyplot as plt
  16. from abl.bridge import SimpleBridge
  17. from abl.evaluation import ReasoningMetric, SymbolMetric
  18. from abl.learning import ABLModel, BasicNN
  19. from abl.reasoning import KBBase, Reasoner
  20. from abl.utils import ABLLogger, print_log
  21. from examples.mnist_add.datasets import get_mnist_add
  22. from examples.models.nn import LeNet5
  23. Load Datasets
  24. -------------
  25. First, we get training and testing data:
  26. .. code:: ipython3
  27. train_data = get_mnist_add(train=True, get_pseudo_label=True)
  28. test_data = get_mnist_add(train=False, get_pseudo_label=True)
  29. The datasets are illustrated as follows:
  30. .. code:: ipython3
  31. print(f"There are {len(train_data[0])} data examples in the training set and {len(test_data[0])} data examples in the test set")
  32. print(f"Each of the data example has {len(train_data)} components: X, gt_pseudo_label, and Y.")
  33. print("As an illustration, in the First data example of the training set, we have:")
  34. print(f"X ({len(train_data[0][0])} images):")
  35. plt.subplot(1,2,1)
  36. plt.axis('off')
  37. plt.imshow(train_data[0][0][0].numpy().transpose(1, 2, 0))
  38. plt.subplot(1,2,2)
  39. plt.axis('off')
  40. plt.imshow(train_data[0][0][1].numpy().transpose(1, 2, 0))
  41. plt.show()
  42. print(f"gt_pseudo_label ({len(train_data[1][0])} ground truth pseudo label): {train_data[1][0][0]}, {train_data[1][0][1]}")
  43. print(f"Y (their sum result): {train_data[2][0]}")
  44. Out:
  45. .. code:: none
  46. :class: code-out
  47. There are 30000 data examples in the training set and 5000 data examples in the test set
  48. Each of the data example has 3 components: X, gt_pseudo_label, and Y.
  49. As an illustration, in the First data example of the training set, we have:
  50. X (2 images):
  51. .. image:: ../img/mnist_add_datasets.png
  52. :width: 400px
  53. .. code:: none
  54. :class: code-out
  55. gt_pseudo_label (2 ground truth pseudo label): 7, 5
  56. Y (their sum result): 12
  57. Learning Part
  58. -------------
  59. First, we build the basic learning model. We use a simple `LeNet neural
  60. network <https://en.wikipedia.org/wiki/LeNet>`__ to complete this task.
  61. .. code:: ipython3
  62. cls = LeNet5(num_classes=10)
  63. loss_fn = nn.CrossEntropyLoss()
  64. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001, betas=(0.9, 0.99))
  65. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  66. base_model = BasicNN(
  67. cls,
  68. loss_fn,
  69. optimizer,
  70. device,
  71. batch_size=32,
  72. num_epochs=1,
  73. )
  74. The base model can predict the outcome class index and the probabilities
  75. for an image, as shown below:
  76. .. code:: ipython3
  77. pred_idx = base_model.predict(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])
  78. print(f"Shape of pred_idx for a batch of 32 samples: {pred_idx.shape}")
  79. pred_prob = base_model.predict_proba(X=[torch.randn(1, 28, 28).to(device) for _ in range(32)])
  80. print(f"Shape of pred_prob for a batch of 32 samples: {pred_prob.shape}")
  81. Out:
  82. .. code:: none
  83. :class: code-out
  84. Shape of pred_idx for a batch of 32 samples: (32,)
  85. Shape of pred_prob for a batch of 32 samples: (32, 10)
  86. Then, we build an instance of ``ABLModel``. The main function of
  87. ``ABLModel`` is to serialize data and provide a unified interface for
  88. different base machine learning models.
  89. .. code:: ipython3
  90. model = ABLModel(base_model)
  91. Logic Part
  92. ----------
  93. In the logic part, we first build a knowledge base.
  94. .. code:: ipython3
  95. # Build knowledge base and reasoner
  96. class AddKB(KBBase):
  97. def __init__(self, pseudo_label_list):
  98. super().__init__(pseudo_label_list)
  99. # Implement the deduction function
  100. def logic_forward(self, nums):
  101. return sum(nums)
  102. kb = AddKB(pseudo_label_list=list(range(10)))
  103. The knowledge base can perform logical reasoning. Below is an example of
  104. performing (deductive) reasoning:
  105. .. code:: ipython3
  106. pseudo_label_sample = [1, 2]
  107. reasoning_result = kb.logic_forward(pseudo_label_sample)
  108. print(f"Reasoning result of pseudo label sample {pseudo_label_sample} is {reasoning_result}.")
  109. Out:
  110. .. code:: none
  111. :class: code-out
  112. Reasoning result of pseudo label sample [1, 2] is 3.
  113. Then, we create a reasoner. It can help minimize inconsistencies between
  114. the knowledge base and pseudo labels predicted by the learning part.
  115. .. code:: ipython3
  116. reasoner = Reasoner(kb, dist_func="confidence")
  117. Evaluation Metrics
  118. ------------------
  119. Set up evaluation metrics. These metrics will be used to evaluate the
  120. model performance during training and testing.
  121. .. code:: ipython3
  122. metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  123. Bridge Learning and Reasoning
  124. -----------------------------
  125. Now, the last step is to bridge the learning and reasoning part.
  126. .. code:: ipython3
  127. bridge = SimpleBridge(model, reasoner, metric_list)
  128. Perform training and testing.
  129. .. code:: ipython3
  130. # Build logger
  131. print_log("Abductive Learning on the MNIST Addition example.", logger="current")
  132. # Retrieve the directory of the Log file and define the directory for saving the model weights.
  133. log_dir = ABLLogger.get_current_instance().log_dir
  134. weights_dir = osp.join(log_dir, "weights")
  135. bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)
  136. bridge.test(test_data)
  137. Out:
  138. .. code:: none
  139. :class: code-out
  140. abl - INFO - Abductive Learning on the MNIST Addition example.
  141. abl - INFO - loop(train) [1/5] segment(train) [1/3]
  142. abl - INFO - model loss: 1.81231
  143. abl - INFO - loop(train) [1/5] segment(train) [2/3]
  144. abl - INFO - model loss: 1.37639
  145. abl - INFO - loop(train) [1/5] segment(train) [3/3]
  146. abl - INFO - model loss: 1.14446
  147. abl - INFO - Evaluation start: loop(val) [1]
  148. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.207 mnist_add/reasoning_accuracy: 0.245
  149. abl - INFO - Saving model: loop(save) [1]
  150. abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_1.pth
  151. abl - INFO - loop(train) [2/5] segment(train) [1/3]
  152. abl - INFO - model loss: 0.97430
  153. abl - INFO - loop(train) [2/5] segment(train) [2/3]
  154. abl - INFO - model loss: 0.91448
  155. abl - INFO - loop(train) [2/5] segment(train) [3/3]
  156. abl - INFO - model loss: 0.83089
  157. abl - INFO - Evaluation start: loop(val) [2]
  158. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.191 mnist_add/reasoning_accuracy: 0.353
  159. abl - INFO - Saving model: loop(save) [2]
  160. abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_2.pth
  161. abl - INFO - loop(train) [3/5] segment(train) [1/3]
  162. abl - INFO - model loss: 0.79906
  163. abl - INFO - loop(train) [3/5] segment(train) [2/3]
  164. abl - INFO - model loss: 0.77949
  165. abl - INFO - loop(train) [3/5] segment(train) [3/3]
  166. abl - INFO - model loss: 0.75007
  167. abl - INFO - Evaluation start: loop(val) [3]
  168. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.148 mnist_add/reasoning_accuracy: 0.385
  169. abl - INFO - Saving model: loop(save) [3]
  170. abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_3.pth
  171. abl - INFO - loop(train) [4/5] segment(train) [1/3]
  172. abl - INFO - model loss: 0.72659
  173. abl - INFO - loop(train) [4/5] segment(train) [2/3]
  174. abl - INFO - model loss: 0.70985
  175. abl - INFO - loop(train) [4/5] segment(train) [3/3]
  176. abl - INFO - model loss: 0.66337
  177. abl - INFO - Evaluation start: loop(val) [4]
  178. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.016 mnist_add/reasoning_accuracy: 0.494
  179. abl - INFO - Saving model: loop(save) [4]
  180. abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_4.pth
  181. abl - INFO - loop(train) [5/5] segment(train) [1/3]
  182. abl - INFO - model loss: 0.61140
  183. abl - INFO - loop(train) [5/5] segment(train) [2/3]
  184. abl - INFO - model loss: 0.57534
  185. abl - INFO - loop(train) [5/5] segment(train) [3/3]
  186. abl - INFO - model loss: 0.57018
  187. abl - INFO - Evaluation start: loop(val) [5]
  188. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.002 mnist_add/reasoning_accuracy: 0.507
  189. abl - INFO - Saving model: loop(save) [5]
  190. abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_5.pth
  191. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.002 mnist_add/reasoning_accuracy: 0.482
  192. More concrete examples are available in ``examples/mnist_add`` folder.

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.