Browse Source

[MNT] delete redundant parameter "pred_label"

pull/3/head
Gao Enhao 2 years ago
parent
commit
75c72bedf5
2 changed files with 29 additions and 31 deletions
  1. +2
    -2
      abl/bridge/simple_bridge.py
  2. +27
    -29
      abl/reasoning/reasoner.py

+ 2
- 2
abl/bridge/simple_bridge.py View File

@@ -15,7 +15,7 @@ class SimpleBridge(BaseBridge):
self,
model: ABLModel,
abducer: ReasonerBase,
metric_list: BaseMetric,
metric_list: List[BaseMetric],
) -> None:
super().__init__(model, abducer)
self.metric_list = metric_list
@@ -34,7 +34,7 @@ class SimpleBridge(BaseBridge):
max_revision: int = -1,
require_more_revision: int = 0,
) -> List[List[Any]]:
return self.abducer.batch_abduce(pred_label, pred_prob, pseudo_label, Y, max_revision, require_more_revision)
return self.abducer.batch_abduce(pred_prob, pseudo_label, Y, max_revision, require_more_revision)

def label_to_pseudo_label(
self, label: List[List[Any]], mapping: Dict = None


+ 27
- 29
abl/reasoning/reasoner.py View File

@@ -30,19 +30,19 @@ class ReasonerBase():
self.mapping = mapping
self.remapping = dict(zip(self.mapping.values(), self.mapping.keys()))

def _get_cost_list(self, pseudo_label, pred_res_prob, candidates):
def _get_cost_list(self, pseudo_label, pred_prob, candidates):
"""
Get the list consisting of costs between each pseudo label and candidate.
Parameter `pred_res_prob` is needed while using confidence distance.
Parameter `pred_prob` is needed while using confidence distance.
"""
if self.dist_func == "hamming":
return hamming_dist(pseudo_label, candidates)

elif self.dist_func == "confidence":
candidates = [[self.remapping[x] for x in c] for c in candidates]
return confidence_dist(pred_res_prob, candidates)
return confidence_dist(pred_prob, candidates)

def _get_one_candidate(self, pseudo_label, pred_res_prob, candidates):
def _get_one_candidate(self, pseudo_label, pred_prob, candidates):
"""
Get one candidate. If multiple candidates exist, return the one with minimum cost.
"""
@@ -51,11 +51,11 @@ class ReasonerBase():
elif len(candidates) == 1:
return candidates[0]
else:
cost_list = self._get_cost_list(pseudo_label, pred_res_prob, candidates)
cost_list = self._get_cost_list(pseudo_label, pred_prob, candidates)
candidate = candidates[np.argmin(cost_list)]
return candidate

def zoopt_revision_score(self, pred_res, pseudo_label, pred_res_prob, y, sol):
def zoopt_revision_score(self, symbol_num, pred_pseudo_label, pred_prob, y, sol):
"""
Get the revision score for a single solution.

@@ -65,9 +65,9 @@ class ReasonerBase():
Solution to evaluate.
pred_res : list
List of predicted results.
pseudo_label : list
pred_pseudo_label : list
List of predicted pseudo labels.
pred_res_prob : list
pred_prob : list
List of probabilities for predicted results.
y : str
Ground truth for the predicted results.
@@ -78,26 +78,26 @@ class ReasonerBase():
The revision score for the given solution.
"""
revision_idx = np.where(sol.get_x() != 0)[0]
candidates = self.revise_by_idx(pseudo_label, y, revision_idx)
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx)
if len(candidates) > 0:
return np.min(self._get_cost_list(pseudo_label, pred_res_prob, candidates))
return np.min(self._get_cost_list(pred_pseudo_label, pred_prob, candidates))
else:
return len(pred_res)
return symbol_num

def _constrain_revision_num(self, solution, max_revision_num):
x = solution.get_x()
return max_revision_num - x.sum()

