| @@ -281,6 +281,8 @@ class ReasonerBase: | |||||
| ) | ) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| from kb import KBBase, prolog_KB | from kb import KBBase, prolog_KB | ||||
| @@ -312,81 +314,34 @@ if __name__ == "__main__": | |||||
| def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
| return sum(nums) | return sum(nums) | ||||
| def test_add(reasoner): | |||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||||
| print(res) | |||||
| print() | |||||
| print("add_KB with GKB:") | print("add_KB with GKB:") | ||||
| kb = add_KB(prebuild_GKB=True) | kb = add_KB(prebuild_GKB=True) | ||||
| reasoner = ReasonerBase(kb, "confidence") | reasoner = ReasonerBase(kb, "confidence") | ||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_add(reasoner) | |||||
| print("add_KB without GKB:") | print("add_KB without GKB:") | ||||
| kb = add_KB() | kb = add_KB() | ||||
| reasoner = ReasonerBase(kb, "confidence") | reasoner = ReasonerBase(kb, "confidence") | ||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_add(reasoner) | |||||
| print("add_KB without GKB:, no cache") | print("add_KB without GKB:, no cache") | ||||
| kb = add_KB(use_cache=False) | kb = add_KB(use_cache=False) | ||||
| reasoner = ReasonerBase(kb, "confidence") | reasoner = ReasonerBase(kb, "confidence") | ||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_add(reasoner) | |||||
| print("prolog_KB with add.pl:") | print("prolog_KB with add.pl:") | ||||
| kb = prolog_KB( | kb = prolog_KB( | ||||
| @@ -394,27 +349,7 @@ if __name__ == "__main__": | |||||
| pl_file="examples/mnist_add/datasets/add.pl", | pl_file="examples/mnist_add/datasets/add.pl", | ||||
| ) | ) | ||||
| reasoner = ReasonerBase(kb, "confidence") | reasoner = ReasonerBase(kb, "confidence") | ||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_add(reasoner) | |||||
| print("prolog_KB with add.pl using zoopt:") | print("prolog_KB with add.pl using zoopt:") | ||||
| kb = prolog_KB( | kb = prolog_KB( | ||||
| @@ -422,27 +357,7 @@ if __name__ == "__main__": | |||||
| pl_file="examples/mnist_add/datasets/add.pl", | pl_file="examples/mnist_add/datasets/add.pl", | ||||
| ) | ) | ||||
| reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | ||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [[1, 1]], prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0 | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_add(reasoner) | |||||
| print("add_KB with multiple inputs at once:") | print("add_KB with multiple inputs at once:") | ||||
| multiple_prob = [ | multiple_prob = [ | ||||
| @@ -459,7 +374,6 @@ if __name__ == "__main__": | |||||
| kb = add_KB() | kb = add_KB() | ||||
| reasoner = ReasonerBase(kb, "confidence") | reasoner = ReasonerBase(kb, "confidence") | ||||
| res = reasoner.batch_abduce( | res = reasoner.batch_abduce( | ||||
| [[1, 1], [1, 2]], | |||||
| multiple_prob, | multiple_prob, | ||||
| [[1, 1], [1, 2]], | [[1, 1], [1, 2]], | ||||
| [4, 8], | [4, 8], | ||||
| @@ -468,7 +382,6 @@ if __name__ == "__main__": | |||||
| ) | ) | ||||
| print(res) | print(res) | ||||
| res = reasoner.batch_abduce( | res = reasoner.batch_abduce( | ||||
| [[1, 1], [1, 2]], | |||||
| multiple_prob, | multiple_prob, | ||||
| [[1, 1], [1, 2]], | [[1, 1], [1, 2]], | ||||
| [4, 8], | [4, 8], | ||||
| @@ -532,186 +445,88 @@ if __name__ == "__main__": | |||||
| mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | ||||
| formula = [mapping[f] for f in formula] | formula = [mapping[f] for f in formula] | ||||
| return eval("".join(formula)) | return eval("".join(formula)) | ||||
| def test_hwf(reasoner): | |||||
| res = reasoner.batch_abduce( | |||||
| [None], | |||||
| [["5", "+", "2"]], | |||||
| [3], | |||||
| max_revision=2, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [None], | |||||
| [["5", "+", "9"]], | |||||
| [65], | |||||
| max_revision=3, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [None], | |||||
| [["5", "8", "8", "8", "8"]], | |||||
| [3.17], | |||||
| max_revision=5, | |||||
| require_more_revision=3, | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| def test_hwf_multiple(reasoner): | |||||
| res = reasoner.batch_abduce( | |||||
| [None, None], | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [3, 64], | |||||
| max_revision=1, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [None, None], | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [3, 64], | |||||
| max_revision=3, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [None, None], | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [3, 65], | |||||
| max_revision=3, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| print("HWF_KB with GKB, max_err=0.1") | print("HWF_KB with GKB, max_err=0.1") | ||||
| kb = HWF_KB(prebuild_GKB=True, GKB_len_list=[1, 3, 5], max_err=0.1) | kb = HWF_KB(prebuild_GKB=True, GKB_len_list=[1, 3, 5], max_err=0.1) | ||||
| reasoner = ReasonerBase(kb, "hamming") | reasoner = ReasonerBase(kb, "hamming") | ||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"]], | |||||
| [None], | |||||
| [[5, 10, 2]], | |||||
| [3], | |||||
| max_revision=2, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"]], | |||||
| [None], | |||||
| [[5, 10, 9]], | |||||
| [65], | |||||
| max_revision=3, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "8", "8", "8", "8"]], | |||||
| [None], | |||||
| [[5, 8, 8, 8, 8]], | |||||
| [3.17], | |||||
| max_revision=5, | |||||
| require_more_revision=3, | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_hwf(reasoner) | |||||
| print("HWF_KB without GKB, max_err=0.1") | print("HWF_KB without GKB, max_err=0.1") | ||||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) | kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) | ||||
| reasoner = ReasonerBase(kb, "hamming") | reasoner = ReasonerBase(kb, "hamming") | ||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"]], | |||||
| [None], | |||||
| [[5, 10, 2]], | |||||
| [3], | |||||
| max_revision=2, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"]], | |||||
| [None], | |||||
| [[5, 10, 9]], | |||||
| [65], | |||||
| max_revision=3, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "8", "8", "8", "8"]], | |||||
| [None], | |||||
| [[5, 8, 8, 8, 8]], | |||||
| [3.17], | |||||
| max_revision=5, | |||||
| require_more_revision=3, | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_hwf(reasoner) | |||||
| print("HWF_KB with GKB, max_err=1") | print("HWF_KB with GKB, max_err=1") | ||||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], prebuild_GKB=True, max_err=1) | kb = HWF_KB(GKB_len_list=[1, 3, 5], prebuild_GKB=True, max_err=1) | ||||
| reasoner = ReasonerBase(kb, "hamming") | reasoner = ReasonerBase(kb, "hamming") | ||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"]], | |||||
| [None], | |||||
| [[5, 10, 2]], | |||||
| [3], | |||||
| max_revision=2, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"]], | |||||
| [None], | |||||
| [[5, 10, 9]], | |||||
| [65], | |||||
| max_revision=3, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "8", "8", "8", "8"]], | |||||
| [None], | |||||
| [[5, 8, 8, 8, 8]], | |||||
| [3.17], | |||||
| max_revision=5, | |||||
| require_more_revision=3, | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_hwf(reasoner) | |||||
| print("HWF_KB without GKB, max_err=1") | print("HWF_KB without GKB, max_err=1") | ||||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=1) | kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=1) | ||||
| reasoner = ReasonerBase(kb, "hamming") | reasoner = ReasonerBase(kb, "hamming") | ||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"]], | |||||
| [None], | |||||
| [[5, 10, 2]], | |||||
| [3], | |||||
| max_revision=2, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"]], | |||||
| [None], | |||||
| [[5, 10, 9]], | |||||
| [65], | |||||
| max_revision=3, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "8", "8", "8", "8"]], | |||||
| [None], | |||||
| [[5, 8, 8, 8, 8]], | |||||
| [3.17], | |||||
| max_revision=5, | |||||
| require_more_revision=3, | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_hwf(reasoner) | |||||
| print("HWF_KB with multiple inputs at once:") | print("HWF_KB with multiple inputs at once:") | ||||
| kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) | kb = HWF_KB(GKB_len_list=[1, 3, 5], max_err=0.1) | ||||
| reasoner = ReasonerBase(kb, "hamming") | reasoner = ReasonerBase(kb, "hamming") | ||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [None, None], | |||||
| [[5, 10, 2], [5, 10, 9]], | |||||
| [3, 64], | |||||
| max_revision=1, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [None, None], | |||||
| [[5, 10, 2], [5, 10, 9]], | |||||
| [3, 64], | |||||
| max_revision=3, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [None, None], | |||||
| [[5, 10, 2], [5, 10, 9]], | |||||
| [3, 65], | |||||
| max_revision=3, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_hwf_multiple(reasoner) | |||||
| print("max_revision is float") | print("max_revision is float") | ||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [None, None], | |||||
| [[5, 10, 2], [5, 10, 9]], | |||||
| [3, 64], | |||||
| max_revision=0.5, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [None, None], | |||||
| [[5, 10, 2], [5, 10, 9]], | |||||
| [3, 64], | |||||
| max_revision=0.9, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| test_hwf_multiple(reasoner) | |||||
| class HED_prolog_KB(prolog_KB): | class HED_prolog_KB(prolog_KB): | ||||
| def __init__(self, pseudo_label_list, pl_file): | def __init__(self, pseudo_label_list, pl_file): | ||||