|
|
|
@@ -26,6 +26,10 @@ import pyswip |
|
|
|
|
|
|
|
class KBBase(ABC): |
|
|
|
def __init__(self, pseudo_label_list, len_list=None, GKB_flag=False, max_err=0): |
|
|
|
# TODO:添加一下类型检查,比如 |
|
|
|
# if not isinstance(X, (np.ndarray, spmatrix)): |
|
|
|
# raise TypeError("X should be numpy array or sparse matrix") |
|
|
|
|
|
|
|
self.pseudo_label_list = pseudo_label_list |
|
|
|
self.len_list = len_list |
|
|
|
self.GKB_flag = GKB_flag |
|
|
|
@@ -161,6 +165,7 @@ class KBBase(ABC): |
|
|
|
new_candidates += candidates |
|
|
|
return new_candidates |
|
|
|
|
|
|
|
# TODO:在类初始化时应该有一个cache(默认True)的参数,用户可以指定是否用cache(若KB会变,那不能用cache) |
|
|
|
@lru_cache(maxsize=100) |
|
|
|
def _abduce_by_search(self, pred_res, key, max_address_num, require_more_address, multiple_predictions): |
|
|
|
pred_res = hashable_to_list(pred_res) |
|
|
|
|