| @@ -20,14 +20,13 @@ class ReasonerBase: | |||
| dist_func : str, optional | |||
| The distance function to be used when determining the cost list between each | |||
| candidate and the given prediction. Valid options include: `"hamming"` (default) | |||
| | `"confidence"`. Any other options will raise a `NotImplementedError`. | |||
| For detailed explanations of these options, refer to `_get_cost_list`. | |||
| mapping : dictt, optional | |||
| A mapping from label to index. If not provided, a default | |||
| order-based mapping is created. | |||
| | `"confidence"`. Any other options will raise a `NotImplementedError`. For | |||
| detailed explanations of these options, refer to `_get_cost_list`. | |||
| mapping : dict, optional | |||
| A mapping from label to index. If not provided, a default order-based mapping is | |||
| created. | |||
| use_zoopt : bool, optional | |||
| Whether to use the Zoopt library during abductive reasoning. | |||
| Default is False. | |||
| Whether to use the Zoopt library during abductive reasoning. Default to False. | |||
| """ | |||
| def __init__(self, kb, dist_func="hamming", mapping=None, use_zoopt=False): | |||
| @@ -42,12 +41,14 @@ class ReasonerBase: | |||
| label: index for index, label in enumerate(self.kb.pseudo_label_list) | |||
| } | |||
| else: | |||
| if not isinstance(mapping, dict): | |||
| raise TypeError("mapping should be dict") | |||
| self.mapping = mapping | |||
| def _get_one_candidate(self, pred_pseudo_label, pred_prob, candidates): | |||
| """ | |||
| Due to the nondeterminism of abductive reasoning, there could be multiple candidates | |||
| satisfying the knowledge base. If this happens, return one candidate that has the | |||
| satisfying the knowledge base. When this happens, return one candidate that has the | |||
| minimum cost. If no candidates are provided, an empty list is returned. | |||
| Parameters | |||
| @@ -58,7 +59,7 @@ class ReasonerBase: | |||
| Predicted probabilities of the prediction (Each sublist contains the probability | |||
| values of all pseudo labels). | |||
| candidates : List[List[Any]] | |||
| Several candidate abduction results. | |||
| Multiple candidate abduction results. | |||
| """ | |||
| if len(candidates) == 0: | |||
| return [] | |||
| @@ -86,7 +87,7 @@ class ReasonerBase: | |||
| Predicted probabilities of the prediction (Each sublist contains the probability | |||
| values of all pseudo labels). Used when distance function is "confidence". | |||
| candidates : List[List[Any]] | |||
| Several candidate abduction results. | |||
| Multiple candidate abduction results. | |||
| """ | |||
| if self.dist_func == "hamming": | |||
| return hamming_dist(pred_pseudo_label, candidates) | |||