def zoopt_get_solution(self, pred_res, pseudo_label, pred_res_prob, y, max_revision_num):
def zoopt_get_solution(self, symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num):
"""Get the optimal solution using the Zoopt library.

Parameters
----------
pred_res : list
List of predicted results.
pseudo_label : list
pred_pseudo_label : list
List of predicted pseudo labels.
pred_res_prob : list
pred_prob : list
List of probabilities for predicted results.
y : str
Ground truth for the predicted results.
@@ -109,10 +109,9 @@ class ReasonerBase():
array-like
The optimal solution.
"""
length = len(flatten(pred_res))
dimension = Dimension(size=length, regs=[[0, 1]] * length, tys=[False] * length)
dimension = Dimension(size=symbol_num, regs=[[0, 1]] * symbol_num, tys=[False] * symbol_num)
objective = Objective(
lambda sol: self.zoopt_revision_score(pred_res, pseudo_label, pred_res_prob, y, sol),
lambda sol: self.zoopt_revision_score(symbol_num, pred_pseudo_label, pred_prob, y, sol),
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
@@ -140,7 +139,7 @@ class ReasonerBase():
"""
return self.kb.revise_by_idx(pseudo_label, y, revision_idx)

def abduce(self, data, max_revision=-1, require_more_revision=0):
def abduce(self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0):
"""
Perform abduction on the given data.

@@ -159,11 +158,11 @@ class ReasonerBase():
list
The abduced revisions.
"""
pred_label, pred_prob, pred_pseudo_label, y = data
max_revision_num = float_parameter(max_revision, len(flatten(pred_label)))
symbol_num = len(flatten(pred_pseudo_label))
max_revision_num = float_parameter(max_revision, symbol_num)

if self.use_zoopt:
solution = self.zoopt_get_solution(pred_label, pred_pseudo_label, pred_prob, y, max_revision_num)
solution = self.zoopt_get_solution(symbol_num, pred_pseudo_label, pred_prob, y, max_revision_num)
revision_idx = np.where(solution != 0)[0]
candidates = self.revise_by_idx(pred_pseudo_label, y, revision_idx)
else:
@@ -172,7 +171,7 @@ class ReasonerBase():
candidate = self._get_one_candidate(pred_pseudo_label, pred_prob, candidates)
return candidate

def batch_abduce(self, pred_label, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0):
def batch_abduce(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0):
"""Perform abduction on the given data in batches.

Parameters
@@ -192,10 +191,9 @@ class ReasonerBase():
list
The abduced revisions.
"""
return [
self.abduce(data, max_revision, require_more_revision)
for data in zip(pred_label, pred_prob, pred_pseudo_label, Y)
]
return [self.abduce(_pred_prob, _pred_pseudo_label, _Y, max_revision, require_more_revision)
for _pred_prob, _pred_pseudo_label, _Y in zip(pred_prob, pred_pseudo_label, Y)]


# def _batch_abduce_helper(self, args):
# z, prob, y, max_revision, require_more_revision = args
@@ -206,8 +204,8 @@ class ReasonerBase():
# results = pool.map(self._batch_abduce_helper, [(z, prob, y, max_revision, require_more_revision) for z, prob, y in zip(Z['cls'], Z['prob'], Y)])
# return results

def __call__(self, Z, Y, max_revision=-1, require_more_revision=0):
return self.batch_abduce(Z, Y, max_revision, require_more_revision)
def __call__(self, pred_prob, pred_pseudo_label, Y, max_revision=-1, require_more_revision=0):
return self.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision)


if __name__ == "__main__":
@@ -570,7 +568,7 @@ if __name__ == "__main__":
candidate = self.revise_by_idx(pred, k, revision_idx)
return candidate

def zoopt_revision_score(self, pred_res, pred_res_prob, y, sol):
def zoopt_revision_score(self, pred_res, pred_prob, y, sol):
all_revision_flag = reform_idx(sol.get_x(), pred_res)
lefted_idxs = [i for i in range(len(pred_res))]
candidate_size = []


Loading…
Cancel
Save