From 819c2739a8327f73538246737752f982b07de1ef Mon Sep 17 00:00:00 2001 From: troyyyyy Date: Sun, 15 Oct 2023 20:22:13 +0800 Subject: [PATCH] [MNT] Modify test case for reasoning --- abl/reasoning/reasoner.py | 347 +++++++++----------------------------- 1 file changed, 81 insertions(+), 266 deletions(-) diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index 4ff0a75..dbfb968 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -281,6 +281,8 @@ class ReasonerBase: ) + + if __name__ == "__main__": from kb import KBBase, prolog_KB @@ -312,81 +314,34 @@ if __name__ == "__main__": def logic_forward(self, nums): return sum(nums) + + def test_add(reasoner): + res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) + print(res) + res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) + print(res) + res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) + print(res) + res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) + print(res) + res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) + print(res) + print() print("add_KB with GKB:") kb = add_KB(prebuild_GKB=True) reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 - ) - print(res) - print() + test_add(reasoner) print("add_KB without GKB:") kb = add_KB() reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 - ) - print(res) - print() + test_add(reasoner) print("add_KB without GKB:, no cache") kb = add_KB(use_cache=False) reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 - ) - print(res) - print() + test_add(reasoner) print("prolog_KB with add.pl:") kb = prolog_KB( @@ -394,27 +349,7 @@ if __name__ == "__main__": pl_file="examples/mnist_add/datasets/add.pl", ) reasoner = ReasonerBase(kb, "confidence") - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 - ) - print(res) - print() + test_add(reasoner) print("prolog_KB with add.pl using zoopt:") kb = prolog_KB( @@ -422,27 +357,7 @@ if __name__ == "__main__": pl_file="examples/mnist_add/datasets/add.pl", ) reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 - ) - print(res) - res = reasoner.batch_abduce( - [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 - ) - print(res) - print() + test_add(reasoner) print("add_KB with multiple inputs at once:") multiple_prob = [ @@ -459,7 +374,6 @@ if __name__ == "__main__": kb = add_KB() reasoner = ReasonerBase(kb, "confidence") res = reasoner.batch_abduce( - [[1, 1], [1, 2]], multiple_prob, [[1, 1], [1, 2]], [4, 8], @@ -468,7 +382,6 @@ if __name__ == "__main__": ) print(res) res = reasoner.batch_abduce( - [[1, 1], [1, 2]], multiple_prob, [[1, 1], [1, 2]], [4, 8], @@ -532,186 +445,88 @@ if __name__ == "__main__": mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) formula = [mapping[f] for f in formula] return eval("".join(formula)) + + def test_hwf(reasoner): + res = reasoner.batch_abduce( + [None], + [["5", "+", "2"]], + [3], + max_revision=2, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + [None], + [["5", "+", "9"]], + [65], + max_revision=3, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + [None], + [["5", "8", "8", "8", "8"]], + [3.17], + max_revision=5, + require_more_revision=3, + ) + print(res) + print() + + def test_hwf_multiple(reasoner): + res = reasoner.batch_abduce( + [None, None], + [["5", "+", "2"], ["5", "+", "9"]], + [3, 64], + max_revision=1, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + [None, None], + [["5", "+", "2"], ["5", "+", "9"]], + [3, 64], + max_revision=3, + require_more_revision=0, + ) + print(res) + res = reasoner.batch_abduce( + [None, None], + [["5", "+", "2"], ["5", "+", "9"]], + [3, 65], + max_revision=3, + require_more_revision=0, + ) + print(res) + print() print("HWF_KB with GKB, max_err=0.1") kb = HWF_KB(prebuild_GKB=True, GKB_len_list=[1, 3, 5], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") - res = reasoner.batch_abduce( - [["5", "+", "2"]], - [None], - [[5, 10, 2]], - [3], - max_revision=2, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "+", "2"]], - [None], - [[5, 10, 9]], - [65], - max_revision=3, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "8", "8", "8", "8"]], - [None], - [[5, 8, 8, 8, 8]], - [3.17], - max_revision=5, - require_more_revision=3, - ) - print(res) - print() + test_hwf(reasoner) print("HWF_KB without GKB, max_err=0.1") kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") - res = reasoner.batch_abduce( - [["5", "+", "2"]], - [None], - [[5, 10, 2]], - [3], - max_revision=2, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "+", "2"]], - [None], - [[5, 10, 9]], - [65], - max_revision=3, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "8", "8", "8", "8"]], - [None], - [[5, 8, 8, 8, 8]], - [3.17], - max_revision=5, - require_more_revision=3, - ) - print(res) - print() + test_hwf(reasoner) print("HWF_KB with GKB, max_err=1") kb = HWF_KB(GKB_len_list=[1, 3, 5], prebuild_GKB=True, max_err=1) reasoner = ReasonerBase(kb, "hamming") - res = reasoner.batch_abduce( - [["5", "+", "2"]], - [None], - [[5, 10, 2]], - [3], - max_revision=2, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "+", "2"]], - [None], - [[5, 10, 9]], - [65], - max_revision=3, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "8", "8", "8", "8"]], - [None], - [[5, 8, 8, 8, 8]], - [3.17], - max_revision=5, - require_more_revision=3, - ) - print(res) - print() + test_hwf(reasoner) print("HWF_KB without GKB, max_err=1") kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=1) reasoner = ReasonerBase(kb, "hamming") - res = reasoner.batch_abduce( - [["5", "+", "2"]], - [None], - [[5, 10, 2]], - [3], - max_revision=2, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "+", "2"]], - [None], - [[5, 10, 9]], - [65], - max_revision=3, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "8", "8", "8", "8"]], - [None], - [[5, 8, 8, 8, 8]], - [3.17], - max_revision=5, - require_more_revision=3, - ) - print(res) - print() + test_hwf(reasoner) print("HWF_KB with multiple inputs at once:") kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) reasoner = ReasonerBase(kb, "hamming") - res = reasoner.batch_abduce( - [["5", "+", "2"], ["5", "+", "9"]], - [None, None], - [[5, 10, 2], [5, 10, 9]], - [3, 64], - max_revision=1, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "+", "2"], ["5", "+", "9"]], - [None, None], - [[5, 10, 2], [5, 10, 9]], - [3, 64], - max_revision=3, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "+", "2"], ["5", "+", "9"]], - [None, None], - [[5, 10, 2], [5, 10, 9]], - [3, 65], - max_revision=3, - require_more_revision=0, - ) - print(res) - print() + test_hwf_multiple(reasoner) + print("max_revision is float") - res = reasoner.batch_abduce( - [["5", "+", "2"], ["5", "+", "9"]], - [None, None], - [[5, 10, 2], [5, 10, 9]], - [3, 64], - max_revision=0.5, - require_more_revision=0, - ) - print(res) - res = reasoner.batch_abduce( - [["5", "+", "2"], ["5", "+", "9"]], - [None, None], - [[5, 10, 2], [5, 10, 9]], - [3, 64], - max_revision=0.9, - require_more_revision=0, - ) - print(res) - print() + test_hwf_multiple(reasoner) class HED_prolog_KB(prolog_KB): def __init__(self, pseudo_label_list, pl_file):