diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 0fc9a5c..63a4d29 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -88,7 +88,11 @@ class KBBase(ABC): """ if logic_result == None: return False - return abs(logic_result - y) <= self.max_err + + if isinstance(logic_result, (int, float)) and isinstance(y, (int, float)): + return abs(logic_result - y) <= self.max_err + else: + return logic_result == y def revise_at_idx(self, pred_pseudo_label, y, revision_idx): """