diff --git a/abl/abducer/kb.py b/abl/abducer/kb.py index fb606f3..8bb7854 100644 --- a/abl/abducer/kb.py +++ b/abl/abducer/kb.py @@ -26,6 +26,10 @@ import pyswip class KBBase(ABC): def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0): + # TODO:添加一下类型检查,比如 + # if not isinstance(X, (np.ndarray, spmatrix)): + # raise TypeError("X should be numpy array or sparse matrix") + self.pseudo_label_list = pseudo_label_list self.len_list = len_list self.GKB_flag = GKB_flag @@ -161,6 +165,7 @@ class KBBase(ABC): new_candidates += candidates return new_candidates + # TODO:在类初始化时应该有一个cache(默认True)的参数,用户可以指定是否用cache(若KB会变,那不能用cache) @lru_cache(maxsize=100) def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): pred_res = hashable_to_list(pred_res)