| @@ -275,12 +275,11 @@ class ReasonerBase: | |||
| if __name__ == "__main__": | |||
| from kb import KBBase, GroundKB, PrologKB | |||
| prob1 = [[[0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| from abl.structures import ListData | |||
| prob2 = [[[0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]] | |||
| ################################ | |||
| # Test for MNIST Add reasoning # | |||
| ################################ | |||
| class AddKB(KBBase): | |||
| def __init__(self, pseudo_label_list=list(range(10)), | |||
| @@ -290,38 +289,54 @@ if __name__ == "__main__": | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| class AddGroundKB(GroundKB): | |||
| class AddGroundKB(GroundKB, AddKB): | |||
| def __init__(self, pseudo_label_list=list(range(10)), | |||
| GKB_len_list=[2]): | |||
| super().__init__(pseudo_label_list, GKB_len_list) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| def logic_forward(self, nums): | |||
| return sum(nums) | |||
| def test_add(reasoner): | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob2, [[1, 1]], [8], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [17], max_revision=1, require_more_revision=0) | |||
| # favor 1 in first one | |||
| prob1 = [[0, 0.99, 0, 0, 0, 0, 0, 0.01, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
| # favor 7 in first one | |||
| prob2 = [[0, 0.01, 0, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]] | |||
| data_samples_add = ListData() | |||
| data_samples_add.pred_pseudo_label = [[1, 1], [1, 1], [1, 1], [1, 1]] | |||
| data_samples_add.pred_prob = [prob1, prob2, prob1, prob2] | |||
| data_samples_add.Y = [8, 8, 17, 10] | |||
| res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(prob1, [[1, 1]], [20], max_revision=2, require_more_revision=0) | |||
| res = reasoner.batch_abduce(data_samples_add, max_revision=1, require_more_revision=1) | |||
| print(res) | |||
| res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce(data_samples_add, max_revision=2, require_more_revision=1) | |||
| print(res) # due to more revision allowed, for the 4th, it will favor [7,3] over [1,9] | |||
| print() | |||
| print("AddKB with GKB:") | |||
| print("AddGroundKB:") | |||
| kb = AddGroundKB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| test_add(reasoner) | |||
| print("AddKB without GKB:") | |||
| print("AddKB:") | |||
| kb = AddKB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| test_add(reasoner) | |||
| print("AddKB without GKB, no cache") | |||
| print("AddKB, no cache") | |||
| kb = AddKB(use_cache=False) | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| test_add(reasoner) | |||
| @@ -339,45 +354,20 @@ if __name__ == "__main__": | |||
| ) | |||
| reasoner = ReasonerBase(kb, "confidence", use_zoopt=True) | |||
| test_add(reasoner) | |||
| print("AddKB with multiple inputs at once:") | |||
| multiple_prob = [[ | |||
| [0, 0.99, 0.01, 0, 0, 0, 0, 0, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ], | |||
| [ | |||
| [0, 0, 0.01, 0, 0, 0, 0, 0.99, 0, 0], | |||
| [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], | |||
| ]] | |||
| kb = AddKB() | |||
| reasoner = ReasonerBase(kb, "confidence") | |||
| res = reasoner.batch_abduce( | |||
| multiple_prob, | |||
| [[1, 1], [1, 2]], | |||
| [4, 8], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| multiple_prob, | |||
| [[1, 1], [1, 2]], | |||
| [4, 8], | |||
| max_revision=2, | |||
| require_more_revision=1, | |||
| ) | |||
| print(res) | |||
| print() | |||
| ################################ | |||
| #### Test for HWF reasoning #### | |||
| ################################ | |||
| class HwfKB(KBBase): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||
| "+", "-", "times", "div"], | |||
| max_err=1e-3, | |||
| use_cache=False, | |||
| ): | |||
| super().__init__(pseudo_label_list, max_err) | |||
| super().__init__(pseudo_label_list, max_err, use_cache) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| @@ -397,7 +387,7 @@ if __name__ == "__main__": | |||
| formula = [mapping[f] for f in formula] | |||
| return eval("".join(formula)) | |||
| class HwfGroundKB(GroundKB): | |||
| class HwfGroundKB(GroundKB, HwfKB): | |||
| def __init__( | |||
| self, | |||
| pseudo_label_list=["1", "2", "3", "4", "5", "6", "7", "8", "9", | |||
| @@ -407,6 +397,17 @@ if __name__ == "__main__": | |||
| ): | |||
| super().__init__(pseudo_label_list, GKB_len_list, max_err) | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| for i in range(len(formula)): | |||
| if i % 2 == 0 and formula[i] not in ["1", "2", "3", "4", "5", "6", "7", "8", "9"]: | |||
| return False | |||
| if i % 2 != 0 and formula[i] not in ["+", "-", "times", "div"]: | |||
| return False | |||
| return True | |||
| def _valid_candidate(self, formula): | |||
| if len(formula) % 2 == 0: | |||
| return False | |||
| @@ -417,6 +418,16 @@ if __name__ == "__main__": | |||
| return False | |||
| return True | |||
| def logic_forward(self, formula): | |||
| if not self._valid_candidate(formula): | |||
| return None | |||
| mapping = {str(i): str(i) for i in range(1, 10)} | |||
| mapping.update({"+": "+", "-": "-", "times": "*", "div": "/"}) | |||
| formula = [mapping[f] for f in formula] | |||
| return eval("".join(formula)) | |||
| def logic_forward(self, formula): | |||
| if not self._valid_candidate(formula): | |||
| return None | |||
| @@ -426,87 +437,46 @@ if __name__ == "__main__": | |||
| return eval("".join(formula)) | |||
| def test_hwf(reasoner): | |||
| res = reasoner.batch_abduce( | |||
| [None], | |||
| [["5", "+", "2"]], | |||
| [3], | |||
| max_revision=2, | |||
| require_more_revision=0, | |||
| ) | |||
| data_samples_hwf = ListData() | |||
| data_samples_hwf.pred_pseudo_label = [["5", "+", "2"], ["5", "+", "9"], ["5", "+", "9"], ["5", "-", "8", "8", "8"]] | |||
| data_samples_hwf.pred_prob = [None, None, None, None] | |||
| data_samples_hwf.Y = [3, 64, 65, 3.17] | |||
| res = reasoner.batch_abduce(data_samples_hwf, max_revision=3, require_more_revision=0) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None], | |||
| [["5", "+", "9"]], | |||
| [65], | |||
| max_revision=3, | |||
| require_more_revision=0, | |||
| ) | |||
| res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.5, require_more_revision=3) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None], | |||
| [["5", "8", "8", "8", "8"]], | |||
| [3.17], | |||
| max_revision=5, | |||
| require_more_revision=3, | |||
| ) | |||
| res = reasoner.batch_abduce(data_samples_hwf, max_revision=0.9, require_more_revision=0) | |||
| print(res) | |||
| print() | |||
| def test_hwf_multiple(reasoner, max_revisions): | |||
| res = reasoner.batch_abduce( | |||
| [None, None], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [3, 64], | |||
| max_revision=max_revisions[0], | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None, None], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [3, 64], | |||
| max_revision=max_revisions[1], | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce( | |||
| [None, None], | |||
| [["5", "+", "2"], ["5", "+", "9"]], | |||
| [3, 65], | |||
| max_revision=max_revisions[2], | |||
| require_more_revision=0, | |||
| ) | |||
| print(res) | |||
| print() | |||
| print("HwfKB with GKB, max_err=0.1") | |||
| print("HwfGroundKB, max_err=0.1:") | |||
| kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf(reasoner) | |||
| print("HwfKB without GKB, max_err=0.1") | |||
| print("HwfKB, max_err=0.1:") | |||
| kb = HwfKB(max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf(reasoner) | |||
| print("HwfKB with GKB, max_err=1") | |||
| print("HwfGroundKB, max_err=1:") | |||
| kb = HwfGroundKB(GKB_len_list=[1, 3, 5], max_err=1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf(reasoner) | |||
| print("HwfKB without GKB, max_err=1") | |||
| print("HwfKB, max_err=1:") | |||
| kb = HwfKB(max_err=1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf(reasoner) | |||
| print("HwfKB with multiple inputs at once:") | |||
| kb = HwfKB(max_err=0.1) | |||
| reasoner = ReasonerBase(kb, "hamming") | |||
| test_hwf_multiple(reasoner, max_revisions=[1,3,3]) | |||
| print("max_revision is float") | |||
| test_hwf_multiple(reasoner, max_revisions=[0.5,0.9,0.9]) | |||
| ################################ | |||
| #### Test for HED reasoning #### | |||
| ################################ | |||
| class HedKB(PrologKB): | |||
| def __init__(self, pseudo_label_list, pl_file): | |||
| super().__init__(pseudo_label_list, pl_file) | |||
| @@ -599,28 +569,24 @@ if __name__ == "__main__": | |||
| inconsist_exs2 = [[1, "+", 0, "=", 0], [1, "=", 1, "=", 0], [0, "=", 0, "=", 1, 1]] | |||
| rules = ["my_op([0], [0], [0])", "my_op([1], [1], [1, 0])"] | |||
| print("HedKB logic forward") | |||
| print(kb.logic_forward(consist_exs)) | |||
| print("HedKB logic forward:") | |||
| print(kb.logic_forward(consist_exs), end=" ") | |||
| print(kb.logic_forward(inconsist_exs1), kb.logic_forward(inconsist_exs2)) | |||
| print() | |||
| print("HedKB consist rule") | |||
| print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules)) | |||
| print("HedKB consist rule:") | |||
| print(kb.consist_rule([1, "+", 1, "=", 1, 0], rules), end=" ") | |||
| print(kb.consist_rule([1, "+", 1, "=", 1, 1], rules)) | |||
| print() | |||
| data_sample_hed = ListData() | |||
| data_sample_hed.pred_pseudo_label = [consist_exs, inconsist_exs1, inconsist_exs2] | |||
| data_sample_hed.pred_prob = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] | |||
| data_sample_hed.Y = [[None] * len(consist_exs), [None] * len(inconsist_exs1), [None] * len(inconsist_exs2)] | |||
| print("HedReasoner abduce") | |||
| res = reasoner.abduce( | |||
| [[[None]]] * len(consist_exs), consist_exs, [None] * len(consist_exs) | |||
| ) | |||
| print(res) | |||
| res = reasoner.abduce( | |||
| [[[None]]] * len(inconsist_exs1), inconsist_exs1, [None] * len(inconsist_exs1) | |||
| ) | |||
| print(res) | |||
| res = reasoner.abduce( | |||
| [[[None]]] * len(inconsist_exs2), inconsist_exs2, [None] * len(inconsist_exs2) | |||
| ) | |||
| print(res) | |||
| res = reasoner.batch_abduce(data_sample_hed) | |||
| for r in res: | |||
| print(r) | |||
| print() | |||
| print("HedReasoner abduce rules") | |||