You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Handwritten Formula
  2. This example shows a simple implementation of [Handwritten Formula](https://arxiv.org/abs/2006.06649) task, where handwritten images of decimal formulas and their computed results are given, alongwith a domain knowledge base containing information on how to compute the decimal formula. The task is to recognize the symbols (which can be digits or operators '+', '-', '×', '÷') of handwritten images and accurately determine their results.
  3. ## Run
  4. ```bash
  5. pip install -r requirements.txt
  6. python main.py
  7. ```
  8. ## Usage
  9. ```bash
  10. usage: main.py [-h] [--no-cuda] [--epochs EPOCHS]
  11. [--label_smoothing LABEL_SMOOTHING] [--lr LR]
  12. [--batch-size BATCH_SIZE]
  13. [--loops LOOPS] [--segment_size SEGMENT_SIZE]
  14. [--save_interval SAVE_INTERVAL] [--max-revision MAX_REVISION]
  15. [--require-more-revision REQUIRE_MORE_REVISION]
  16. [--ground] [--max-err MAX_ERR]
  17. Handwritten Formula example
  18. optional arguments:
  19. -h, --help show this help message and exit
  20. --no-cuda disables CUDA training
  21. --epochs EPOCHS number of epochs in each learning loop iteration
  22. (default : 1)
  23. --label_smoothing LABEL_SMOOTHING
  24. label smoothing in cross entropy loss (default : 0.2)
  25. --lr LR base model learning rate (default : 0.001)
  26. --batch-size BATCH_SIZE
  27. base model batch size (default : 32)
  28. --loops LOOPS number of loop iterations (default : 5)
  29. --segment_size SEGMENT_SIZE
  30. segment size (default : 1/3)
  31. --save_interval SAVE_INTERVAL
  32. save interval (default : 1)
  33. --max-revision MAX_REVISION
  34. maximum revision in reasoner (default : -1)
  35. --require-more-revision REQUIRE_MORE_REVISION
  36. require more revision in reasoner (default : 0)
  37. --ground use GroundKB (default: False)
  38. --max-err MAX_ERR max tolerance during abductive reasoning (default : 1e-10)
  39. ```
  40. ## Environment
  41. For all experiments, we used a single linux server. Details on the specifications are listed in the table below.
  42. <table class="tg" style="margin-left: auto; margin-right: auto;">
  43. <thead>
  44. <tr>
  45. <th>CPU</th>
  46. <th>GPU</th>
  47. <th>Memory</th>
  48. <th>OS</th>
  49. </tr>
  50. </thead>
  51. <tbody>
  52. <tr>
  53. <td>2 * Xeon Platinum 8358, 32 Cores, 2.6 GHz Base Frequency</td>
  54. <td>A100 80GB</td>
  55. <td>512GB</td>
  56. <td>Ubuntu 20.04</td>
  57. </tr>
  58. </tbody>
  59. </table>
  60. ## Performance
  61. We present the results of ABL as follows, which include the reasoning accuracy (for different equation lengths in the HWF dataset), training time (to achieve the accuracy using all equation lengths), and average memory usage (using all equation lengths). These results are compared with the following methods:
  62. - [**NGS**](https://github.com/liqing-ustc/NGS): A neural-symbolic framework that uses a grammar model and a back-search algorithm to improve its computing process;
  63. - [**DeepProbLog**](https://github.com/ML-KULeuven/deepproblog/tree/master): An extension of ProbLog by introducing neural predicates in Probabilistic Logic Programming;
  64. - [**DeepStochLog**](https://github.com/ML-KULeuven/deepstochlog/tree/main): A neural-symbolic framework based on stochastic logic program.
  65. <table class="tg" style="margin-left: auto; margin-right: auto;">
  66. <thead>
  67. <tr>
  68. <th rowspan="2"></th>
  69. <th colspan="5">Reasoning Accuracy<br><span style="font-weight: normal; font-size: smaller;">(for different equation lengths)</span></th>
  70. <th rowspan="2">Training Time (s)<br><span style="font-weight: normal; font-size: smaller;">(to achieve the Acc. using all lengths)</span></th>
  71. <th rowspan="2">Average Memory Usage (MB)<br><span style="font-weight: normal; font-size: smaller;">(using all lengths)</span></th>
  72. </tr>
  73. <tr>
  74. <th>1</th>
  75. <th>3</th>
  76. <th>5</th>
  77. <th>7</th>
  78. <th>All</th>
  79. </tr>
  80. </thead>
  81. <tbody>
  82. <tr>
  83. <td>NGS</td>
  84. <td>91.2</td>
  85. <td>89.1</td>
  86. <td>92.7</td>
  87. <td>5.2</td>
  88. <td>98.4</td>
  89. <td>426.2</td>
  90. <td>3705</td>
  91. </tr>
  92. <tr>
  93. <td>DeepProbLog</td>
  94. <td>90.8</td>
  95. <td>85.6</td>
  96. <td>timeout*</td>
  97. <td>timeout</td>
  98. <td>timeout</td>
  99. <td>timeout</td>
  100. <td>4315</td>
  101. </tr>
  102. <tr>
  103. <td>DeepStochLog</td>
  104. <td>92.8</td>
  105. <td>87.5</td>
  106. <td>92.1</td>
  107. <td>timeout</td>
  108. <td>timeout</td>
  109. <td>timeout</td>
  110. <td>4355</td>
  111. </tr>
  112. <tr>
  113. <td>ABL</td>
  114. <td><span style="font-weight:bold">94.0</span></td>
  115. <td><span style="font-weight:bold">89.7</span></td>
  116. <td><span style="font-weight:bold">96.5</span></td>
  117. <td><span style="font-weight:bold">97.2</span></td>
  118. <td><span style="font-weight:bold">99.2</span></td>
  119. <td><span style="font-weight:bold">77.3</span></td>
  120. <td><span style="font-weight:bold">3074</span></td>
  121. </tr>
  122. </tbody>
  123. </table>
  124. <p style="font-size: 13px;">* timeout: need more than 1 hour to execute</p>

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.