| @@ -4,8 +4,7 @@ import numpy as np | |||
| from zoopt import Dimension, Objective, Opt, Parameter, Solution | |||
| from ..structures import ListData | |||
| from ..utils.utils import (calculate_revision_num, confidence_dist, | |||
| hamming_dist, reform_idx) | |||
| from ..utils.utils import calculate_revision_num, confidence_dist, hamming_dist, reform_idx | |||
| from .base_kb import BaseKB | |||
| @@ -13,7 +12,7 @@ class ReasonerBase: | |||
| def __init__( | |||
| self, | |||
| kb: BaseKB, | |||
| dist_func: str = "hamming", | |||
| dist_func: str = "confidence", | |||
| mapping: Mapping = None, | |||
| use_zoopt: bool = False, | |||
| ): | |||
| @@ -25,7 +24,7 @@ class ReasonerBase: | |||
| kb : BaseKB | |||
| The knowledge base to be used for reasoning. | |||
| dist_func : str, optional | |||
| The distance function to be used. Can be "hamming" or "confidence". Default is "hamming". | |||
| The distance function to be used. Can be "hamming" or "confidence". Default is "confidence". | |||
| mapping : dict, optional | |||
| A mapping of indices to labels. If None, a default mapping is generated. | |||
| use_zoopt : bool, optional | |||
| @@ -37,8 +36,8 @@ class ReasonerBase: | |||
| If the specified distance function is neither "hamming" nor "confidence". | |||
| """ | |||
| if not (dist_func == "hamming" or dist_func == "confidence"): | |||
| raise NotImplementedError # Only hamming or confidence distance is available. | |||
| if dist_func not in ["hamming", "confidence"]: | |||
| raise NotImplementedError(f"The distance function '{dist_func}' is not implemented.") | |||
| self.kb = kb | |||
| self.dist_func = dist_func | |||
| @@ -46,7 +45,18 @@ class ReasonerBase: | |||
| if mapping is None: | |||
| self.mapping = {index: label for index, label in enumerate(self.kb.pseudo_label_list)} | |||
| else: | |||
| if not isinstance(mapping, dict): | |||
| raise ValueError("mapping must be of type dict") | |||
| for key, value in mapping.items(): | |||
| if not isinstance(key, int): | |||
| raise ValueError("All keys in the mapping must be integers") | |||
| if value not in self.kb.pseudo_label_list: | |||
| raise ValueError("All values in the mapping must be in the pseudo_label_list") | |||
| self.mapping = mapping | |||
| self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) | |||
| def _get_cost_list(self, data_sample: ListData, candidates: List[List[Any]]): | |||