| @@ -275,12 +275,11 @@ class ReasonerBase: | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| from kb import KBBase, GroundKB, PrologKB | from kb import KBBase, GroundKB, PrologKB | ||||
| prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||||
| from abl.structures import ListData | |||||
| prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||||
| ################################ | |||||
| # Test for MNIST Add reasoning # | |||||
| ################################ | |||||
| class AddKB(KBBase): | class AddKB(KBBase): | ||||
| def __init__(self, pseudo_label_list=list(range(10)), | def __init__(self, pseudo_label_list=list(range(10)), | ||||
| @@ -290,38 +289,54 @@ if __name__ == "__main__": | |||||
| def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
| return sum(nums) | return sum(nums) | ||||
| class AddGroundKB(GroundKB): | |||||
| class AddGroundKB(GroundKB, AddKB): | |||||
| def __init__(self, pseudo_label_list=list(range(10)), | def __init__(self, pseudo_label_list=list(range(10)), | ||||
| GKB_len_list=[2]): | GKB_len_list=[2]): | ||||
| super().__init__(pseudo_label_list, GKB_len_list) | super().__init__(pseudo_label_list, GKB_len_list) | ||||
| def logic_forward(self, nums): | |||||
| return sum(nums) | |||||
| def logic_forward(self, nums): | def logic_forward(self, nums): | ||||
| return sum(nums) | return sum(nums) | ||||
| def test_add(reasoner): | def test_add(reasoner): | ||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||||
| # favor 1 in first one | |||||
| prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||||
| # favor 7 in first one | |||||
| prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||||
| data_samples_add = ListData() | |||||
| data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] | |||||
| data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] | |||||
| data_samples_add.Y = [8, 8, 17, 10] | |||||
| res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) | |||||
| print(res) | print(res) | ||||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||||
| res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) | |||||
| print(res) | print(res) | ||||
| res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) | |||||
| print(res) # due to more revision allowed, for the 4th, it will favor [7,3] over [1,9] | |||||
| print() | print() | ||||
| print("AddKB with GKB:") | |||||
| print("AddGroundKB:") | |||||
| kb = AddGroundKB() | kb = AddGroundKB() | ||||
| reasoner = ReasonerBase(kb, "confidence") | reasoner = ReasonerBase(kb, "confidence") | ||||
| test_add(reasoner) | test_add(reasoner) | ||||
| print("AddKB without GKB:") | |||||
| print("AddKB:") | |||||
| kb = AddKB() | kb = AddKB() | ||||
| reasoner = ReasonerBase(kb, "confidence") | reasoner = ReasonerBase(kb, "confidence") | ||||
| test_add(reasoner) | test_add(reasoner) | ||||
| print("AddKB without GKB, no cache") | |||||
| print("AddKB, no cache") | |||||
| kb = AddKB(use_cache=False) | kb = AddKB(use_cache=False) | ||||
| reasoner = ReasonerBase(kb, "confidence") | reasoner = ReasonerBase(kb, "confidence") | ||||
| test_add(reasoner) | test_add(reasoner) | ||||
| @@ -339,45 +354,20 @@ if __name__ == "__main__": | |||||
| ) | ) | ||||
| reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | ||||
| test_add(reasoner) | test_add(reasoner) | ||||
| print("AddKB with multiple inputs at once:") | |||||
| multiple_prob = [[ | |||||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
| ], | |||||
| [ | |||||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||||
| ]] | |||||
| kb = AddKB() | |||||
| reasoner = ReasonerBase(kb, "confidence") | |||||
| res = reasoner.batch_abduce( | |||||
| multiple_prob, | |||||
| [[1, 1], [1, 2]], | |||||
| [4, 8], | |||||
| max_revision=2, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| multiple_prob, | |||||
| [[1, 1], [1, 2]], | |||||
| [4, 8], | |||||
| max_revision=2, | |||||
| require_more_revision=1, | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| ################################ | |||||
| #### Test for HWF reasoning #### | |||||
| ################################ | |||||
| class HwfKB(KBBase): | class HwfKB(KBBase): | ||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | ||||
| "+", "-", "times", "div"], | "+", "-", "times", "div"], | ||||
| max_err=1e-3, | max_err=1e-3, | ||||
| use_cache=False, | |||||
| ): | ): | ||||
| super().__init__(pseudo_label_list, max_err) | |||||
| super().__init__(pseudo_label_list, max_err, use_cache) | |||||
| def _valid_candidate(self, formula): | def _valid_candidate(self, formula): | ||||
| if len(formula) % 2 == 0: | if len(formula) % 2 == 0: | ||||
| @@ -397,7 +387,7 @@ if __name__ == "__main__": | |||||
| formula = [mapping[f] for f in formula] | formula = [mapping[f] for f in formula] | ||||
| return eval("".join(formula)) | return eval("".join(formula)) | ||||
| class HwfGroundKB(GroundKB): | |||||
| class HwfGroundKB(GroundKB, HwfKB): | |||||
| def __init__( | def __init__( | ||||
| self, | self, | ||||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | ||||
| @@ -407,6 +397,17 @@ if __name__ == "__main__": | |||||
| ): | ): | ||||
| super().__init__(pseudo_label_list, GKB_len_list, max_err) | super().__init__(pseudo_label_list, GKB_len_list, max_err) | ||||
| def _valid_candidate(self, formula): | |||||
| if len(formula) % 2 == 0: | |||||
| return False | |||||
| for i in range(len(formula)): | |||||
| if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||||
| return False | |||||
| if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||||
| return False | |||||
| return True | |||||
| def _valid_candidate(self, formula): | def _valid_candidate(self, formula): | ||||
| if len(formula) % 2 == 0: | if len(formula) % 2 == 0: | ||||
| return False | return False | ||||
| @@ -417,6 +418,16 @@ if __name__ == "__main__": | |||||
| return False | return False | ||||
| return True | return True | ||||
| def logic_forward(self, formula): | |||||
| if not self._valid_candidate(formula): | |||||
| return None | |||||
| mapping = {str(i): str(i) for i in range(1, 10)} | |||||
| mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||||
| formula = [mapping[f] for f in formula] | |||||
| return eval("".join(formula)) | |||||
| def logic_forward(self, formula): | def logic_forward(self, formula): | ||||
| if not self._valid_candidate(formula): | if not self._valid_candidate(formula): | ||||
| return None | return None | ||||
| @@ -426,87 +437,46 @@ if __name__ == "__main__": | |||||
| return eval("".join(formula)) | return eval("".join(formula)) | ||||
| def test_hwf(reasoner): | def test_hwf(reasoner): | ||||
| res = reasoner.batch_abduce( | |||||
| [None], | |||||
| [["5", "+", "2"]], | |||||
| [3], | |||||
| max_revision=2, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| data_samples_hwf = ListData() | |||||
| data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]] | |||||
| data_samples_hwf.pred_prob = [None, None, None, None] | |||||
| data_samples_hwf.Y = [3, 64, 65, 3.17] | |||||
| res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) | |||||
| print(res) | print(res) | ||||
| res = reasoner.batch_abduce( | |||||
| [None], | |||||
| [["5", "+", "9"]], | |||||
| [65], | |||||
| max_revision=3, | |||||
| require_more_revision=0, | |||||
| ) | |||||
| res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) | |||||
| print(res) | print(res) | ||||
| res = reasoner.batch_abduce( | |||||
| [None], | |||||
| [["5", "8", "8", "8", "8"]], | |||||
| [3.17], | |||||
| max_revision=5, | |||||
| require_more_revision=3, | |||||
| ) | |||||
| res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) | |||||
| print(res) | print(res) | ||||
| print() | print() | ||||
| def test_hwf_multiple(reasoner, max_revisions): | |||||
| res = reasoner.batch_abduce( | |||||
| [None, None], | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [3, 64], | |||||
| max_revision=max_revisions[0], | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [None, None], | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [3, 64], | |||||
| max_revision=max_revisions[1], | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce( | |||||
| [None, None], | |||||
| [["5", "+", "2"], ["5", "+", "9"]], | |||||
| [3, 65], | |||||
| max_revision=max_revisions[2], | |||||
| require_more_revision=0, | |||||
| ) | |||||
| print(res) | |||||
| print() | |||||
| print("HwfKB with GKB, max_err=0.1") | |||||
| print("HwfGroundKB, max_err=0.1:") | |||||
| kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) | kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) | ||||
| reasoner = ReasonerBase(kb, "hamming") | reasoner = ReasonerBase(kb, "hamming") | ||||
| test_hwf(reasoner) | test_hwf(reasoner) | ||||
| print("HwfKB without GKB, max_err=0.1") | |||||
| print("HwfKB, max_err=0.1:") | |||||
| kb = HwfKB(max_err=0.1) | kb = HwfKB(max_err=0.1) | ||||
| reasoner = ReasonerBase(kb, "hamming") | reasoner = ReasonerBase(kb, "hamming") | ||||
| test_hwf(reasoner) | test_hwf(reasoner) | ||||
| print("HwfKB with GKB, max_err=1") | |||||
| print("HwfGroundKB, max_err=1:") | |||||
| kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1) | kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1) | ||||
| reasoner = ReasonerBase(kb, "hamming") | reasoner = ReasonerBase(kb, "hamming") | ||||
| test_hwf(reasoner) | test_hwf(reasoner) | ||||
| print("HwfKB without GKB, max_err=1") | |||||
| print("HwfKB, max_err=1:") | |||||
| kb = HwfKB(max_err=1) | kb = HwfKB(max_err=1) | ||||
| reasoner = ReasonerBase(kb, "hamming") | reasoner = ReasonerBase(kb, "hamming") | ||||
| test_hwf(reasoner) | test_hwf(reasoner) | ||||
| print("HwfKB with multiple inputs at once:") | |||||
| kb = HwfKB(max_err=0.1) | |||||
| reasoner = ReasonerBase(kb, "hamming") | |||||
| test_hwf_multiple(reasoner, max_revisions=[1,3,3]) | |||||
| print("max_revision is float") | |||||
| test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) | |||||
| ################################ | |||||
| #### Test for HED reasoning #### | |||||
| ################################ | |||||
| class HedKB(PrologKB): | class HedKB(PrologKB): | ||||
| def __init__(self, pseudo_label_list, pl_file): | def __init__(self, pseudo_label_list, pl_file): | ||||
| super().__init__(pseudo_label_list, pl_file) | super().__init__(pseudo_label_list, pl_file) | ||||
| @@ -599,28 +569,24 @@ if __name__ == "__main__": | |||||
| inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | ||||
| rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] | rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] | ||||
| print("HedKB logic forward") | |||||
| print(kb.logic_forward(consist_exs)) | |||||
| print("HedKB logic forward:") | |||||
| print(kb.logic_forward(consist_exs), end=" ") | |||||
| print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) | print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) | ||||
| print() | print() | ||||
| print("HedKB consist rule") | |||||
| print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) | |||||
| print("HedKB consist rule:") | |||||
| print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules), end=" ") | |||||
| print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) | print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) | ||||
| print() | print() | ||||
| data_sample_hed = ListData() | |||||
| data_sample_hed.pred_pseudo_label = [consist_exs, inconsist_exs1, inconsist_exs2] | |||||
| data_sample_hed.pred_prob = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] | |||||
| data_sample_hed.Y = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] | |||||
| print("HedReasoner abduce") | print("HedReasoner abduce") | ||||
| res = reasoner.abduce( | |||||
| [[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.abduce( | |||||
| [[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.abduce( | |||||
| [[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) | |||||
| ) | |||||
| print(res) | |||||
| res = reasoner.batch_abduce(data_sample_hed) | |||||
| for r in res: | |||||
| print(r) | |||||
| print() | print() | ||||
| print("HedReasoner abduce rules") | print("HedReasoner abduce rules") | ||||