Browse Source

[MNT] unify variable names

pull/3/head
Gao Enhao 2 years ago
parent
commit
6ac0bb9378
3 changed files with 21 additions and 22 deletions
  1. +2
    -2
      abl/bridge/base_bridge.py
  2. +16
    -17
      abl/bridge/simple_bridge.py
  3. +3
    -3
      abl/reasoning/reasoner.py

+ 2
- 2
abl/bridge/base_bridge.py View File

@@ -26,12 +26,12 @@ class BaseBridge(metaclass=ABCMeta):
"""Placeholder for abduce pseudo labels.""" """Placeholder for abduce pseudo labels."""


@abstractmethod @abstractmethod
def label_to_pseudo_label(self, label: List[List[Any]]) -> List[List[Any]]:
def idx_to_pseudo_label(self, idx: List[List[Any]]) -> List[List[Any]]:
"""Placeholder for map label space to symbol space.""" """Placeholder for map label space to symbol space."""
pass pass


@abstractmethod @abstractmethod
def pseudo_label_to_label(self, pseudo_label: List[List[Any]]) -> List[List[Any]]:
def pseudo_label_to_idx(self, pseudo_label: List[List[Any]]) -> List[List[Any]]:
"""Placeholder for map symbol space to label space.""" """Placeholder for map symbol space to label space."""
pass pass


+ 16
- 17
abl/bridge/simple_bridge.py View File

@@ -22,28 +22,27 @@ class SimpleBridge(BaseBridge):


def predict(self, X) -> Tuple[List[List[Any]], ndarray]: def predict(self, X) -> Tuple[List[List[Any]], ndarray]:
pred_res = self.model.predict(X) pred_res = self.model.predict(X)
pred_label, pred_prob = pred_res["label"], pred_res["prob"]
return pred_label, pred_prob
pred_idx, pred_prob = pred_res["label"], pred_res["prob"]
return pred_idx, pred_prob
def abduce_pseudo_label( def abduce_pseudo_label(
self, self,
pred_label: List[List[Any]],
pred_prob: ndarray, pred_prob: ndarray,
pseudo_label: List[List[Any]],
Y: List[List[Any]],
pred_pseudo_label: List[List[Any]],
Y: List[Any],
max_revision: int = -1, max_revision: int = -1,
require_more_revision: int = 0, require_more_revision: int = 0,
) -> List[List[Any]]: ) -> List[List[Any]]:
return self.abducer.batch_abduce(pred_prob, pseudo_label, Y, max_revision, require_more_revision)
return self.abducer.batch_abduce(pred_prob, pred_pseudo_label, Y, max_revision, require_more_revision)


def label_to_pseudo_label(
self, label: List[List[Any]], mapping: Dict = None
def idx_to_pseudo_label(
self, idx: List[List[Any]], mapping: Dict = None
) -> List[List[Any]]: ) -> List[List[Any]]:
if mapping is None: if mapping is None:
mapping = self.abducer.mapping mapping = self.abducer.mapping
return [[mapping[_label] for _label in sub_list] for sub_list in label]
return [[mapping[_idx] for _idx in sub_list] for sub_list in idx]


def pseudo_label_to_label(
def pseudo_label_to_idx(
self, pseudo_label: List[List[Any]], mapping: Dict = None self, pseudo_label: List[List[Any]], mapping: Dict = None
) -> List[List[Any]]: ) -> List[List[Any]]:
if mapping is None: if mapping is None:
@@ -69,12 +68,12 @@ class SimpleBridge(BaseBridge):


for epoch in range(epochs): for epoch in range(epochs):
for seg_idx, (X, Z, Y) in enumerate(data_loader): for seg_idx, (X, Z, Y) in enumerate(data_loader):
pred_label, pred_prob = self.predict(X)
pred_pseudo_label = self.label_to_pseudo_label(pred_label)
pred_idx, pred_prob = self.predict(X)
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
abduced_pseudo_label = self.abduce_pseudo_label( abduced_pseudo_label = self.abduce_pseudo_label(
pred_label, pred_prob, pred_pseudo_label, Y
pred_prob, pred_pseudo_label, Y
) )
abduced_label = self.pseudo_label_to_label(abduced_pseudo_label)
abduced_label = self.pseudo_label_to_idx(abduced_pseudo_label)
min_loss = self.model.train(X, abduced_label) min_loss = self.model.train(X, abduced_label)


print_log( print_log(
@@ -88,10 +87,10 @@ class SimpleBridge(BaseBridge):


def _valid(self, data_loader): def _valid(self, data_loader):
for X, Z, Y in data_loader: for X, Z, Y in data_loader:
pred_label, pred_prob = self.predict(X)
pred_pseudo_label = self.label_to_pseudo_label(pred_label)
pred_idx, pred_prob = self.predict(X)
pred_pseudo_label = self.idx_to_pseudo_label(pred_idx)
data_samples = dict( data_samples = dict(
pred_label=pred_label,
pred_idx=pred_idx,
pred_prob=pred_prob, pred_prob=pred_prob,
pred_pseudo_label=pred_pseudo_label, pred_pseudo_label=pred_pseudo_label,
gt_pseudo_label=Z, gt_pseudo_label=Z,


+ 3
- 3
abl/reasoning/reasoner.py View File

@@ -119,13 +119,13 @@ class ReasonerBase():
solution = Opt.min(objective, parameter).get_x() solution = Opt.min(objective, parameter).get_x()
return solution return solution


def revise_by_idx(self, pseudo_label, y, revision_idx):
def revise_by_idx(self, pred_pseudo_label, y, revision_idx):
""" """
Get the revisions corresponding to the given indices. Get the revisions corresponding to the given indices.


Parameters Parameters
---------- ----------
pseudo_label : list
pred_pseudo_label : list
List of predicted pseudo labels. List of predicted pseudo labels.
y : str y : str
Ground truth for the predicted results. Ground truth for the predicted results.
@@ -137,7 +137,7 @@ class ReasonerBase():
list list
The revisions corresponding to the given indices. The revisions corresponding to the given indices.
""" """
return self.kb.revise_by_idx(pseudo_label, y, revision_idx)
return self.kb.revise_by_idx(pred_pseudo_label, y, revision_idx)


def abduce(self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0): def abduce(self, pred_prob, pred_pseudo_label, y, max_revision=-1, require_more_revision=0):
""" """


Loading…
Cancel
Save