You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MNISTAdd.rst 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. MNIST Addition
  2. ==============
  3. This notebook shows an implementation of `MNIST
  4. Addition <https://arxiv.org/abs/1805.10872>`__. In this task, pairs of
  5. MNIST handwritten images and their sums are given, alongwith a domain
  6. knowledge base containing information on how to perform addition
  7. operations. The task is to recognize the digits of handwritten images
  8. and accurately determine their sum.
  9. Intuitively, we first use a machine learning model (learning part) to
  10. convert the input images to digits (we call them pseudo-labels), and
  11. then use the knowledge base (reasoning part) to calculate the sum of
  12. these digits. Since we do not have ground-truth of the digits, in
  13. abductive learning, the reasoning part will leverage domain knowledge
  14. and revise the initial digits yielded by the learning part through
  15. abductive reasoning. This process enables us to further update the
  16. machine learning model.
  17. .. code:: ipython3
  18. # Import necessary libraries and modules
  19. import os.path as osp
  20. import torch
  21. import torch.nn as nn
  22. import matplotlib.pyplot as plt
  23. from examples.mnist_add.datasets import get_dataset
  24. from examples.models.nn import LeNet5
  25. from abl.learning import ABLModel, BasicNN
  26. from abl.reasoning import KBBase, Reasoner
  27. from abl.evaluation import ReasoningMetric, SymbolMetric
  28. from abl.utils import ABLLogger, print_log
  29. from abl.bridge import SimpleBridge
  30. Working with Data
  31. -----------------
  32. First, we get the training and testing datasets:
  33. .. code:: ipython3
  34. train_data = get_dataset(train=True, get_pseudo_label=True)
  35. test_data = get_dataset(train=False, get_pseudo_label=True)
  36. Both ``train_data`` and ``test_data`` have the same structures: tuples
  37. with three components: X (list where each element is a pair of images),
  38. gt_pseudo_label (list where each element is a pair of digits) and Y
  39. (list where each element is the sum of each digit pair). The length and
  40. structures of datasets are illustrated as follows.
  41. .. note::
  42. ``gt_pseudo_label`` is only used to evaluate the performance of
  43. the learning part but not to train the model.
  44. .. code:: ipython3
  45. def describe_structure(lst):
  46. if not isinstance(lst, list):
  47. return type(lst).__name__
  48. return [describe_structure(item) for item in lst]
  49. print(f"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y")
  50. print("\n")
  51. train_X, train_gt_pseudo_label, train_Y = train_data
  52. print(f"Length of X, gt_pseudo_label, Y in train_data: {len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}")
  53. test_X, test_gt_pseudo_label, test_Y = test_data
  54. print(f"Length of X, gt_pseudo_label, Y in test_data: {len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}")
  55. print("\n")
  56. structure_X = describe_structure(train_X[0])
  57. structure_gt_pseudo_label = describe_structure(train_gt_pseudo_label[0])
  58. structure_Y = describe_structure(train_Y[0])
  59. print(f"Type of X: {type(train_X).__name__}, and type of each element in X: {structure_X}.")
  60. print(f"Type of gt_pseudo_label: {type(train_gt_pseudo_label).__name__}, and type of each element in gt_pseudo_label: {structure_gt_pseudo_label}.")
  61. print(f"Type of Y: {type(train_Y).__name__}, and type of each element in Y: {structure_Y}.")
  62. Out:
  63. .. code:: none
  64. :class: code-out
  65. Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y
  66. Length of X, gt_pseudo_label, Y in train_data: 30000, 30000, 30000
  67. Length of X, gt_pseudo_label, Y in test_data: 5000, 5000, 5000
  68. Type of X: list, and type of each element in X: ['Tensor', 'Tensor'].
  69. Type of gt_pseudo_label: list, and type of each element in gt_pseudo_label: ['int', 'int'].
  70. Type of Y: list, and type of each element in Y: int.
  71. The ith element of X, gt_pseudo_label, and Y together constitute the ith
  72. data example. As an illustration, in the first data example of the
  73. training set, we have:
  74. .. code:: ipython3
  75. first_X, first_gt_pseudo_label, first_Y = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  76. print(f"X in the first data example (a pair of images):")
  77. plt.subplot(1,2,1)
  78. plt.axis('off')
  79. plt.imshow(first_X[0].numpy().transpose(1, 2, 0))
  80. plt.subplot(1,2,2)
  81. plt.axis('off')
  82. plt.imshow(first_X[1].numpy().transpose(1, 2, 0))
  83. plt.show()
  84. print(f"gt_pseudo_label in the first data example (a pair of ground truth pseudo-labels): {first_gt_pseudo_label[0]}, {first_gt_pseudo_label[1]}")
  85. print(f"Y in the first data example (their sum result): {first_Y}")
  86. Out:
  87. .. code:: none
  88. :class: code-out
  89. X in the first data example (a pair of images):
  90. .. image:: ../img/mnist_add_datasets.png
  91. :width: 400px
  92. .. parsed-literal::
  93. gt_pseudo_label in the first data example (a pair of ground truth pseudo-labels): 7, 5
  94. Y in the first data example (their sum result): 12
  95. Building the Learning Part
  96. --------------------------
  97. To build the learning part, we need to first build a machine learning
  98. base model. We use a simple `LeNet-5 neural
  99. network <https://en.wikipedia.org/wiki/LeNet>`__, and encapsulate it
  100. within a ``BasicNN`` object to create the base model. ``BasicNN`` is a
  101. class that encapsulates a PyTorch model, transforming it into a base
  102. model with an sklearn-style interface.
  103. .. code:: ipython3
  104. cls = LeNet5(num_classes=10)
  105. loss_fn = nn.CrossEntropyLoss()
  106. optimizer = torch.optim.Adam(cls.parameters(), lr=0.001)
  107. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  108. base_model = BasicNN(
  109. cls,
  110. loss_fn,
  111. optimizer,
  112. device,
  113. batch_size=32,
  114. num_epochs=1,
  115. )
  116. ``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which
  117. are used to predict the class index and the probabilities of each class
  118. for images. As shown below:
  119. .. code:: ipython3
  120. data_instances = [torch.randn(1, 28, 28).to(device) for _ in range(32)]
  121. pred_idx = base_model.predict(X=data_instances)
  122. print(f"Predicted class index for a batch of 32 instances: np.ndarray with shape {pred_idx.shape}")
  123. pred_prob = base_model.predict_proba(X=data_instances)
  124. print(f"Predicted class probabilities for a batch of 32 instances: np.ndarray with shape {pred_prob.shape}")
  125. Out:
  126. .. code:: none
  127. :class: code-out
  128. Predicted class index for a batch of 32 instances: np.ndarray with shape (32,)
  129. Predicted class probabilities for a batch of 32 instances: np.ndarray with shape (32, 10)
  130. However, the base model built above deals with instance-level data
  131. (i.e., individual images), and can not directly deal with example-level
  132. data (i.e., a pair of images). Therefore, we wrap the base model into
  133. ``ABLModel``, which enables the learning part to train, test, and
  134. predict on example-level data.
  135. .. code:: ipython3
  136. model = ABLModel(base_model)
  137. As an illustration, consider this example of training on example-level
  138. data using the ``predict`` method in ``ABLModel``. In this process, the
  139. method accepts data examples as input and outputs the class labels and
  140. the probabilities of each class for all instances within these data
  141. examples.
  142. .. code:: ipython3
  143. from abl.structures import ListData
  144. # ListData is a data structure provided by ABL-Package that can be used to organize data examples
  145. data_example = ListData()
  146. data_example.X = first_X
  147. data_example.gt_pseudo_label = first_gt_pseudo_label
  148. data_example.Y = first_Y
  149. # Perform prediction on the first data examples
  150. prediction_result = model.predict(data_example)
  151. print(f"Predicted class labels for the first data example: np.array with shape {prediction_result['label'].shape}")
  152. print(f"Predicted class probabilities for the first data example: np.array with shape {prediction_result['prob'].shape}")
  153. Out:
  154. .. code:: none
  155. :class: code-out
  156. Predicted class labels for the first data example: np.array with shape (2,)
  157. Predicted class probabilities for the first data example: np.array with shape (2, 10)
  158. Building the Reasoning Part
  159. ---------------------------
  160. In the reasoning part, we first build a knowledge base which contain
  161. information on how to perform addition operations. We build it by
  162. creating a subclass of ``KBBase``. In the derived subclass, we
  163. initialize the ``pseudo_label_list`` parameter specifying list of
  164. possible pseudo-labels, and override the ``logic_forward`` function
  165. defining how to perform (deductive) reasoning.
  166. .. code:: ipython3
  167. class AddKB(KBBase):
  168. def __init__(self, pseudo_label_list=list(range(10))):
  169. super().__init__(pseudo_label_list)
  170. # Implement the deduction function
  171. def logic_forward(self, nums):
  172. return sum(nums)
  173. kb = AddKB()
  174. The knowledge base can perform logical reasoning (both deductive
  175. reasoning and abductive reasoning). Below is an example of performing
  176. (deductive) reasoning, and users can refer to :ref:`Performing abductive
  177. reasoning in the knowledge base <kb-abd>` for details of abductive reasoning.
  178. .. code:: ipython3
  179. pseudo_label_example = [1, 2]
  180. reasoning_result = kb.logic_forward(pseudo_label_example)
  181. print(f"Reasoning result of pseudo-label example {pseudo_label_example} is {reasoning_result}.")
  182. Out:
  183. .. code:: none
  184. :class: code-out
  185. Reasoning result of pseudo-label example [1, 2] is 3.
  186. .. note::
  187. In addition to building a knowledge base based on ``KBBase``, we
  188. can also establish a knowledge base with a ground KB using ``GroundKB``,
  189. or a knowledge base implemented based on Prolog files using
  190. ``PrologKB``. The corresponding code for these implementations can be
  191. found in the ``main.py`` file. Those interested are encouraged to
  192. examine it for further insights.
  193. Then, we create a reasoner by instantiating the class ``Reasoner``. Due
  194. to the indeterminism of abductive reasoning, there could be multiple
  195. candidates compatible to the knowledge base. When this happens, reasoner
  196. can minimize inconsistencies between the knowledge base and
  197. pseudo-labels predicted by the learning part, and then return only one
  198. candidate that has the highest consistency.
  199. .. code:: ipython3
  200. reasoner = Reasoner(kb)
  201. .. note::
  202. During creating reasoner, the definition of “consistency” can be
  203. customized within the ``dist_func`` parameter. In the code above, we
  204. employ a consistency measurement based on confidence, which calculates
  205. the consistency between the data example and candidates based on the
  206. confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we
  207. provide options for utilizing other forms of consistency measurement.
  208. Also, during process of inconsistency minimization, one can leverage
  209. `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for acceleration.
  210. Options for this are also available in ``examples/mnist_add/main.py``. Those interested are
  211. encouraged to explore these features.
  212. Building Evaluation Metrics
  213. ---------------------------
  214. Next, we set up evaluation metrics. These metrics will be used to
  215. evaluate the model performance during training and testing.
  216. Specifically, we use ``SymbolMetric`` and ``ReasoningMetric``, which are
  217. used to evaluate the accuracy of the machine learning model’s
  218. predictions and the accuracy of the final reasoning results,
  219. respectively.
  220. .. code:: ipython3
  221. metric_list = [SymbolMetric(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  222. Bridge Learning and Reasoning
  223. -----------------------------
  224. Now, the last step is to bridge the learning and reasoning part. We
  225. proceed this step by creating an instance of ``SimpleBridge``.
  226. .. code:: ipython3
  227. bridge = SimpleBridge(model, reasoner, metric_list)
  228. Perform training and testing by invoking the ``train`` and ``test``
  229. methods of ``SimpleBridge``.
  230. .. code:: ipython3
  231. # Build logger
  232. print_log("Abductive Learning on the MNIST Addition example.", logger="current")
  233. log_dir = ABLLogger.get_current_instance().log_dir
  234. weights_dir = osp.join(log_dir, "weights")
  235. bridge.train(train_data, loops=5, segment_size=1/3, save_interval=1, save_dir=weights_dir)
  236. bridge.test(test_data)
  237. Out:
  238. .. code:: none
  239. :class: code-out
  240. abl - INFO - Abductive Learning on the MNIST Addition example.
  241. abl - INFO - loop(train) [1/5] segment(train) [1/3]
  242. abl - INFO - model loss: 1.49104
  243. abl - INFO - loop(train) [1/5] segment(train) [2/3]
  244. abl - INFO - model loss: 1.24945
  245. abl - INFO - loop(train) [1/5] segment(train) [3/3]
  246. abl - INFO - model loss: 0.87861
  247. abl - INFO - Evaluation start: loop(val) [1]
  248. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.818 mnist_add/reasoning_accuracy: 0.672
  249. abl - INFO - Saving model: loop(save) [1]
  250. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_1.pth
  251. abl - INFO - loop(train) [2/5] segment(train) [1/3]
  252. abl - INFO - model loss: 0.31148
  253. abl - INFO - loop(train) [2/5] segment(train) [2/3]
  254. abl - INFO - model loss: 0.09520
  255. abl - INFO - loop(train) [2/5] segment(train) [3/3]
  256. abl - INFO - model loss: 0.07402
  257. abl - INFO - Evaluation start: loop(val) [2]
  258. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.982 mnist_add/reasoning_accuracy: 0.964
  259. abl - INFO - Saving model: loop(save) [2]
  260. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_2.pth
  261. abl - INFO - loop(train) [3/5] segment(train) [1/3]
  262. abl - INFO - model loss: 0.06027
  263. abl - INFO - loop(train) [3/5] segment(train) [2/3]
  264. abl - INFO - model loss: 0.05341
  265. abl - INFO - loop(train) [3/5] segment(train) [3/3]
  266. abl - INFO - model loss: 0.04915
  267. abl - INFO - Evaluation start: loop(val) [3]
  268. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.975
  269. abl - INFO - Saving model: loop(save) [3]
  270. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_3.pth
  271. abl - INFO - loop(train) [4/5] segment(train) [1/3]
  272. abl - INFO - model loss: 0.04413
  273. abl - INFO - loop(train) [4/5] segment(train) [2/3]
  274. abl - INFO - model loss: 0.04181
  275. abl - INFO - loop(train) [4/5] segment(train) [3/3]
  276. abl - INFO - model loss: 0.04127
  277. abl - INFO - Evaluation start: loop(val) [4]
  278. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.990 mnist_add/reasoning_accuracy: 0.980
  279. abl - INFO - Saving model: loop(save) [4]
  280. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_4.pth
  281. abl - INFO - loop(train) [5/5] segment(train) [1/3]
  282. abl - INFO - model loss: 0.03544
  283. abl - INFO - loop(train) [5/5] segment(train) [2/3]
  284. abl - INFO - model loss: 0.03092
  285. abl - INFO - loop(train) [5/5] segment(train) [3/3]
  286. abl - INFO - model loss: 0.03663
  287. abl - INFO - Evaluation start: loop(val) [5]
  288. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.982
  289. abl - INFO - Saving model: loop(save) [5]
  290. abl - INFO - Checkpoints will be saved to weights_dir/model_checkpoint_loop_5.pth
  291. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.987 mnist_add/reasoning_accuracy: 0.974
  292. More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``.

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.