|
|
|
@@ -30,19 +30,19 @@ class ReasonerBase(): |
|
|
|
self.mapping = mapping |
|
|
|
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys())) |
|
|
|
|
|
|
|
def _get_cost_list(self, pseudo_label, pred_res_prob, candidates): |
|
|
|
def _get_cost_list(self, pseudo_label, pred_prob, candidates): |
|
|
|
""" |
|
|
|
Get the list consisting of costs between each pseudo label and candidate. |
|
|
|
Parameter `pred_res_prob` is needed while using confidence distance. |
|
|
|
Parameter `pred_prob` is needed while using confidence distance. |
|
|
|
""" |
|
|
|
if self.dist_func == "hamming": |
|
|
|
return hamming_dist(pseudo_label, candidates) |
|
|
|
|
|
|
|
elif self.dist_func == "confidence": |
|
|
|
candidates = [[self.remapping[x] for x in c] for c in candidates] |
|
|
|
return confidence_dist(pred_res_prob, candidates) |
|
|
|
return confidence_dist(pred_prob, candidates) |
|
|
|
|
|
|
|
def _get_one_candidate(self, pseudo_label, pred_res_prob, candidates): |
|
|
|
def _get_one_candidate(self, pseudo_label, pred_prob, candidates): |
|
|
|
""" |
|
|
|
Get one candidate. If multiple candidates exist, return the one with minimum cost. |
|
|
|
""" |
|
|
|
@@ -51,11 +51,11 @@ class ReasonerBase(): |
|
|
|
elif len(candidates) == 1: |
|
|
|
return candidates[0] |
|
|
|
else: |
|
|
|
cost_list = self._get_cost_list(pseudo_label, pred_res_prob, candidates) |
|
|
|
cost_list = self._get_cost_list(pseudo_label, pred_prob, candidates) |
|
|
|
candidate = candidates[np.argmin(cost_list)] |
|
|
|
return candidate |
|
|
|
|
|
|
|
def zoopt_revision_score(self, pred_res, pseudo_label, pred_res_prob, y, sol): |
|
|
|
def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol): |
|
|
|
""" |
|
|
|
Get the revision score for a single solution. |
|
|
|
|
|
|
|
@@ -65,9 +65,9 @@ class ReasonerBase(): |
|
|
|
Solution to evaluate. |
|
|
|
pred_res : list |
|
|
|
List of predicted results. |
|
|
|
pseudo_label : list |
|
|
|
pred_pseudo_label : list |
|
|
|
List of predicted pseudo labels. |
|
|
|
pred_res_prob : list |
|
|
|
pred_prob : list |
|
|
|
List of probabilities for predicted results. |
|
|
|
y : str |
|
|
|
Ground truth for the predicted results. |
|
|
|
@@ -78,26 +78,26 @@ class ReasonerBase(): |
|
|
|
The revision score for the given solution. |
|
|
|
""" |
|
|
|
revision_idx = np.where(sol.get_x() != 0)[0] |
|
|
|
candidates = self.revise_by_idx(pseudo_label, y, revision_idx) |
|
|
|
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) |
|
|
|
if len(candidates) > 0: |
|
|
|
return np.min(self._get_cost_list(pseudo_label, pred_res_prob, candidates)) |
|
|
|
return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates)) |
|
|
|
else: |
|
|
|
return len(pred_res) |
|
|
|
return symbol_num |
|
|
|
|
|
|
|
def _constrain_revision_num(self, solution, max_revision_num): |
|
|
|
x = solution.get_x() |
|
|
|
return max_revision_num - x.sum() |
|
|
|
|
|
|
|
def zoopt_get_solution(self, pred_res, pseudo_label, pred_res_prob, y, max_revision_num): |
|
|
|
def zoopt_get_solution(self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num): |
|
|
|
"""Get the optimal solution using the Zoopt library. |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
pred_res : list |
|
|
|
List of predicted results. |
|
|
|
pseudo_label : list |
|
|
|
pred_pseudo_label : list |
|
|
|
List of predicted pseudo labels. |
|
|
|
pred_res_prob : list |
|
|
|
pred_prob : list |
|
|
|
List of probabilities for predicted results. |
|
|
|
y : str |
|
|
|
Ground truth for the predicted results. |
|
|
|
@@ -109,10 +109,9 @@ class ReasonerBase(): |
|
|
|
array-like |
|
|
|
The optimal solution. |
|
|
|
""" |
|
|
|
length = len(flatten(pred_res)) |
|
|
|
dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length) |
|
|
|
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num) |
|
|
|
objective = Objective( |
|
|
|
lambda sol: self.zoopt_revision_score(pred_res, pseudo_label, pred_res_prob, y, sol), |
|
|
|
lambda sol: self.zoopt_revision_score(symbol_num, pred_pseudo_label, pred_prob, y, sol), |
|
|
|
dim=dimension, |
|
|
|
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num), |
|
|
|
) |
|
|
|
@@ -140,7 +139,7 @@ class ReasonerBase(): |
|
|
|
""" |
|
|
|
return self.kb.revise_by_idx(pseudo_label, y, revision_idx) |
|
|
|
|
|
|
|
def abduce(self, data, max_revision=-1, require_more_revision=0): |
|
|
|
def abduce(self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0): |
|
|
|
""" |
|
|
|
Perform abduction on the given data. |
|
|
|
|
|
|
|
@@ -159,11 +158,11 @@ class ReasonerBase(): |
|
|
|
list |
|
|
|
The abduced revisions. |
|
|
|
""" |
|
|
|
pred_label, pred_prob, pred_pseudo_label, y = data |
|
|
|
max_revision_num = float_parameter(max_revision, len(flatten(pred_label))) |
|
|
|
symbol_num = len(flatten(pred_pseudo_label)) |
|
|
|
max_revision_num = float_parameter(max_revision, symbol_num) |
|
|
|
|
|
|
|
if self.use_zoopt: |
|
|
|
solution = self.zoopt_get_solution(pred_label, pred_pseudo_label, pred_prob, y, max_revision_num) |
|
|
|
solution = self.zoopt_get_solution(symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num) |
|
|
|
revision_idx = np.where(solution != 0)[0] |
|
|
|
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx) |
|
|
|
else: |
|
|
|
@@ -172,7 +171,7 @@ class ReasonerBase(): |
|
|
|
candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates) |
|
|
|
return candidate |
|
|
|
|
|
|
|
def batch_abduce(self, pred_label, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0): |
|
|
|
def batch_abduce(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0): |
|
|
|
"""Perform abduction on the given data in batches. |
|
|
|
|
|
|
|
Parameters |
|
|
|
@@ -192,10 +191,9 @@ class ReasonerBase(): |
|
|
|
list |
|
|
|
The abduced revisions. |
|
|
|
""" |
|
|
|
return [ |
|
|
|
self.abduce(data, max_revision, require_more_revision) |
|
|
|
for data in zip(pred_label, pred_prob, pred_pseudo_label, Y) |
|
|
|
] |
|
|
|
return [self.abduce(_pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision) |
|
|
|
for _pred_prob, _pred_pseudo_label, _Y in zip(pred_prob, pred_pseudo_label, Y)] |
|
|
|
|
|
|
|
|
|
|
|
# def _batch_abduce_helper(self, args): |
|
|
|
# z, prob, y, max_revision, require_more_revision = args |
|
|
|
@@ -206,8 +204,8 @@ class ReasonerBase(): |
|
|
|
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)]) |
|
|
|
# return results |
|
|
|
|
|
|
|
def __call__(self, Z, Y, max_revision=-1, require_more_revision=0): |
|
|
|
return self.batch_abduce(Z, Y, max_revision, require_more_revision) |
|
|
|
def __call__(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0): |
|
|
|
return self.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
@@ -570,7 +568,7 @@ if __name__ == "__main__": |
|
|
|
candidate = self.revise_by_idx(pred, k, revision_idx) |
|
|
|
return candidate |
|
|
|
|
|
|
|
def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol): |
|
|
|
def zoopt_revision_score(self, pred_res, pred_prob, y, sol): |
|
|
|
all_revision_flag = reform_idx(sol.get_x(), pred_res) |
|
|
|
lefted_idxs = [i for i in range(len(pred_res))] |
|
|
|
candidate_size = [] |
|
|
|
|