From fd7d33f031d60e516829bd82b150028cf87c0c80 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Sun, 7 Jan 2024 21:01:58 +0800 Subject: [PATCH] [ENH] add mnist performance to docs --- docs/Examples/MNISTAdd.rst | 83 +++++++++++++++++++----------- examples/mnist_add/mnist_add.ipynb | 19 +++++++ 2 files changed, 73 insertions(+), 29 deletions(-) diff --git a/docs/Examples/MNISTAdd.rst b/docs/Examples/MNISTAdd.rst index bedb86a..3491c9c 100644 --- a/docs/Examples/MNISTAdd.rst +++ b/docs/Examples/MNISTAdd.rst @@ -348,34 +348,59 @@ Out: :class: code-out abl - INFO - Abductive Learning on the MNIST Addition example. - abl - INFO - loop(train) [1/1] segment(train) [1/100] - abl - INFO - model loss: 2.23587 - abl - INFO - loop(train) [1/1] segment(train) [2/100] - abl - INFO - model loss: 2.23756 - abl - INFO - loop(train) [1/1] segment(train) [3/100] - abl - INFO - model loss: 2.04475 - abl - INFO - loop(train) [1/1] segment(train) [4/100] - abl - INFO - model loss: 2.01035 - abl - INFO - loop(train) [1/1] segment(train) [5/100] - abl - INFO - model loss: 1.97584 - abl - INFO - loop(train) [1/1] segment(train) [6/100] - abl - INFO - model loss: 1.91570 - abl - INFO - loop(train) [1/1] segment(train) [7/100] - abl - INFO - model loss: 1.90268 - abl - INFO - loop(train) [1/1] segment(train) [8/100] - abl - INFO - model loss: 1.77436 - abl - INFO - loop(train) [1/1] segment(train) [9/100] - abl - INFO - model loss: 1.73454 - abl - INFO - loop(train) [1/1] segment(train) [10/100] - abl - INFO - model loss: 1.62495 - abl - INFO - loop(train) [1/1] segment(train) [11/100] - abl - INFO - model loss: 1.58456 - abl - INFO - loop(train) [1/1] segment(train) [12/100] - abl - INFO - model loss: 1.62575 + abl - INFO - Working with Data. + abl - INFO - Building the Learning Part. + abl - INFO - Building the Reasoning Part. + abl - INFO - Building Evaluation Metrics. + abl - INFO - Bridge Learning and Reasoning. + abl - INFO - loop(train) [1/2] segment(train) [1/100] + abl - INFO - model loss: 2.25980 + abl - INFO - loop(train) [1/2] segment(train) [2/100] + abl - INFO - model loss: 2.14168 + abl - INFO - loop(train) [1/2] segment(train) [3/100] + abl - INFO - model loss: 2.02010 + abl - INFO - loop(train) [1/2] segment(train) [4/100] + abl - INFO - model loss: 2.01933 + abl - INFO - loop(train) [1/2] segment(train) [5/100] + abl - INFO - model loss: 2.02230 + abl - INFO - loop(train) [1/2] segment(train) [6/100] + abl - INFO - model loss: 1.97602 + abl - INFO - loop(train) [1/2] segment(train) [7/100] + abl - INFO - model loss: 1.97942 + abl - INFO - loop(train) [1/2] segment(train) [8/100] + abl - INFO - model loss: 1.92800 + abl - INFO - loop(train) [1/2] segment(train) [9/100] + abl - INFO - model loss: 1.91512 + abl - INFO - loop(train) [1/2] segment(train) [10/100] + abl - INFO - model loss: 1.83939 + abl - INFO - loop(train) [1/2] segment(train) [11/100] + abl - INFO - model loss: 1.78795 + abl - INFO - loop(train) [1/2] segment(train) [12/100] + abl - INFO - model loss: 1.78538 ... - abl - INFO - Eval start: loop(val) [1] - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.986 mnist_add/reasoning_accuracy: 0.973 - abl - INFO - Saving model: loop(save) [1] - abl - INFO - Checkpoints will be saved to log_dir/weights/model_checkpoint_loop_1.pth + abl - INFO - Eval start: loop(val) [2] + abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.993 mnist_add/reasoning_accuracy: 0.986 + abl - INFO - Saving model: loop(save) [2] + abl - INFO - Checkpoints will be saved to results/20240107_20_56_09/weights/model_checkpoint_loop_2.pth abl - INFO - Test start: - abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.983 mnist_add/reasoning_accuracy: 0.967 + abl - INFO - Evaluation ended, mnist_add/character_accuracy: 0.991 mnist_add/reasoning_accuracy: 0.980 + + + +Performance +----------- + +.. table:: + :class: centered + + +--------------+----------+------------------------------+ + | Method | Accuracy | Time to achieve the Acc. (s) | + +==============+==========+==============================+ + | NeurASP | 0.964 | 354 | + +--------------+----------+------------------------------+ + | DeepProbLog | 0.965 | 1965 | + +--------------+----------+------------------------------+ + | DeepStochLog | 0.975 | 727 | + +--------------+----------+------------------------------+ + | ABL | 0.980 | 42 | + +--------------+----------+------------------------------+ diff --git a/examples/mnist_add/mnist_add.ipynb b/examples/mnist_add/mnist_add.ipynb index 67aee51..0c947b3 100644 --- a/examples/mnist_add/mnist_add.ipynb +++ b/examples/mnist_add/mnist_add.ipynb @@ -457,6 +457,25 @@ "bridge.train(train_data, loops=2, segment_size=0.01, save_interval=1, save_dir=weights_dir)\n", "bridge.test(test_data)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Performance" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| Method | Accuracy | Time to achieve the Acc. (s) |\n", + "| :----------: | :------: | :--------------------------: |\n", + "| NeurASP | 0.964 | 354 |\n", + "| DeepProbLog | 0.965 | 1965 |\n", + "| DeepStochLog | 0.975 | 727 |\n", + "| ABL | 0.980 | 42 |" + ] } ], "metadata": {