From 1a38478bda69a51e23136a04eaada3fb439e93e1 Mon Sep 17 00:00:00 2001 From: troyyyyy <49091847+troyyyyy@users.noreply.github.com> Date: Mon, 21 Nov 2022 09:30:30 +0800 Subject: [PATCH] Update kb.py --- abducer/kb.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/abducer/kb.py b/abducer/kb.py index 54a367e..a934ad8 100644 --- a/abducer/kb.py +++ b/abducer/kb.py @@ -77,9 +77,9 @@ class ClsKB(KBBase): if key is None: return self.get_all_candidates() - length = self._length(length) - if(max(self.len_list) < min(length)): + if (type(length) is int and length not in self.len_list): return [] + length = self._length(length) return sum([self.base[l][key] for l in length], []) def get_all_candidates(self): @@ -94,7 +94,7 @@ class ClsKB(KBBase): class add_KB(ClsKB): - def __init__(self, len_list = None): + def __init__(self, len_list = [2]): self.pseudo_label_list = list(range(10)) super().__init__(self.pseudo_label_list, len_list) @@ -103,7 +103,7 @@ class add_KB(ClsKB): class hwf_KB(ClsKB): - def __init__(self, len_list = None): + def __init__(self, len_list = [1, 3, 5, 7]): self.pseudo_label_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '-', '*', '/'] super().__init__(self.pseudo_label_list, len_list)