You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MNISTAdd.rst 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. MNIST Addition
  2. ==============
  3. Below shows an implementation of `MNIST
  4. Addition <https://arxiv.org/abs/1805.10872>`__. In this task, pairs of
  5. MNIST handwritten images and their sums are given, alongwith a domain
  6. knowledge base containing information on how to perform addition
  7. operations. The task is to recognize the digits of handwritten images
  8. and accurately determine their sum.
  9. Intuitively, we first use a machine learning model (learning part) to
  10. convert the input images to digits (we call them pseudo-labels), and
  11. then use the knowledge base (reasoning part) to calculate the sum of
  12. these digits. Since we do not have ground-truth of the digits, in
  13. Abductive Learning, the reasoning part will leverage domain knowledge
  14. and revise the initial digits yielded by the learning part through
  15. abductive reasoning. This process enables us to further update the
  16. machine learning model.
  17. .. code:: ipython3
  18. # Import necessary libraries and modules
  19. import os.path as osp
  20. import torch
  21. import torch.nn as nn
  22. import matplotlib.pyplot as plt
  23. from torch.optim import RMSprop, lr_scheduler
  24. from examples.mnist_add.datasets import get_dataset
  25. from examples.models.nn import LeNet5
  26. from abl.learning import ABLModel, BasicNN
  27. from abl.reasoning import KBBase, Reasoner
  28. from abl.data.evaluation import ReasoningMetric, SymbolAccuracy
  29. from abl.utils import ABLLogger, print_log
  30. from abl.bridge import SimpleBridge
  31. Working with Data
  32. -----------------
  33. First, we get the training and testing datasets:
  34. .. code:: ipython3
  35. train_data = get_dataset(train=True, get_pseudo_label=True)
  36. test_data = get_dataset(train=False, get_pseudo_label=True)
  37. ``train_data`` and ``test_data`` share identical structures:
  38. tuples with three components: X (list where each element is a
  39. list of two images), gt_pseudo_label (list where each element
  40. is a list of two digits, i.e., pseudo-labels) and Y (list where
  41. each element is the sum of the two digits). The length and structures
  42. of datasets are illustrated as follows.
  43. .. note::
  44. ``gt_pseudo_label`` is only used to evaluate the performance of
  45. the learning part but not to train the model.
  46. .. code:: ipython3
  47. print(f"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y")
  48. print("\n")
  49. train_X, train_gt_pseudo_label, train_Y = train_data
  50. print(f"Length of X, gt_pseudo_label, Y in train_data: " +
  51. f"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}")
  52. test_X, test_gt_pseudo_label, test_Y = test_data
  53. print(f"Length of X, gt_pseudo_label, Y in test_data: " +
  54. f"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}")
  55. print("\n")
  56. X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  57. print(f"X is a {type(train_X).__name__}, " +
  58. f"with each element being a {type(X_0).__name__} " +
  59. f"of {len(X_0)} {type(X_0[0]).__name__}.")
  60. print(f"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, " +
  61. f"with each element being a {type(gt_pseudo_label_0).__name__} " +
  62. f"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.")
  63. print(f"Y is a {type(train_Y).__name__}, " +
  64. f"with each element being a {type(Y_0).__name__}.")
  65. Out:
  66. .. code:: none
  67. :class: code-out
  68. Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y
  69. Length of X, gt_pseudo_label, Y in train_data: 30000, 30000, 30000
  70. Length of X, gt_pseudo_label, Y in test_data: 5000, 5000, 5000
  71. X is a list, with each element being a list of 2 Tensor.
  72. gt_pseudo_label is a list, with each element being a list of 2 int.
  73. Y is a list, with each element being a int.
  74. The ith element of X, gt_pseudo_label, and Y together constitute the ith
  75. data example. As an illustration, in the first data example of the
  76. training set, we have:
  77. .. code:: ipython3
  78. X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]
  79. print(f"X in the first data example (a list of two images):")
  80. plt.subplot(1,2,1)
  81. plt.axis('off')
  82. plt.imshow(X_0[0].numpy().transpose(1, 2, 0))
  83. plt.subplot(1,2,2)
  84. plt.axis('off')
  85. plt.imshow(X_0[1].numpy().transpose(1, 2, 0))
  86. plt.show()
  87. print(f"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}")
  88. print(f"Y in the first data example (their sum result): {Y_0}")
  89. Out:
  90. .. code:: none
  91. :class: code-out
  92. X in the first data example (a list of two images):
  93. .. image:: ../img/mnist_add_datasets.png
  94. :width: 400px
  95. .. parsed-literal::
  96. gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): [7, 5]
  97. Y in the first data example (their sum result): 12
  98. Building the Learning Part
  99. --------------------------
  100. To build the learning part, we need to first build a machine learning
  101. base model. We use a simple `LeNet-5 neural
  102. network <https://en.wikipedia.org/wiki/LeNet>`__, and encapsulate it
  103. within a ``BasicNN`` object to create the base model. ``BasicNN`` is a
  104. class that encapsulates a PyTorch model, transforming it into a base
  105. model with an sklearn-style interface.
  106. .. code:: ipython3
  107. cls = LeNet5(num_classes=10)
  108. loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
  109. optimizer = RMSprop(cls.parameters(), lr=0.001, alpha=0.9)
  110. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  111. scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, pct_start=0.1, total_steps=100)
  112. base_model = BasicNN(
  113. cls,
  114. loss_fn,
  115. optimizer,
  116. scheduler=scheduler,
  117. device=device,
  118. batch_size=32,
  119. num_epochs=1,
  120. )
  121. ``BasicNN`` offers methods like ``predict`` and ``predict_prob``, which
  122. are used to predict the class index and the probabilities of each class
  123. for images. As shown below:
  124. .. code:: ipython3
  125. data_instances = [torch.randn(1, 28, 28).to(device) for _ in range(32)]
  126. pred_idx = base_model.predict(X=data_instances)
  127. print(f"Predicted class index for a batch of 32 instances: np.ndarray with shape {pred_idx.shape}")
  128. pred_prob = base_model.predict_proba(X=data_instances)
  129. print(f"Predicted class probabilities for a batch of 32 instances: np.ndarray with shape {pred_prob.shape}")
  130. Out:
  131. .. code:: none
  132. :class: code-out
  133. Predicted class index for a batch of 32 instances: np.ndarray with shape (32,)
  134. Predicted class probabilities for a batch of 32 instances: np.ndarray with shape (32, 10)
  135. However, the base model built above deals with instance-level data
  136. (i.e., individual images), and can not directly deal with example-level
  137. data (i.e., a pair of images). Therefore, we wrap the base model into
  138. ``ABLModel``, which enables the learning part to train, test, and
  139. predict on example-level data.
  140. .. code:: ipython3
  141. model = ABLModel(base_model)
  142. As an illustration, consider this example of training on example-level
  143. data using the ``predict`` method in ``ABLModel``. In this process, the
  144. method accepts data examples as input and outputs the class labels and
  145. the probabilities of each class for all instances within these data
  146. examples.
  147. .. code:: ipython3
  148. from abl.data.structures import ListData
  149. # ListData is a data structure provided by ABL-Package that can be used to organize data examples
  150. data_examples = ListData()
  151. # We use the first 100 data examples in the training set as an illustration
  152. data_examples.X = train_X[:100]
  153. data_examples.gt_pseudo_label = train_gt_pseudo_label[:100]
  154. data_examples.Y = train_Y[:100]
  155. # Perform prediction on the 100 data examples
  156. pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']
  157. print(f"Predicted class labels for the 100 data examples: \n" +
  158. f"a list of length {len(pred_label)}, and each element is " +
  159. f"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\n")
  160. print(f"Predicted class probabilities for the 100 data examples: \n" +
  161. f"a list of length {len(pred_prob)}, and each element is " +
  162. f"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.")
  163. Out:
  164. .. code:: none
  165. :class: code-out
  166. Predicted class labels for the 100 data examples:
  167. a list of length 100, and each element is a ndarray of shape (2,).
  168. Predicted class probabilities for the 100 data examples:
  169. a list of length 100, and each element is a ndarray of shape (2, 10).
  170. Building the Reasoning Part
  171. ---------------------------
  172. In the reasoning part, we first build a knowledge base which contain
  173. information on how to perform addition operations. We build it by
  174. creating a subclass of ``KBBase``. In the derived subclass, we
  175. initialize the ``pseudo_label_list`` parameter specifying list of
  176. possible pseudo-labels, and override the ``logic_forward`` function
  177. defining how to perform (deductive) reasoning.
  178. .. code:: ipython3
  179. class AddKB(KBBase):
  180. def __init__(self, pseudo_label_list=list(range(10))):
  181. super().__init__(pseudo_label_list)
  182. # Implement the deduction function
  183. def logic_forward(self, nums):
  184. return sum(nums)
  185. kb = AddKB()
  186. The knowledge base can perform logical reasoning (both deductive
  187. reasoning and abductive reasoning). Below is an example of performing
  188. (deductive) reasoning, and users can refer to :ref:`Performing abductive
  189. reasoning in the knowledge base <kb-abd>` for details of abductive reasoning.
  190. .. code:: ipython3
  191. pseudo_labels = [1, 2]
  192. reasoning_result = kb.logic_forward(pseudo_labels)
  193. print(f"Reasoning result of pseudo-labels {pseudo_labels} is {reasoning_result}.")
  194. Out:
  195. .. code:: none
  196. :class: code-out
  197. Reasoning result of pseudo-labels [1, 2] is 3.
  198. .. note::
  199. In addition to building a knowledge base based on ``KBBase``, we
  200. can also establish a knowledge base with a ground KB using ``GroundKB``,
  201. or a knowledge base implemented based on Prolog files using
  202. ``PrologKB``. The corresponding code for these implementations can be
  203. found in the ``main.py`` file. Those interested are encouraged to
  204. examine it for further insights.
  205. Then, we create a reasoner by instantiating the class ``Reasoner``. Due
  206. to the indeterminism of abductive reasoning, there could be multiple
  207. candidates compatible to the knowledge base. When this happens, reasoner
  208. can minimize inconsistencies between the knowledge base and
  209. pseudo-labels predicted by the learning part, and then return only one
  210. candidate that has the highest consistency.
  211. .. code:: ipython3
  212. reasoner = Reasoner(kb)
  213. .. note::
  214. During creating reasoner, the definition of “consistency” can be
  215. customized within the ``dist_func`` parameter. In the code above, we
  216. employ a consistency measurement based on confidence, which calculates
  217. the consistency between the data example and candidates based on the
  218. confidence derived from the predicted probability. In ``examples/mnist_add/main.py``, we
  219. provide options for utilizing other forms of consistency measurement.
  220. Also, during process of inconsistency minimization, we can leverage
  221. `ZOOpt library <https://github.com/polixir/ZOOpt>`__ for acceleration.
  222. Options for this are also available in ``examples/mnist_add/main.py``. Those interested are
  223. encouraged to explore these features.
  224. Building Evaluation Metrics
  225. ---------------------------
  226. Next, we set up evaluation metrics. These metrics will be used to
  227. evaluate the model performance during training and testing.
  228. Specifically, we use ``SymbolAccuracy`` and ``ReasoningMetric``, which are
  229. used to evaluate the accuracy of the machine learning model’s
  230. predictions and the accuracy of the final reasoning results,
  231. respectively.
  232. .. code:: ipython3
  233. metric_list = [SymbolAccuracy(prefix="mnist_add"), ReasoningMetric(kb=kb, prefix="mnist_add")]
  234. Bridge Learning and Reasoning
  235. -----------------------------
  236. Now, the last step is to bridge the learning and reasoning part. We
  237. proceed this step by creating an instance of ``SimpleBridge``.
  238. .. code:: ipython3
  239. bridge = SimpleBridge(model, reasoner, metric_list)
  240. Perform training and testing by invoking the ``train`` and ``test``
  241. methods of ``SimpleBridge``.
  242. .. code:: ipython3
  243. # Build logger
  244. print_log("Abductive Learning on the MNIST Addition example.", logger="current")
  245. log_dir = ABLLogger.get_current_instance().log_dir
  246. weights_dir = osp.join(log_dir, "weights")
  247. bridge.train(train_data, loops=1, segment_size=0.01, save_interval=1, save_dir=weights_dir)
  248. bridge.test(test_data)
  249. Out:
  250. .. code:: none
  251. :class: code-out
  252. abl - INFO - Abductive Learning on the MNIST Addition example.
  253. abl - INFO - loop(train) [1/1] segment(train) [1/100]
  254. abl - INFO - model loss: 2.23587
  255. abl - INFO - loop(train) [1/1] segment(train) [2/100]
  256. abl - INFO - model loss: 2.23756
  257. abl - INFO - loop(train) [1/1] segment(train) [3/100]
  258. abl - INFO - model loss: 2.04475
  259. abl - INFO - loop(train) [1/1] segment(train) [4/100]
  260. abl - INFO - model loss: 2.01035
  261. abl - INFO - loop(train) [1/1] segment(train) [5/100]
  262. abl - INFO - model loss: 1.97584
  263. abl - INFO - loop(train) [1/1] segment(train) [6/100]
  264. abl - INFO - model loss: 1.91570
  265. abl - INFO - loop(train) [1/1] segment(train) [7/100]
  266. abl - INFO - model loss: 1.90268
  267. abl - INFO - loop(train) [1/1] segment(train) [8/100]
  268. abl - INFO - model loss: 1.77436
  269. abl - INFO - loop(train) [1/1] segment(train) [9/100]
  270. abl - INFO - model loss: 1.73454
  271. abl - INFO - loop(train) [1/1] segment(train) [10/100]
  272. abl - INFO - model loss: 1.62495
  273. abl - INFO - loop(train) [1/1] segment(train) [11/100]
  274. abl - INFO - model loss: 1.58456
  275. abl - INFO - loop(train) [1/1] segment(train) [12/100]
  276. abl - INFO - model loss: 1.62575
  277. ...
  278. abl - INFO - Eval start: loop(val) [1]
  279. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.986 mnist_add/reasoning_accuracy: 0.973
  280. abl - INFO - Saving model: loop(save) [1]
  281. abl - INFO - Checkpoints will be saved to results/20231222_22_25_07/weights/model_checkpoint_loop_1.pth
  282. abl - INFO - Test start:
  283. abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.983 mnist_add/reasoning_accuracy: 0.967
  284. More concrete examples are available in ``examples/mnist_add/main.py`` and ``examples/mnist_add/mnist_add.ipynb``.

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.