| @@ -281,6 +281,8 @@ class ReasonerBase: | |||
| ) | |||
| if __name__ == "__main__": | |||
| from kb import KBBase, prolog_KB | |||
| @@ -312,81 +314,34 @@ if __name__ == "__main__": | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| def test_add(reasoner): | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| print() | |||
| print("add_KB with GKB:") | |||
| kb = add_KB(prebuild_GKB=True) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_add(reasoner) | |||
| print("add_KB without GKB:") | |||
| kb = add_KB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_add(reasoner) | |||
| print("add_KB without GKB:, no cache") | |||
| kb = add_KB(use_cache=False) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_add(reasoner) | |||
| print("prolog_KB with add.pl:") | |||
| kb = prolog_KB( | |||
| @@ -394,27 +349,7 @@ if __name__ == "__main__": | |||
| pl_file="examples/mnist_add/datasets/add.pl", | |||
| ) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_add(reasoner) | |||
| print("prolog_KB with add.pl using zoopt:") | |||
| kb = prolog_KB( | |||
| @@ -422,27 +357,7 @@ if __name__ == "__main__": | |||
| pl_file="examples/mnist_add/datasets/add.pl", | |||
| ) | |||
| reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_add(reasoner) | |||
| print("add_KB with multiple inputs at once:") | |||
| multiple_prob = [ | |||
| @@ -459,7 +374,6 @@ if __name__ == "__main__": | |||
| kb = add_KB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1], [1, 2]], | |||
| multiple_prob, | |||
| [[1, 1], [1, 2]], | |||
| [4, 8], | |||
| @@ -468,7 +382,6 @@ if __name__ == "__main__": | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [[1, 1], [1, 2]], | |||
| multiple_prob, | |||
| [[1, 1], [1, 2]], | |||
| [4, 8], | |||
| @@ -532,186 +445,88 @@ if __name__ == "__main__": | |||
| mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
| formula = [mapping[f] for f in formula] | |||
| return eval("".join(formula)) | |||
| def test_hwf(reasoner): | |||
| res = reasoner.batch_abduce( | |||
| [None], | |||
| [["5", "+", "2"]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None], | |||
| [["5", "+", "9"]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None], | |||
| [["5", "8", "8", "8", "8"]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| print(res) | |||
| print() | |||
| def test_hwf_multiple(reasoner): | |||
| res = reasoner.batch_abduce( | |||
| [None, None], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [3, 64], | |||
| max_revision=1, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None, None], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [3, 64], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None, None], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [3, 65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| print() | |||
| print("HWF_KB with GKB, max_err=0.1") | |||
| kb = HWF_KB(prebuild_GKB=True, GKB_len_list=[1, 3, 5], max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 2]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 9]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "8", "8", "8", "8"]], | |||
| [None], | |||
| [[5, 8, 8, 8, 8]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_hwf(reasoner) | |||
| print("HWF_KB without GKB, max_err=0.1") | |||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 2]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 9]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "8", "8", "8", "8"]], | |||
| [None], | |||
| [[5, 8, 8, 8, 8]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_hwf(reasoner) | |||
| print("HWF_KB with GKB, max_err=1") | |||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], prebuild_GKB=True, max_err=1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 2]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 9]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "8", "8", "8", "8"]], | |||
| [None], | |||
| [[5, 8, 8, 8, 8]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_hwf(reasoner) | |||
| print("HWF_KB without GKB, max_err=1") | |||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 2]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"]], | |||
| [None], | |||
| [[5, 10, 9]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "8", "8", "8", "8"]], | |||
| [None], | |||
| [[5, 8, 8, 8, 8]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_hwf(reasoner) | |||
| print("HWF_KB with multiple inputs at once:") | |||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [None, None], | |||
| [[5, 10, 2], [5, 10, 9]], | |||
| [3, 64], | |||
| max_revision=1, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [None, None], | |||
| [[5, 10, 2], [5, 10, 9]], | |||
| [3, 64], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [None, None], | |||
| [[5, 10, 2], [5, 10, 9]], | |||
| [3, 65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_hwf_multiple(reasoner) | |||
| print("max_revision is float") | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [None, None], | |||
| [[5, 10, 2], [5, 10, 9]], | |||
| [3, 64], | |||
| max_revision=0.5, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [None, None], | |||
| [[5, 10, 2], [5, 10, 9]], | |||
| [3, 64], | |||
| max_revision=0.9, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| print() | |||
| test_hwf_multiple(reasoner) | |||
| class HED_prolog_KB(prolog_KB): | |||
| def __init__(self, pseudo_label_list, pl_file): | |||