Browse Source

[FIX] pass flake8

pull/1/head
troyyyyy 2 years ago
parent
commit
d125509f78
35 changed files with 629 additions and 405 deletions
  1. +1
    -1
      .github/workflows/lint.yaml
  2. +14
    -12
      abl/bridge/simple_bridge.py
  3. +86
    -45
      abl/data/data_converter.py
  4. +2
    -1
      abl/data/evaluation/reasoning_metric.py
  5. +2
    -2
      abl/data/structures/list_data.py
  6. +14
    -13
      abl/learning/basic_nn.py
  7. +64
    -22
      abl/learning/model_converter.py
  8. +88
    -76
      abl/reasoning/kb.py
  9. +26
    -21
      abl/reasoning/reasoner.py
  10. +8
    -1
      abl/utils/__init__.py
  11. +1
    -1
      abl/utils/cache.py
  12. +11
    -8
      abl/utils/logger.py
  13. +15
    -14
      abl/utils/utils.py
  14. +12
    -10
      docs/conf.py
  15. +1
    -1
      examples/hed/bridge.py
  16. +13
    -10
      examples/hed/datasets/get_dataset.py
  17. +31
    -23
      examples/hed/hed.ipynb
  18. +4
    -4
      examples/hed/main.py
  19. +1
    -1
      examples/hed/reasoning/__init__.py
  20. +7
    -4
      examples/hed/reasoning/reasoning.py
  21. +1
    -1
      examples/hwf/datasets/__init__.py
  22. +12
    -7
      examples/hwf/datasets/get_dataset.py
  23. +56
    -33
      examples/hwf/hwf.ipynb
  24. +4
    -4
      examples/hwf/main.py
  25. +1
    -1
      examples/mnist_add/datasets/__init__.py
  26. +1
    -0
      examples/mnist_add/datasets/get_dataset.py
  27. +2
    -2
      examples/mnist_add/main.py
  28. +19
    -5
      examples/mnist_add/main_with_model_converter.py
  29. +49
    -30
      examples/mnist_add/mnist_add.ipynb
  30. +10
    -7
      examples/zoo/get_dataset.py
  31. +20
    -9
      examples/zoo/kb.py
  32. +16
    -10
      examples/zoo/main.py
  33. +11
    -4
      examples/zoo/zoo.ipynb
  34. +2
    -2
      tests/conftest.py
  35. +24
    -20
      tests/test_reasoning.py

+ 1
- 1
.github/workflows/lint.yaml View File

@@ -21,4 +21,4 @@ jobs:
uses: py-actions/flake8@v2
with:
max-line-length: "100"
args: --ignore=E203,W503
args: --ignore=E203,W503,F821,E266

+ 14
- 12
abl/bridge/simple_bridge.py View File

@@ -164,7 +164,8 @@ class SimpleBridge(BaseBridge):
self, unlabel_data_examples: ListData, label_data_examples: Optional[ListData]
) -> ListData:
"""
Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model.
Concatenate unlabeled and labeled data examples. ``abduced_pseudo_label`` of unlabeled data
examples and ``gt_pseudo_label`` of labeled data examples will be used to train the model.

Parameters
----------
@@ -212,18 +213,19 @@ class SimpleBridge(BaseBridge):
Training data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData``
object with ``X``, ``gt_pseudo_label`` and ``Y`` attributes.
- ``X`` is a list of sublists representing the input data.
- ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but not
to train. ``gt_pseudo_label`` can be ``None``.
- ``Y`` is a list representing the ground truth reasoning result for each sublist in ``X``.
- ``gt_pseudo_label`` is only used to evaluate the performance of the ``ABLModel`` but
not to train. ``gt_pseudo_label`` can be ``None``.
- ``Y`` is a list representing the ground truth reasoning result for each sublist
in ``X``.
label_data : Union[ListData, Tuple[List[List[Any]], List[List[Any]], List[Any]]], optional
Labeled data should be in the same format as ``train_data``. The only difference is
that the ``gt_pseudo_label`` in ``label_data`` should not be ``None`` and will be
utilized to train the model. Defaults to None.
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]], optional # noqa: E501
Validation data should be in the same format as ``train_data``. Both ``gt_pseudo_label``
and ``Y`` can be either None or not, which depends on the evaluation metircs in
``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate the
model during training time. Defaults to None.
``self.metric_list``. If ``val_data`` is None, ``train_data`` will be used to validate
the model during training time. Defaults to None.
loops : int
Machine Learning part and Reasoning part will be iteratively optimized
for ``loops`` times, by default 50.
@@ -325,7 +327,7 @@ class SimpleBridge(BaseBridge):

Parameters
----------
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]]
val_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501
Validation data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object
with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be
either None or not, which depends on the evaluation metircs in ``self.metric_list``.
@@ -344,10 +346,10 @@ class SimpleBridge(BaseBridge):

Parameters
----------
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]]
Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object with ``X``,
``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y`` can be either None or
not, which depends on the evaluation metircs in ``self.metric_list``.
test_data : Union[ListData, Tuple[List[List[Any]], Optional[List[List[Any]]], Optional[List[Any]]]] # noqa: E501
Test data should be in the form of ``(X, gt_pseudo_label, Y)`` or a ``ListData`` object
with ``X``, ``gt_pseudo_label`` and ``Y`` attributes. Both ``gt_pseudo_label`` and ``Y``
can be either None or not, which depends on the evaluation metircs in ``self.metric_list``.
"""
print_log("Test start:", logger="current")
test_data_examples = self.data_preprocess("test", test_data)


+ 86
- 45
abl/data/data_converter.py View File

@@ -4,21 +4,22 @@ from abl.utils import tab_data_to_tuple
from .structures.list_data import ListData
from lambdaLearn.Base.TabularMixin import TabularMixin


class DataConverter:
'''
"""
This class provides functionality to convert LambdaLearn data to ABL-Package data.
'''
"""

def __init__(self) -> None:
pass
def convert_lambdalearn_to_tuple(
self,
dataset: TabularMixin,
reasoning_result: Any
self, dataset: TabularMixin, reasoning_result: Any
) -> Tuple[Tuple, Tuple, Tuple, Tuple]:
'''
Convert a lambdalearn dataset to a tuple of tuples (label_data, train_data, valid_data, test_data), each containing (data, label, reasoning_result).
"""
Convert a lambdalearn dataset to a tuple of tuples (label_data, train_data, valid_data, test_data), # noqa: E501
each containing (data, label, reasoning_result).

Parameters
----------
dataset : TabularMixin
@@ -28,27 +29,38 @@ class DataConverter:
Returns
-------
Tuple[Tuple, Tuple, Tuple, Tuple]
A tuple of (label_data, train_data, valid_data, test_data), where each element is a tuple of (data, label, reasoning_result).
'''
A tuple of (label_data, train_data, valid_data, test_data), where each element is
a tuple of (data, label, reasoning_result).
"""

if not isinstance(dataset, TabularMixin):
raise NotImplementedError("Only support converting the datasets that are instances of TabularMixin. Please refer to the documentation and manually convert the dataset into a tuple. ")
label_data = tab_data_to_tuple(dataset.labeled_X, dataset.labeled_y, reasoning_result=reasoning_result)
train_data = tab_data_to_tuple(dataset.unlabeled_X, dataset.unlabeled_y, reasoning_result=reasoning_result)
valid_data = tab_data_to_tuple(dataset.valid_X, dataset.valid_y, reasoning_result=reasoning_result)
test_data = tab_data_to_tuple(dataset.test_X, dataset.test_y, reasoning_result=reasoning_result)
raise NotImplementedError(
"Only support converting the datasets that are instances of TabularMixin. "
+ "Please refer to the documentation and manually convert the dataset into a tuple."
)

label_data = tab_data_to_tuple(
dataset.labeled_X, dataset.labeled_y, reasoning_result=reasoning_result
)
train_data = tab_data_to_tuple(
dataset.unlabeled_X, dataset.unlabeled_y, reasoning_result=reasoning_result
)
valid_data = tab_data_to_tuple(
dataset.valid_X, dataset.valid_y, reasoning_result=reasoning_result
)
test_data = tab_data_to_tuple(
dataset.test_X, dataset.test_y, reasoning_result=reasoning_result
)

return label_data, train_data, valid_data, test_data
def convert_lambdalearn_to_listdata(
self,
dataset: TabularMixin,
reasoning_result: Any
self, dataset: TabularMixin, reasoning_result: Any
) -> Tuple[ListData, ListData, ListData, ListData]:
'''
Convert a lambdalearn dataset to a tuple of ListData (label_data_examples, train_data_examples, valid_data_examples, test_data_examples).
"""
Convert a lambdalearn dataset to a tuple of ListData
(label_data_examples, train_data_examples, valid_data_examples, test_data_examples).

Parameters
----------
dataset : TabularMixin
@@ -58,14 +70,20 @@ class DataConverter:
Returns
-------
Tuple[ListData, ListData, ListData, ListData]
A tuple of ListData (label_data_examples, train_data_examples, valid_data_examples, test_data_examples)
'''
A tuple of ListData (label_data_examples, train_data_examples, valid_data_examples, test_data_examples) # noqa: E501
"""
if not isinstance(dataset, TabularMixin):
raise NotImplementedError("Only support converting the datasets that are instances of TabularMixin. Please refer to the documentation and manually convert the dataset into a ListData. ")
label_data, train_data, valid_data, test_data = self.convert_lambdalearn_to_tuple(dataset, reasoning_result)
raise NotImplementedError(
"Only support converting the datasets that are instances of TabularMixin. "
+ "Please refer to the documentation and manually convert the dataset "
+ "into a ListData."
)

label_data, train_data, valid_data, test_data = self.convert_lambdalearn_to_tuple(
dataset, reasoning_result
)

if label_data is not None:
X, gt_pseudo_label, Y = label_data
label_data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y)
@@ -78,23 +96,46 @@ class DataConverter:
if test_data is not None:
X, gt_pseudo_label, Y = test_data
test_data_examples = ListData(X=X, gt_pseudo_label=gt_pseudo_label, Y=Y)
return label_data_examples, train_data_examples, valid_data_examples, test_data_examples
if __name__ == '__main__':


if __name__ == "__main__":
from lambdaLearn.Dataset.Tabular.BreastCancer import BreastCancer
breast_dataset=BreastCancer(labeled_size=0.1, stratified=True, shuffle=True)

breast_dataset = BreastCancer(labeled_size=0.1, stratified=True, shuffle=True)
dataconverter = DataConverter()
label_data, train_data, valid_data, test_data = dataconverter.convert_lambdalearn_to_tuple(breast_dataset, 0)
print(type(label_data).__name__, type(train_data).__name__, type(valid_data).__name__, type(test_data).__name__)

label_data, train_data, valid_data, test_data = dataconverter.convert_lambdalearn_to_tuple(
breast_dataset, 0
)
print(
type(label_data).__name__,
type(train_data).__name__,
type(valid_data).__name__,
type(test_data).__name__,
)
print(len(label_data))
print(len(label_data[0]), len(label_data[1]), len(label_data[2]))
print(label_data[0][0], label_data[1][0], label_data[2][0])
print()
label_data_examples, train_data_examples, valid_data_examples, test_data_examples = dataconverter.convert_lambdalearn_to_listdata(breast_dataset, 0)
print(type(label_data_examples).__name__, type(train_data_examples).__name__, type(valid_data_examples).__name__, type(test_data_examples).__name__)
print(len(label_data_examples.X), len(label_data_examples.gt_pseudo_label), len(label_data_examples.Y))

(
label_data_examples,
train_data_examples,
valid_data_examples,
test_data_examples,
) = dataconverter.convert_lambdalearn_to_listdata(breast_dataset, 0)
print(
type(label_data_examples).__name__,
type(train_data_examples).__name__,
type(valid_data_examples).__name__,
type(test_data_examples).__name__,
)
print(
len(label_data_examples.X),
len(label_data_examples.gt_pseudo_label),
len(label_data_examples.Y),
)
label_data_example = label_data_examples[0]
print(label_data_example.X, label_data_example.gt_pseudo_label, label_data_example.Y)
print(label_data_example.X, label_data_example.gt_pseudo_label, label_data_example.Y)

+ 2
- 1
abl/data/evaluation/reasoning_metric.py View File

@@ -38,7 +38,8 @@ class ReasoningMetric(BaseMetric):
"""
Process a batch of data examples.

This method takes in a batch of data examples, each containing predicted pseudo-labels(pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It
This method takes in a batch of data examples, each containing predicted pseudo-labels
(pred_pseudo_label), ground truth of reasoning results (Y), and input data (X). It
evaluates the reasoning accuracy of each example by comparing the logical reasoning
result (derived using the knowledge base) of the predicted pseudo-labels against Y
The result of this comparison (1 for correct reasoning, 0 for incorrect) is appended


+ 2
- 2
abl/data/structures/list_data.py View File

@@ -53,7 +53,7 @@ class ListData(BaseDataElement):
``torch.Tensor``, ``numpy.ndarray``, ``list``, ``str`` and ``tuple``.

This design is inspired by and extends the functionalities of the ``BaseDataElement``
class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_.
class implemented in `MMEngine <https://github.com/open-mmlab/mmengine/blob/main/mmengine/structures/base_data_element.py>`_. # noqa: E501

Examples:
>>> from abl.data.structures import ListData
@@ -71,7 +71,7 @@ class ListData(BaseDataElement):
DATA FIELDS
Y: [1, 2, 3]
gt_pseudo_label: [[1, 2], [3, 4], [5, 6]]
X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]]
X: [[tensor(1.1949), tensor(-0.9378)], [tensor(0.7414), tensor(0.7603)], [tensor(1.0587), tensor(1.9697)]] # noqa: E501
) at 0x7f3bbf1991c0>
>>> print(data_examples[:1])
<ListData(


+ 14
- 13
abl/learning/basic_nn.py View File

@@ -81,7 +81,8 @@ class BasicNN:
if not isinstance(device, torch.device):
if not isinstance(device, str):
raise TypeError(
"device must be an instance of torch.device or a str indicates the target device"
"device must be an instance of torch.device or a str indicating "
+ "the target device"
)
else:
device = torch.device(device)
@@ -163,9 +164,9 @@ class BasicNN:
return self

def fit(
self,
data_loader: Optional[DataLoader] = None,
X: Optional[List[Any]] = None,
self,
data_loader: Optional[DataLoader] = None,
X: Optional[List[Any]] = None,
y: Optional[List[int]] = None,
) -> BasicNN:
"""
@@ -271,8 +272,8 @@ class BasicNN:
return torch.cat(results, axis=0)

def predict(
self,
data_loader: Optional[DataLoader] = None,
self,
data_loader: Optional[DataLoader] = None,
X: Optional[List[Any]] = None,
) -> numpy.ndarray:
"""
@@ -312,8 +313,8 @@ class BasicNN:
return self._predict(data_loader).argmax(axis=1).cpu().numpy()

def predict_proba(
self,
data_loader: Optional[DataLoader] = None,
self,
data_loader: Optional[DataLoader] = None,
X: Optional[List[Any]] = None,
) -> numpy.ndarray:
"""
@@ -403,9 +404,9 @@ class BasicNN:
return mean_loss, accuracy

def score(
self,
data_loader: Optional[DataLoader] = None,
X: Optional[List[Any]] = None,
self,
data_loader: Optional[DataLoader] = None,
X: Optional[List[Any]] = None,
y: Optional[List[int]] = None,
) -> float:
"""
@@ -447,8 +448,8 @@ class BasicNN:

def _data_loader(
self,
X: Optional[List[Any]],
y: Optional[List[int]] = None,
X: Optional[List[Any]],
y: Optional[List[int]] = None,
shuffle: Optional[bool] = True,
) -> DataLoader:
"""


+ 64
- 22
abl/learning/model_converter.py View File

@@ -6,16 +6,18 @@ from .abl_model import ABLModel
from .basic_nn import BasicNN
from lambdaLearn.Base.DeepModelMixin import DeepModelMixin


class ModelConverter:
'''
"""
This class provides functionality to convert LambdaLearn models to ABL-Package models.
'''
"""

def __init__(self) -> None:
pass
def convert_lambdalearn_to_ablmodel(
self,
lambdalearn_model,
lambdalearn_model,
loss_fn: torch.nn.Module,
optimizer_dict: dict,
scheduler_dict: Optional[dict] = None,
@@ -28,11 +30,13 @@ class ModelConverter:
save_dir: Optional[str] = None,
train_transform: Callable[..., Any] = None,
test_transform: Callable[..., Any] = None,
collate_fn: Callable[[List[Any]], Any] = None
collate_fn: Callable[[List[Any]], Any] = None,
):
'''
Convert a lambdalearn model to an ABLModel. If the lambdalearn model is an instance of DeepModelMixin, its network will be used as the model of BasicNN. Otherwise, the lambdalearn model should implement fit and predict methods.
"""
Convert a lambdalearn model to an ABLModel. If the lambdalearn model is an instance of
DeepModelMixin, its network will be used as the model of BasicNN. Otherwise, the lambdalearn
model should implement ``fit`` and ``predict`` methods.

Parameters
----------
lambdalearn_model : Union[DeepModelMixin, Any]
@@ -75,17 +79,34 @@ class ModelConverter:
-------
ABLModel
The converted ABLModel instance.
'''
"""
if isinstance(lambdalearn_model, DeepModelMixin):
base_model = self.convert_lambdalearn_to_basicnn(lambdalearn_model, loss_fn, optimizer_dict, scheduler_dict, device, batch_size, num_epochs, stop_loss, num_workers, save_interval, save_dir, train_transform, test_transform, collate_fn)
base_model = self.convert_lambdalearn_to_basicnn(
lambdalearn_model,
loss_fn,
optimizer_dict,
scheduler_dict,
device,
batch_size,
num_epochs,
stop_loss,
num_workers,
save_interval,
save_dir,
train_transform,
test_transform,
collate_fn,
)
return ABLModel(base_model)
if not (hasattr(lambdalearn_model, "fit") and hasattr(lambdalearn_model, "predict")):
raise NotImplementedError("The lambdalearn_model should be an instance of DeepModelMixin, or implement fit and predict methods.")
raise NotImplementedError(
"The lambdalearn_model should be an instance of DeepModelMixin, or implement "
+ "fit and predict methods."
)

return ABLModel(lambdalearn_model)

def convert_lambdalearn_to_basicnn(
self,
lambdalearn_model: DeepModelMixin,
@@ -103,9 +124,10 @@ class ModelConverter:
test_transform: Callable[..., Any] = None,
collate_fn: Callable[[List[Any]], Any] = None,
):
'''
Convert a lambdalearn model to a BasicNN. If the lambdalearn model is an instance of DeepModelMixin, its network will be used as the model of BasicNN.
"""
Convert a lambdalearn model to a BasicNN. If the lambdalearn model is an instance of
DeepModelMixin, its network will be used as the model of BasicNN.

Parameters
----------
lambdalearn_model : Union[DeepModelMixin, Any]
@@ -147,10 +169,13 @@ class ModelConverter:
-------
BasicNN
The converted BasicNN instance.
'''
"""
if isinstance(lambdalearn_model, DeepModelMixin):
if not isinstance(lambdalearn_model.network, torch.nn.Module):
raise NotImplementedError(f"Expected lambdalearn_model.network to be a torch.nn.Module, but got {type(lambdalearn_model.network)}")
raise NotImplementedError(
"Expected lambdalearn_model.network to be a torch.nn.Module, "
+ f"but got {type(lambdalearn_model.network)}"
)
# Only use the network part and device of the lambdalearn model
network = copy.deepcopy(lambdalearn_model.network)
optimizer_class = optimizer_dict["optimizer"]
@@ -163,7 +188,24 @@ class ModelConverter:
else:
scheduler = None
device = lambdalearn_model.device if device is None else device
base_model = BasicNN(model=network, loss_fn=loss_fn, optimizer=optimizer, scheduler=scheduler, device=device, batch_size=batch_size, num_epochs=num_epochs, stop_loss=stop_loss, num_workers=num_workers, save_interval=save_interval, save_dir=save_dir, train_transform=train_transform, test_transform=test_transform, collate_fn=collate_fn)
base_model = BasicNN(
model=network,
loss_fn=loss_fn,
optimizer=optimizer,
scheduler=scheduler,
device=device,
batch_size=batch_size,
num_epochs=num_epochs,
stop_loss=stop_loss,
num_workers=num_workers,
save_interval=save_interval,
save_dir=save_dir,
train_transform=train_transform,
test_transform=test_transform,
collate_fn=collate_fn,
)
return base_model
else:
raise NotImplementedError("The lambdalearn_model should be an instance of DeepModelMixin.")
raise NotImplementedError(
"The lambdalearn_model should be an instance of DeepModelMixin."
)

+ 88
- 76
abl/reasoning/kb.py View File

@@ -26,7 +26,7 @@ class KBBase(ABC):
list so that each aligns with its corresponding index in the base model: the first with
the 0th index, the second with the 1st, and so forth.
max_err : float, optional
The upper tolerance limit when comparing the similarity between the reasoning result of
The upper tolerance limit when comparing the similarity between the reasoning result of
pseudo-labels and the ground truth. This is only applicable when the reasoning
result is of a numerical type. This is particularly relevant for regression problems where
exact matches might not be feasible. Defaults to 1e-10.
@@ -65,10 +65,12 @@ class KBBase(ABC):
self.use_cache = use_cache
self.key_func = key_func
self.cache_size = cache_size
argspec = inspect.getfullargspec(self.logic_forward)
self._num_args = len(argspec.args) - 1
if self._num_args==2 and self.use_cache: # If the logic_forward function has 2 arguments, then disable cache
if (
self._num_args == 2 and self.use_cache
): # If the logic_forward function has 2 arguments, then disable cache
self.use_cache = False
print_log(
"The logic_forward function has 2 arguments, so the cache is disabled. ",
@@ -89,10 +91,10 @@ class KBBase(ABC):
pseudo_label : List[Any]
Pseudo-labels of an example.
x : List[Any], optional
The example. If deductive logical reasoning does not require any
information from the example, the overridden function provided by the user can omit
The example. If deductive logical reasoning does not require any
information from the example, the overridden function provided by the user can omit
this parameter.
Returns
-------
Any
@@ -100,11 +102,11 @@ class KBBase(ABC):
"""

def abduce_candidates(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
require_more_revision: int,
) -> List[List[Any]]:
"""
@@ -118,7 +120,7 @@ class KBBase(ABC):
Ground truth of the reasoning result for the example.
x : List[Any]
The example. If the information from the example
is not required in the reasoning process, then this parameter will not have
is not required in the reasoning process, then this parameter will not have
any effect.
max_revision_num : int
The upper limit on the number of revised labels for each example.
@@ -129,9 +131,9 @@ class KBBase(ABC):
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo-labels of the example. that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
pseudo-labels of the example. that are compatible with the knowledge base. The second
element is a list of reasoning results corresponding to each candidate, i.e., the
outcome of the ``logic_forward`` function.
"""
return self._abduce_by_search(pseudo_label, y, x, max_revision_num, require_more_revision)

@@ -154,10 +156,10 @@ class KBBase(ABC):
return reasoning_result == y

def revise_at_idx(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
revision_idx: List[int],
) -> List[List[Any]]:
"""
@@ -171,7 +173,7 @@ class KBBase(ABC):
Ground truth of the reasoning result for the example.
x : List[Any]
The example. If the information from the example
is not required in the reasoning process, then this parameter will not have
is not required in the reasoning process, then this parameter will not have
any effect.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo-labels.
@@ -180,9 +182,9 @@ class KBBase(ABC):
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo-labels of the example that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
pseudo-labels of the example that are compatible with the knowledge base. The second
element is a list of reasoning results corresponding to each candidate, i.e., the
outcome of the ``logic_forward`` function.
"""
candidates, reasoning_results = [], []
abduce_c = product(self.pseudo_label_list, repeat=len(revision_idx))
@@ -192,14 +194,15 @@ class KBBase(ABC):
candidate[idx] = c[i]
reasoning_result = self.logic_forward(candidate, *(x,) if self._num_args == 2 else ())
if self._check_equal(reasoning_result, y):
candidates.append(candidate); reasoning_results.append(reasoning_result)
candidates.append(candidate)
reasoning_results.append(reasoning_result)
return candidates, reasoning_results

def _revision(
self,
revision_num: int,
pseudo_label: List[Any],
y: Any,
self,
revision_num: int,
pseudo_label: List[Any],
y: Any,
x: List[Any],
) -> List[List[Any]]:
"""
@@ -210,16 +213,17 @@ class KBBase(ABC):
revision_idx_list = combinations(range(len(pseudo_label)), revision_num)
for revision_idx in revision_idx_list:
candidates, reasoning_results = self.revise_at_idx(pseudo_label, y, x, revision_idx)
new_candidates.extend(candidates); new_reasoning_results.extend(reasoning_results)
new_candidates.extend(candidates)
new_reasoning_results.extend(reasoning_results)
return new_candidates, new_reasoning_results

@abl_cache()
def _abduce_by_search(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
require_more_revision: int,
) -> List[List[Any]]:
"""
@@ -235,7 +239,7 @@ class KBBase(ABC):
Ground truth of the reasoning result for the example.
x : List[Any]
The example. If the information from the example
is not required in the reasoning process, then this parameter will not have
is not required in the reasoning process, then this parameter will not have
any effect.
max_revision_num : int
The upper limit on the number of revisions.
@@ -248,14 +252,15 @@ class KBBase(ABC):
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo-labels of the example that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
pseudo-labels of the example that are compatible with the knowledge base. The second
element is a list of reasoning results corresponding to each candidate, i.e., the
outcome of the ``logic_forward`` function.
"""
candidates, reasoning_results = [], []
for revision_num in range(len(pseudo_label) + 1):
new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x)
candidates.extend(new_candidates); reasoning_results.extend(new_reasoning_results)
candidates.extend(new_candidates)
reasoning_results.extend(new_reasoning_results)
if len(candidates) > 0:
min_revision_num = revision_num
break
@@ -268,7 +273,8 @@ class KBBase(ABC):
if revision_num > max_revision_num:
return candidates, reasoning_results
new_candidates, new_reasoning_results = self._revision(revision_num, pseudo_label, y, x)
candidates.extend(new_candidates); reasoning_results.extend(new_reasoning_results)
candidates.extend(new_candidates)
reasoning_results.extend(new_reasoning_results)
return candidates, reasoning_results

def __repr__(self):
@@ -305,16 +311,19 @@ class GroundKB(KBBase):
"""

def __init__(
self,
pseudo_label_list: List[Any],
GKB_len_list: List[int],
self,
pseudo_label_list: List[Any],
GKB_len_list: List[int],
max_err: float = 1e-10,
):
super().__init__(pseudo_label_list, max_err)
if not isinstance(GKB_len_list, list):
raise TypeError("GKB_len_list should be list, but got {type(GKB_len_list)}")
if self._num_args==2:
raise NotImplementedError(f"GroundKB only supports 1-argument logic_forward, but got {self._num_args}-argument logic_forward")
if self._num_args == 2:
raise NotImplementedError(
"GroundKB only supports 1-argument logic_forward, but got "
+ f"{self._num_args}-argument logic_forward"
)
self.GKB_len_list = GKB_len_list
self.GKB = {}
X, Y = self._get_GKB()
@@ -354,11 +363,11 @@ class GroundKB(KBBase):
return X, Y

def abduce_candidates(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
max_revision_num: int,
require_more_revision: int,
) -> List[List[Any]]:
"""
@@ -383,9 +392,9 @@ class GroundKB(KBBase):
-------
Tuple[List[List[Any]], List[Any]]
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo-labels of THE example that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
pseudo-labels of the example that are compatible with the knowledge base. The second
element is a list of reasoning results corresponding to each candidate, i.e., the
outcome of the ``logic_forward`` function.
"""
if self.GKB == {} or len(pseudo_label) not in self.GKB_len_list:
return [], []
@@ -418,7 +427,8 @@ class GroundKB(KBBase):
all_candidates, all_reasoning_results = [], []
for key in key_list[low_key:high_key]:
for candidate in potential_candidates[key]:
all_candidates.append(candidate); all_reasoning_results.append(key)
all_candidates.append(candidate)
all_reasoning_results.append(key)
else:
all_candidates = self.GKB[len(pseudo_label)][y]
all_reasoning_results = [y] * len(all_candidates)
@@ -468,14 +478,17 @@ class PrologKB(KBBase):

def __init__(self, pseudo_label_list: List[Any], pl_file: str):
super().__init__(pseudo_label_list)
try:
import pyswip
except (IndexError, ImportError):
print("A Prolog-based knowledge base is in use. Please install Swi-Prolog \
using the command 'sudo apt-get install swi-prolog' for Linux users, \
or download it following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md for Windows and Mac users.")
print(
"A Prolog-based knowledge base is in use. Please install Swi-Prolog using the"
+ "command 'sudo apt-get install swi-prolog' for Linux users, or download it "
+ "following the guide in https://github.com/yuce/pyswip/blob/master/INSTALL.md "
+ "for Windows and Mac users."
)

self.prolog = pyswip.Prolog()
self.pl_file = pl_file
if not os.path.exists(self.pl_file):
@@ -519,9 +532,9 @@ class PrologKB(KBBase):
return re.sub(regex, lambda x: x.group().replace("'", ""), str(revision_pseudo_label))

def get_query_string(
self,
pseudo_label: List[Any],
y: Any,
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
revision_idx: List[int],
) -> str:
@@ -538,8 +551,8 @@ class PrologKB(KBBase):
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
The corresponding input example. If the information from the input
is not required in the reasoning process, then this parameter will not have
The corresponding input example. If the information from the input
is not required in the reasoning process, then this parameter will not have
any effect.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo-labels.
@@ -556,10 +569,10 @@ class PrologKB(KBBase):
return query_string

def revise_at_idx(
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
self,
pseudo_label: List[Any],
y: Any,
x: List[Any],
revision_idx: List[int],
) -> List[List[Any]]:
"""
@@ -572,8 +585,8 @@ class PrologKB(KBBase):
y : Any
Ground truth of the reasoning result for the example.
x : List[Any]
The corresponding input example. If the information from the input
is not required in the reasoning process, then this parameter will not have
The corresponding input example. If the information from the input
is not required in the reasoning process, then this parameter will not have
any effect.
revision_idx : List[int]
A list specifying indices of where revisions should be made to the pseudo-labels.
@@ -581,12 +594,10 @@ class PrologKB(KBBase):
Returns
-------
Tuple[List[List[Any]], List[Any]]
A list of candidates, i.e. revised pseudo-labels of the example that are compatible with the
knowledge base.
A tuple of two element. The first element is a list of candidate revisions, i.e. revised
pseudo-labels of the example that are compatible with the knowledge base. The second element is
a list of reasoning results corresponding to each candidate, i.e., the outcome of the
logic_forward function.
pseudo-labels of the example that are compatible with the knowledge base. The second
element is a list of reasoning results corresponding to each candidate, i.e., the
outcome of the ``logic_forward`` function.
"""
candidates, reasoning_results = [], []
query_string = self.get_query_string(pseudo_label, y, x, revision_idx)
@@ -598,7 +609,8 @@ class PrologKB(KBBase):
for i, idx in enumerate(revision_idx):
candidate[idx] = c[i]
candidate = reform_list(candidate, save_pseudo_label)
candidates.append(candidate); reasoning_results.append(y)
candidates.append(candidate)
reasoning_results.append(y)
return candidates, reasoning_results

def __repr__(self):


+ 26
- 21
abl/reasoning/reasoner.py View File

@@ -28,10 +28,10 @@ class Reasoner:
candidate, 'confidence': calculates the distance between the prediction
and each candidate based on confidence derived from the predicted probability
in the data example. The callable function should have the signature
dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must return a cost list. Each element
in this cost list should be a numerical value representing the cost for each
candidate, and the list should have the same length as candidates.
Defaults to 'confidence'.
dist_func(data_example, candidates, candidate_idxs, reasoning_results) and must
return a cost list. Each element in this cost list should be a numerical value
representing the cost for each candidate, and the list should have the same length
as candidates. Defaults to 'confidence'.
idx_to_label : dict, optional
A mapping from index in the base model to label. If not provided, a default
order-based index to label mapping is created. Defaults to None.
@@ -76,14 +76,16 @@ class Reasoner:
if isinstance(dist_func, str):
if dist_func not in ["hamming", "confidence"]:
raise NotImplementedError(
f'Valid options for predefined dist_func include "hamming" and "confidence", but got {dist_func}.'
'Valid options for predefined dist_func include "hamming" '
+ f'and "confidence", but got {dist_func}.'
)
return
elif callable(dist_func):
params = inspect.signature(dist_func).parameters.values()
if len(params) != 4:
raise ValueError(
f"User-defined dist_func must have exactly four parameters, but got {len(params)}."
"User-defined dist_func must have exactly four parameters, "
+ f"but got {len(params)}."
)
return
else:
@@ -99,7 +101,8 @@ class Reasoner:
raise ValueError(f"All keys in the idx_to_label must be integers, but got {key}.")
if value not in self.kb.pseudo_label_list:
raise ValueError(
f"All values in the idx_to_label must be in the pseudo_label_list, but got {value}."
"All values in the idx_to_label must be in the pseudo_label_list, "
+ f"but got {value}."
)

def _get_one_candidate(
@@ -169,8 +172,8 @@ class Reasoner:
cost_list = self.dist_func(data_example, candidates, candidate_idxs, reasoning_results)
if len(cost_list) != len(candidates):
raise ValueError(
f"The length of the array returned by dist_func must be equal to the number of candidates. "
f"Expected length {len(candidates)}, but got {len(cost_list)}."
"The length of the array returned by dist_func must be equal to the number "
+ f"of candidates. Expected length {len(candidates)}, but got {len(cost_list)}."
)
return cost_list

@@ -204,7 +207,9 @@ class Reasoner:
dim=dimension,
constraint=lambda sol: self._constrain_revision_num(sol, max_revision_num),
)
parameter = Parameter(budget=self.zoopt_budget(symbol_num), intermediate_result=False, autoset=True)
parameter = Parameter(
budget=self.zoopt_budget(symbol_num), intermediate_result=False, autoset=True
)
solution = Opt.min(objective, parameter)
return solution

@@ -240,29 +245,28 @@ class Reasoner:
return np.min(self._get_cost_list(data_example, candidates, reasoning_results))
else:
return symbol_num
def zoopt_budget(self, symbol_num: int) -> int:
"""
Set the budget for ZOOpt optimization. The function, in its default implementation,
returns a fixed budget value of 100. However, it can be adjusted to return other fixed
values, or a dynamic budget based on the number of symbols, if desired. For example, one might choose to
set the budget as 100 times symbol_num.
Set the budget for ZOOpt optimization. The function, in its default implementation,
returns a fixed budget value of 100. However, it can be adjusted to return other fixed
values, or a dynamic budget based on the number of symbols, if desired. For example,
one might choose to set the budget as 100 times ``symbol_num``.

Parameters
----------
symbol_num : int
The number of symbols to be considered in the ZOOpt optimization process. Although this parameter
can be used to compute a dynamic optimization budget, by default it is not utilized in the
calculation.
The number of symbols to be considered in the ZOOpt optimization process. Although this
parameter can be used to compute a dynamic optimization budget, by default it is not
utilized in the calculation.

Returns
-------
int
The budget for ZOOpt optimization. By default, this is a fixed value of 100,
The budget for ZOOpt optimization. By default, this is a fixed value of 100,
irrespective of the symbol_num value.
"""
return 100

def _constrain_revision_num(self, solution: Solution, max_revision_num: int) -> int:
"""
@@ -284,7 +288,8 @@ class Reasoner:
elif isinstance(max_revision, float):
if not (0 <= max_revision <= 1):
raise ValueError(
f"If max_revision is a float, it must be between 0 and 1, but got {max_revision}"
"If max_revision is a float, it must be between 0 and 1, "
+ f"but got {max_revision}"
)
return round(symbol_num * max_revision)
else:


+ 8
- 1
abl/utils/__init__.py View File

@@ -1,6 +1,13 @@
from .cache import Cache, abl_cache
from .logger import ABLLogger, print_log
from .utils import confidence_dist, flatten, hamming_dist, reform_list, to_hashable, tab_data_to_tuple
from .utils import (
confidence_dist,
flatten,
hamming_dist,
reform_list,
to_hashable,
tab_data_to_tuple,
)

__all__ = [
"Cache",


+ 1
- 1
abl/utils/cache.py View File

@@ -43,7 +43,7 @@ class Cache(Generic[K, T]):
def get_from_dict(self, obj, *args) -> T:
"""Implements dict based cache."""
# x is not used in cache key
pred_pseudo_label, y, x, *res_args = args
pred_pseudo_label, y, x, *res_args = args
cache_key = (self.key_func(pred_pseudo_label), self.key_func(y), *res_args)
link = self.cache_dict.get(cache_key)
if link is not None:


+ 11
- 8
abl/utils/logger.py View File

@@ -168,9 +168,11 @@ class ABLLogger(Logger, ManagerMixin):
Notes
-----
- The ``name`` of the logger and the ``instance_name`` of ``ABLLogger`` could be different.
``ABLLogger`` instances are retrieved using ``ABLLogger.get_instance``, not ``logging.getLogger``.
This ensures ``ABLLogger`` is not influenced by third-party logging configurations.
- Unlike ``logging.Logger``, ``ABLLogger`` will not log warning or error messages without ``Handler``.
``ABLLogger`` instances are retrieved using ``ABLLogger.get_instance``, not
``logging.getLogger``. This ensures ``ABLLogger`` is not influenced by third-party logging
configurations.
- Unlike ``logging.Logger``, ``ABLLogger`` will not log warning or error messages without
``Handler``.

Examples
--------
@@ -288,15 +290,16 @@ class ABLLogger(Logger, ManagerMixin):


def print_log(
msg,
logger: Optional[Union[Logger, str]] = None,
msg,
logger: Optional[Union[Logger, str]] = None,
level: Optional[int] = logging.INFO,
) -> None:
"""
Print a log message using the specified logger or a default method.

This function logs a message with a given logger, if provided, or prints it using
the standard ``print`` function. It supports special logger types such as 'silent' and 'current'.
the standard ``print`` function. It supports special logger types such as 'silent'
and 'current'.

Parameters
----------
@@ -308,8 +311,8 @@ def print_log(
method is used.
- 'silent': No message will be printed.
- 'current': Use the latest created logger to log the message.
- other str: The instance name of the logger. A ``ValueError`` is raised if the logger has not
been created.
- other str: The instance name of the logger. A ``ValueError`` is raised if the logger has
not been created.
- None: The ``print()`` method is used for logging.
level : int, optional
The logging level. This is only applicable when ``logger`` is a Logger object, 'current',


+ 15
- 14
abl/utils/utils.py View File

@@ -15,7 +15,7 @@ def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[A
Returns
-------
List[Any]
A flattened version of the input list, where only the first
A flattened version of the input list, where only the first
level of sublists and tuples are reduced.
"""
if not isinstance(nested_list, list):
@@ -24,15 +24,15 @@ def flatten(nested_list: List[Union[Any, List[Any], Tuple[Any, ...]]]) -> List[A
flattened_list = []
for item in nested_list:
if isinstance(item, (list, tuple)):
flattened_list.extend(item)
flattened_list.extend(item)
else:
flattened_list.append(item)

return flattened_list


def reform_list(
flattened_list: List[Any],
structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]]
flattened_list: List[Any], structured_list: List[Union[Any, List[Any], Tuple[Any, ...]]]
) -> List[List[Any]]:
"""
Reform the list based on the structure of ``structured_list``.
@@ -148,16 +148,15 @@ def restore_from_hashable(x):
return [restore_from_hashable(item) for item in x]
return x


def tab_data_to_tuple(
X: Union[List[Any], Any],
y: Union[List[Any], Any],
reasoning_result: Optional[Any] = 0
X: Union[List[Any], Any], y: Union[List[Any], Any], reasoning_result: Optional[Any] = 0
) -> Tuple[List[List[Any]], List[List[Any]], List[Any]]:
'''
Convert a tabular data to a tuple by adding a dimension to each element of
"""
Convert a tabular data to a tuple by adding a dimension to each element of
X and y. The tuple contains three elements: data, label, and reasoning result.
If X is None, return None.
Parameters
----------
X : Union[List[Any], Any]
@@ -166,14 +165,16 @@ def tab_data_to_tuple(
The label.
reasoning_result : Any, optional
The reasoning result, by default 0.
Returns
-------
Tuple[List[List[Any]], List[List[Any]], List[Any]]
A tuple of (data, label, reasoning_result).
'''
"""
if X is None:
return None
if len(X) != len(y):
raise ValueError("The length of X and y should be the same, but got {} and {}.".format(len(X), len(y)))
return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y))
raise ValueError(
"The length of X and y should be the same, but got {} and {}.".format(len(X), len(y))
)
return ([[x] for x in X], [[y_item] for y_item in y], [reasoning_result] * len(y))

+ 12
- 10
docs/conf.py View File

@@ -5,27 +5,28 @@ import sys
from docutils import nodes
from docutils.parsers.rst import roles

import re
from sphinx.application import Sphinx


def remove_noqa(app: Sphinx, what: str, name: str, obj, options, lines):
new_lines = []
for line in lines:
new_line = re.sub(r'\s*#\s*noqa.*$', '', line)
new_line = re.sub(r"\s*#\s*noqa.*$", "", line)
new_lines.append(new_line)
lines[:] = new_lines


def colored_text_role(role, rawtext, text, lineno, inliner, options={}, content=[]):
node = nodes.inline(rawtext, text, classes=[role])
return [node], []

roles.register_local_role('green-bold', colored_text_role)
roles.register_local_role('blue-bold', colored_text_role)
roles.register_local_role('yellow-bold', colored_text_role)
roles.register_local_role('green', colored_text_role)
roles.register_local_role('blue', colored_text_role)
roles.register_local_role('yellow', colored_text_role)

roles.register_local_role("green-bold", colored_text_role)
roles.register_local_role("blue-bold", colored_text_role)
roles.register_local_role("yellow-bold", colored_text_role)
roles.register_local_role("green", colored_text_role)
roles.register_local_role("blue", colored_text_role)
roles.register_local_role("yellow", colored_text_role)


if "READTHEDOCS" not in os.environ:
@@ -45,7 +46,7 @@ author = "Author"
extensions = [
"sphinx.ext.intersphinx",
"sphinx.ext.autodoc",
'sphinx.ext.autosummary',
"sphinx.ext.autosummary",
"sphinx.ext.mathjax",
"sphinx.ext.viewcode",
"sphinx_rtd_theme",
@@ -95,7 +96,8 @@ texinfo_documents = [
def setup(app):
from sphinx.domains.python import PyField
from sphinx.util.docfields import Field
app.connect('autodoc-process-docstring', remove_noqa)

app.connect("autodoc-process-docstring", remove_noqa)
app.add_object_type(
"confval",
"confval",


+ 1
- 1
examples/hed/bridge.py View File

@@ -247,7 +247,7 @@ class HedBridge(SimpleBridge):
logger="current",
)
self.model.load(
load_path=os.path.join(save_dir, f"pretrain_weights.pth")
load_path=os.path.join(save_dir, "pretrain_weights.pth")
)
else:
self.model.load(


+ 13
- 10
examples/hed/datasets/get_dataset.py View File

@@ -12,16 +12,20 @@ from torchvision.transforms import transforms

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))


def download_and_unzip(url, zip_file_name):
try:
gdown.download(url, zip_file_name)
with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
with zipfile.ZipFile(zip_file_name, "r") as zip_ref:
zip_ref.extractall(CURRENT_DIR)
os.remove(zip_file_name)
except Exception as e:
if os.path.exists(zip_file_name):
os.remove(zip_file_name)
raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hed/datasets' folder")
raise Exception(
f"An error occurred during download or unzip: {e}. Instead, you can download "
+ f"the dataset from {url} and unzip it in 'examples/hed/datasets' folder"
)


def get_pretrain_data(labels, image_size=(28, 28, 1)):
@@ -33,7 +37,7 @@ def get_pretrain_data(labels, image_size=(28, 28, 1)):
img_path_list = os.listdir(label_path)
for img_path in img_path_list:
with Image.open(osp.join(label_path, img_path)) as img:
img = img.convert('L')
img = img.convert("L")
img = img.resize((image_size[1], image_size[0]))
img_array = np.array(img, dtype=np.float32)
normalized_img = (img_array - 127) / 128.0
@@ -72,19 +76,19 @@ def split_equation(equations_by_len, prop_train, prop_val):


def get_dataset(dataset="mnist", train=True):
data_dir = CURRENT_DIR + '/mnist_images'
data_dir = CURRENT_DIR + "/mnist_images"
if not os.path.exists(data_dir):
print("Dataset not exist, downloading it...")
url = 'https://drive.google.com/u/0/uc?id=1XoJDjO3cNUdytqVgXUKOBe9dOcUBobom&export=download'
url = "https://drive.google.com/u/0/uc?id=1XoJDjO3cNUdytqVgXUKOBe9dOcUBobom&export=download"
download_and_unzip(url, os.path.join(CURRENT_DIR, "HED.zip"))
print("Download and extraction complete.")
if train:
file = os.path.join(data_dir, "expr_train.json")
else:
file = os.path.join(data_dir, "expr_test.json")
if dataset == "mnist":
file = osp.join(CURRENT_DIR, "mnist_equation_data_train_len_26_test_len_26_sys_2_.pk")
elif dataset == "random":
@@ -94,7 +98,7 @@ def get_dataset(dataset="mnist", train=True):

with open(file, "rb") as f:
img_dataset = pickle.load(f)
X, Y = [], []
if train:
positive = img_dataset["train:positive"]
@@ -117,4 +121,3 @@ def get_dataset(dataset="mnist", train=True):

equations_by_len = divide_equations_by_len(X, Y)
return equations_by_len


+ 31
- 23
examples/hed/hed.ipynb View File

@@ -86,31 +86,39 @@
"source": [
"true_train_equation = train_data[1]\n",
"false_train_equation = train_data[0]\n",
"print(f\"Equations in the dataset is organized by equation length, \" +\n",
" f\"from {min(train_data[0].keys())} to {max(train_data[0].keys())}\")\n",
"print(\n",
" f\"Equations in the dataset is organized by equation length, \"\n",
" + f\"from {min(train_data[0].keys())} to {max(train_data[0].keys())}\"\n",
")\n",
"print()\n",
"\n",
"true_train_equation_with_length_5 = true_train_equation[5]\n",
"false_train_equation_with_length_5 = false_train_equation[5]\n",
"print(f\"For each euqation length, there are {len(true_train_equation_with_length_5)} \" +\n",
" f\"true equation and {len(false_train_equation_with_length_5)} false equation \" +\n",
" f\"in the training set\")\n",
"print(\n",
" f\"For each euqation length, there are {len(true_train_equation_with_length_5)} \"\n",
" + f\"true equation and {len(false_train_equation_with_length_5)} false equation \"\n",
" + f\"in the training set\"\n",
")\n",
"\n",
"true_val_equation = val_data[1]\n",
"false_val_equation = val_data[0]\n",
"true_val_equation_with_length_5 = true_val_equation[5]\n",
"false_val_equation_with_length_5 = false_val_equation[5]\n",
"print(f\"For each euqation length, there are {len(true_val_equation_with_length_5)} \" +\n",
" f\"true equation and {len(false_val_equation_with_length_5)} false equation \" +\n",
" f\"in the validation set\")\n",
"print(\n",
" f\"For each euqation length, there are {len(true_val_equation_with_length_5)} \"\n",
" + f\"true equation and {len(false_val_equation_with_length_5)} false equation \"\n",
" + f\"in the validation set\"\n",
")\n",
"\n",
"true_test_equation = test_data[1]\n",
"false_test_equation = test_data[0]\n",
"true_test_equation_with_length_5 = true_test_equation[5]\n",
"false_test_equation_with_length_5 = false_test_equation[5]\n",
"print(f\"For each euqation length, there are {len(true_test_equation_with_length_5)} \" +\n",
" f\"true equation and {len(false_test_equation_with_length_5)} false equation \" +\n",
" f\"in the test set\")"
"print(\n",
" f\"For each euqation length, there are {len(true_test_equation_with_length_5)} \"\n",
" + f\"true equation and {len(false_test_equation_with_length_5)} false equation \"\n",
" + f\"in the test set\"\n",
")"
]
},
{
@@ -199,30 +207,30 @@
"true_train_equation_with_length_8 = true_train_equation[8]\n",
"print(f\"First true equation with length 5 in the training dataset:\")\n",
"for i, x in enumerate(true_train_equation_with_length_5[0]):\n",
" plt.subplot(1, 5, i+1)\n",
" plt.axis('off') \n",
" plt.imshow(x.squeeze(), cmap='gray')\n",
" plt.subplot(1, 5, i + 1)\n",
" plt.axis(\"off\")\n",
" plt.imshow(x.squeeze(), cmap=\"gray\")\n",
"plt.show()\n",
"print(f\"First true equation with length 8 in the training dataset:\")\n",
"for i, x in enumerate(true_train_equation_with_length_8[0]):\n",
" plt.subplot(1, 8, i+1)\n",
" plt.axis('off') \n",
" plt.imshow(x.squeeze(), cmap='gray')\n",
" plt.subplot(1, 8, i + 1)\n",
" plt.axis(\"off\")\n",
" plt.imshow(x.squeeze(), cmap=\"gray\")\n",
"plt.show()\n",
"\n",
"false_train_equation_with_length_5 = false_train_equation[5]\n",
"false_train_equation_with_length_8 = false_train_equation[8]\n",
"print(f\"First false equation with length 5 in the training dataset:\")\n",
"for i, x in enumerate(false_train_equation_with_length_5[0]):\n",
" plt.subplot(1, 5, i+1)\n",
" plt.axis('off') \n",
" plt.imshow(x.squeeze(), cmap='gray')\n",
" plt.subplot(1, 5, i + 1)\n",
" plt.axis(\"off\")\n",
" plt.imshow(x.squeeze(), cmap=\"gray\")\n",
"plt.show()\n",
"print(f\"First false equation with length 8 in the training dataset:\")\n",
"for i, x in enumerate(false_train_equation_with_length_8[0]):\n",
" plt.subplot(1, 8, i+1)\n",
" plt.axis('off') \n",
" plt.imshow(x.squeeze(), cmap='gray')\n",
" plt.subplot(1, 8, i + 1)\n",
" plt.axis(\"off\")\n",
" plt.imshow(x.squeeze(), cmap=\"gray\")\n",
"plt.show()"
]
},


+ 4
- 4
examples/hed/main.py View File

@@ -46,20 +46,20 @@ def main():
)

args = parser.parse_args()
# Build logger
print_log("Abductive Learning on the HED example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
total_train_data = get_dataset(train=True)
train_data, val_data = split_equation(total_train_data, 3, 1)
test_data = get_dataset(train=False)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
cls = SymbolNet(num_classes=4)
loss_fn = nn.CrossEntropyLoss()
@@ -83,7 +83,7 @@ def main():

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
kb = HedKB()



+ 1
- 1
examples/hed/reasoning/__init__.py View File

@@ -1,3 +1,3 @@
from .reasoning import HedKB, HedReasoner

__all__ = ["HedKB", "HedReasoner"]
__all__ = ["HedKB", "HedReasoner"]

+ 7
- 4
examples/hed/reasoning/reasoning.py View File

@@ -8,8 +8,11 @@ from abl.utils import reform_list

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))


class HedKB(PrologKB):
def __init__(self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")):
def __init__(
self, pseudo_label_list=[1, 0, "+", "="], pl_file=os.path.join(CURRENT_DIR, "learn_add.pl")
):
pl_file = pl_file.replace("\\", "/")
super().__init__(pseudo_label_list, pl_file)
self.learned_rules = {}
@@ -34,7 +37,7 @@ class HedReasoner(Reasoner):
data_example.pred_pseudo_label, data_example.Y, data_example.X, revision_idx
)
return candidate
def zoopt_budget(self, symbol_num):
return 200

@@ -53,7 +56,7 @@ class HedReasoner(Reasoner):
max_candidate_idxs = []
found = False
for idx in range(-1, len(data_example.pred_idx)):
if (not idx in idxs) and (idx >= 0):
if (idx not in idxs) and (idx >= 0):
idxs.append(idx)
candidates, _ = self.revise_at_idx(data_example[idxs])
if len(candidates) == 0:
@@ -96,4 +99,4 @@ class HedReasoner(Reasoner):
return abduced_pseudo_label

def abduce_rules(self, pred_res):
return self.kb.abduce_rules(pred_res)
return self.kb.abduce_rules(pred_res)

+ 1
- 1
examples/hwf/datasets/__init__.py View File

@@ -1,3 +1,3 @@
from .get_dataset import get_dataset

__all__ = ["get_dataset"]
__all__ = ["get_dataset"]

+ 12
- 7
examples/hwf/datasets/get_dataset.py View File

@@ -10,26 +10,31 @@ CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))

img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1,))])


def download_and_unzip(url, zip_file_name):
try:
gdown.download(url, zip_file_name)
with zipfile.ZipFile(zip_file_name, 'r') as zip_ref:
with zipfile.ZipFile(zip_file_name, "r") as zip_ref:
zip_ref.extractall(CURRENT_DIR)
os.remove(zip_file_name)
except Exception as e:
if os.path.exists(zip_file_name):
os.remove(zip_file_name)
raise Exception(f"An error occurred during download or unzip: {e}. Instead, you can download the dataset from {url} and unzip it in 'examples/hwf/datasets' folder")
raise Exception(
f"An error occurred during download or unzip: {e}. Instead, you can download "
+ f"the dataset from {url} and unzip it in 'examples/hwf/datasets' folder"
)


def get_dataset(train=True, get_pseudo_label=False):
data_dir = CURRENT_DIR + '/data'
data_dir = CURRENT_DIR + "/data"
if not os.path.exists(data_dir):
print("Dataset not exist, downloading it...")
url = 'https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download'
url = "https://drive.google.com/u/0/uc?id=1G07kw-wK-rqbg_85tuB7FNfA49q8lvoy&export=download"
download_and_unzip(url, os.path.join(CURRENT_DIR, "HWF.zip"))
print("Download and extraction complete.")
if train:
file = os.path.join(data_dir, "expr_train.json")
else:
@@ -59,4 +64,4 @@ def get_dataset(train=True, get_pseudo_label=False):
pseudo_label.append(imgs_pseudo_label)
Y.append(data[idx]["res"])

return X, pseudo_label, Y
return X, pseudo_label, Y

+ 56
- 33
examples/hwf/hwf.ipynb View File

@@ -72,21 +72,28 @@
"print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n",
"print()\n",
"train_X, train_gt_pseudo_label, train_Y = train_data\n",
"print(f\"Length of X, gt_pseudo_label, Y in train_data: \" +\n",
" f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\")\n",
"print(\n",
" f\"Length of X, gt_pseudo_label, Y in train_data: \"\n",
" + f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\"\n",
")\n",
"test_X, test_gt_pseudo_label, test_Y = test_data\n",
"print(f\"Length of X, gt_pseudo_label, Y in test_data: \" +\n",
" f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\")\n",
"print(\n",
" f\"Length of X, gt_pseudo_label, Y in test_data: \"\n",
" + f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\"\n",
")\n",
"print()\n",
"\n",
"X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n",
"print(f\"X is a {type(train_X).__name__}, \" +\n",
" f\"with each element being a {type(X_0).__name__} of {type(X_0[0]).__name__}.\")\n",
"print(f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \" +\n",
" f\"with each element being a {type(gt_pseudo_label_0).__name__} \" +\n",
" f\"of {type(gt_pseudo_label_0[0]).__name__}.\")\n",
"print(f\"Y is a {type(train_Y).__name__}, \" +\n",
" f\"with each element being a {type(Y_0).__name__}.\")"
"print(\n",
" f\"X is a {type(train_X).__name__}, \"\n",
" + f\"with each element being a {type(X_0).__name__} of {type(X_0[0]).__name__}.\"\n",
")\n",
"print(\n",
" f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \"\n",
" + f\"with each element being a {type(gt_pseudo_label_0).__name__} \"\n",
" + f\"of {type(gt_pseudo_label_0[0]).__name__}.\"\n",
")\n",
"print(f\"Y is a {type(train_Y).__name__}, \" + f\"with each element being a {type(Y_0).__name__}.\")"
]
},
{
@@ -105,21 +112,25 @@
"X_1000, gt_pseudo_label_1000, Y_1000 = train_X[1000], train_gt_pseudo_label[1000], train_Y[1000]\n",
"print(f\"X in the 1001st data example (a list of images):\")\n",
"for i, x in enumerate(X_1000):\n",
" plt.subplot(1, len(X_1000), i+1)\n",
" plt.axis('off') \n",
" plt.imshow(x.squeeze(), cmap='gray')\n",
" plt.subplot(1, len(X_1000), i + 1)\n",
" plt.axis(\"off\")\n",
" plt.imshow(x.squeeze(), cmap=\"gray\")\n",
"plt.show()\n",
"print(f\"gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_1000}\")\n",
"print(\n",
" f\"gt_pseudo_label in the 1001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_1000}\"\n",
")\n",
"print(f\"Y in the 1001st data example (the computed result): {Y_1000}\")\n",
"print()\n",
"X_3000, gt_pseudo_label_3000, Y_3000 = train_X[3000], train_gt_pseudo_label[3000], train_Y[3000]\n",
"print(f\"X in the 3001st data example (a list of images):\")\n",
"for i, x in enumerate(X_3000):\n",
" plt.subplot(1, len(X_3000), i+1)\n",
" plt.axis('off') \n",
" plt.imshow(x.squeeze(), cmap='gray')\n",
" plt.subplot(1, len(X_3000), i + 1)\n",
" plt.axis(\"off\")\n",
" plt.imshow(x.squeeze(), cmap=\"gray\")\n",
"plt.show()\n",
"print(f\"gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_3000}\")\n",
"print(\n",
" f\"gt_pseudo_label in the 3001st data example (a list of ground truth pseudo-labels): {gt_pseudo_label_3000}\"\n",
")\n",
"print(f\"Y in the 3001st data example (the computed result): {Y_3000}\")"
]
},
@@ -184,11 +195,15 @@
"source": [
"data_instances = [torch.randn(1, 45, 45).to(device) for _ in range(32)]\n",
"pred_idx = base_model.predict(X=data_instances)\n",
"print(f\"Predicted class index for a batch of 32 instances: \" +\n",
" f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\")\n",
"print(\n",
" f\"Predicted class index for a batch of 32 instances: \"\n",
" + f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\"\n",
")\n",
"pred_prob = base_model.predict_proba(X=data_instances)\n",
"print(f\"Predicted class probabilities for a batch of 32 instances: \" +\n",
" f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\")"
"print(\n",
" f\"Predicted class probabilities for a batch of 32 instances: \"\n",
" + f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\"\n",
")"
]
},
{
@@ -221,6 +236,7 @@
"outputs": [],
"source": [
"from abl.data.structures import ListData\n",
"\n",
"# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n",
"data_examples = ListData()\n",
"# We use the first 1001st and 3001st data examples in the training set as an illustration\n",
@@ -229,15 +245,19 @@
"data_examples.Y = [Y_1000, Y_3000]\n",
"\n",
"# Perform prediction on the two data examples\n",
"# Remind that, in the 1001st data example, the length of the formula is 3, \n",
"# Remind that, in the 1001st data example, the length of the formula is 3,\n",
"# while in the 3001st data example, the length of the formula is 5.\n",
"pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']\n",
"print(f\"Predicted class labels for the 100 data examples: a list of length {len(pred_label)}, \\n\" +\n",
" f\"the first element is a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}, \"+\n",
" f\"and the second element is a {type(pred_label[1]).__name__} of shape {pred_label[1].shape}.\\n\")\n",
"print(f\"Predicted class probabilities for the 100 data examples: a list of length {len(pred_prob)}, \\n\"\n",
" f\"the first element is a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}, \" +\n",
" f\"and the second element is a {type(pred_prob[1]).__name__} of shape {pred_prob[1].shape}.\")"
"pred_label, pred_prob = model.predict(data_examples)[\"label\"], model.predict(data_examples)[\"prob\"]\n",
"print(\n",
" f\"Predicted class labels for the 100 data examples: a list of length {len(pred_label)}, \\n\"\n",
" + f\"the first element is a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}, \"\n",
" + f\"and the second element is a {type(pred_label[1]).__name__} of shape {pred_label[1].shape}.\\n\"\n",
")\n",
"print(\n",
" f\"Predicted class probabilities for the 100 data examples: a list of length {len(pred_prob)}, \\n\"\n",
" f\"the first element is a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}, \"\n",
" + f\"and the second element is a {type(pred_prob[1]).__name__} of shape {pred_prob[1].shape}.\"\n",
")"
]
},
{
@@ -261,7 +281,9 @@
"outputs": [],
"source": [
"class HwfKB(KBBase):\n",
" def __init__(self, pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"*\", \"/\"]):\n",
" def __init__(\n",
" self, pseudo_label_list=[\"1\", \"2\", \"3\", \"4\", \"5\", \"6\", \"7\", \"8\", \"9\", \"+\", \"-\", \"*\", \"/\"]\n",
" ):\n",
" super().__init__(pseudo_label_list)\n",
"\n",
" def _valid_candidate(self, formula):\n",
@@ -273,13 +295,14 @@
" if i % 2 != 0 and formula[i] not in [\"+\", \"-\", \"*\", \"/\"]:\n",
" return False\n",
" return True\n",
" \n",
"\n",
" # Implement the deduction function\n",
" def logic_forward(self, formula):\n",
" if not self._valid_candidate(formula):\n",
" return np.inf\n",
" return eval(\"\".join(formula))\n",
"\n",
"\n",
"kb = HwfKB()"
]
},


+ 4
- 4
examples/hwf/main.py View File

@@ -113,19 +113,19 @@ def main():
)

args = parser.parse_args()
# Build logger
print_log("Abductive Learning on the HWF example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
train_data = get_dataset(train=True, get_pseudo_label=True)
test_data = get_dataset(train=False, get_pseudo_label=True)

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
cls = SymbolNet(num_classes=13, image_size=(45, 45, 1))
loss_fn = nn.CrossEntropyLoss()
@@ -148,7 +148,7 @@ def main():

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
if args.ground:
kb = HwfGroundKB()


+ 1
- 1
examples/mnist_add/datasets/__init__.py View File

@@ -1,3 +1,3 @@
from .get_dataset import get_dataset

__all__ = ["get_dataset"]
__all__ = ["get_dataset"]

+ 1
- 0
examples/mnist_add/datasets/get_dataset.py View File

@@ -5,6 +5,7 @@ from torchvision.transforms import transforms

CURRENT_DIR = os.path.abspath(os.path.dirname(__file__))


def get_dataset(train=True, get_pseudo_label=True):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]


+ 2
- 2
examples/mnist_add/main.py View File

@@ -88,7 +88,7 @@ def main():

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
cls = LeNet5(num_classes=10)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)
@@ -119,7 +119,7 @@ def main():

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
if args.prolog:
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")


+ 19
- 5
examples/mnist_add/main_with_model_converter.py View File

@@ -89,23 +89,37 @@ def main():

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build necessary components for BasicNN
model=FixMatch(network=LeNet5(), threshold=0.95,lambda_u=1.0,mu=7,T=0.5,epoch=1,num_it_epoch=2**20,num_it_total=2**20,device='cuda')
model = FixMatch(
network=LeNet5(),
threshold=0.95,
lambda_u=1.0,
mu=7,
T=0.5,
epoch=1,
num_it_epoch=2**20,
num_it_total=2**20,
device="cuda",
)

loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2)
optimizer_dict = dict(optimizer=RMSprop, lr=0.0003, alpha=0.9)
scheduler_dict = dict(scheduler=lr_scheduler.OneCycleLR, max_lr=0.0003, pct_start=0.15, total_steps=200)
scheduler_dict = dict(
scheduler=lr_scheduler.OneCycleLR, max_lr=0.0003, pct_start=0.15, total_steps=200
)

converter = ModelConverter()
base_model = converter.convert_lambdalearn_to_basicnn(model, loss_fn=loss_fn, optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict)
base_model = converter.convert_lambdalearn_to_basicnn(
model, loss_fn=loss_fn, optimizer_dict=optimizer_dict, scheduler_dict=scheduler_dict
)

# Build ABLModel
model = ABLModel(base_model)

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
if args.prolog:
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="add.pl")


+ 49
- 30
examples/mnist_add/mnist_add.ipynb View File

@@ -87,22 +87,29 @@
"print(f\"Both train_data and test_data consist of 3 components: X, gt_pseudo_label, Y\")\n",
"print()\n",
"train_X, train_gt_pseudo_label, train_Y = train_data\n",
"print(f\"Length of X, gt_pseudo_label, Y in train_data: \" +\n",
" f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\")\n",
"print(\n",
" f\"Length of X, gt_pseudo_label, Y in train_data: \"\n",
" + f\"{len(train_X)}, {len(train_gt_pseudo_label)}, {len(train_Y)}\"\n",
")\n",
"test_X, test_gt_pseudo_label, test_Y = test_data\n",
"print(f\"Length of X, gt_pseudo_label, Y in test_data: \" +\n",
" f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\")\n",
"print(\n",
" f\"Length of X, gt_pseudo_label, Y in test_data: \"\n",
" + f\"{len(test_X)}, {len(test_gt_pseudo_label)}, {len(test_Y)}\"\n",
")\n",
"print()\n",
"\n",
"X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n",
"print(f\"X is a {type(train_X).__name__}, \" +\n",
" f\"with each element being a {type(X_0).__name__} \" +\n",
" f\"of {len(X_0)} {type(X_0[0]).__name__}.\")\n",
"print(f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \" +\n",
" f\"with each element being a {type(gt_pseudo_label_0).__name__} \" +\n",
" f\"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.\")\n",
"print(f\"Y is a {type(train_Y).__name__}, \" +\n",
" f\"with each element being a {type(Y_0).__name__}.\")"
"print(\n",
" f\"X is a {type(train_X).__name__}, \"\n",
" + f\"with each element being a {type(X_0).__name__} \"\n",
" + f\"of {len(X_0)} {type(X_0[0]).__name__}.\"\n",
")\n",
"print(\n",
" f\"gt_pseudo_label is a {type(train_gt_pseudo_label).__name__}, \"\n",
" + f\"with each element being a {type(gt_pseudo_label_0).__name__} \"\n",
" + f\"of {len(gt_pseudo_label_0)} {type(gt_pseudo_label_0[0]).__name__}.\"\n",
")\n",
"print(f\"Y is a {type(train_Y).__name__}, \" + f\"with each element being a {type(Y_0).__name__}.\")"
]
},
{
@@ -146,14 +153,16 @@
"source": [
"X_0, gt_pseudo_label_0, Y_0 = train_X[0], train_gt_pseudo_label[0], train_Y[0]\n",
"print(f\"X in the first data example (a list of two images):\")\n",
"plt.subplot(1,2,1)\n",
"plt.axis('off') \n",
"plt.imshow(X_0[0].squeeze(), cmap='gray')\n",
"plt.subplot(1,2,2)\n",
"plt.axis('off') \n",
"plt.imshow(X_0[1].squeeze(), cmap='gray')\n",
"plt.subplot(1, 2, 1)\n",
"plt.axis(\"off\")\n",
"plt.imshow(X_0[0].squeeze(), cmap=\"gray\")\n",
"plt.subplot(1, 2, 2)\n",
"plt.axis(\"off\")\n",
"plt.imshow(X_0[1].squeeze(), cmap=\"gray\")\n",
"plt.show()\n",
"print(f\"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}\")\n",
"print(\n",
" f\"gt_pseudo_label in the first data example (a list of two ground truth pseudo-labels): {gt_pseudo_label_0}\"\n",
")\n",
"print(f\"Y in the first data example (their sum result): {Y_0}\")"
]
},
@@ -219,11 +228,15 @@
"source": [
"data_instances = [torch.randn(1, 28, 28).to(device) for _ in range(32)]\n",
"pred_idx = base_model.predict(X=data_instances)\n",
"print(f\"Predicted class index for a batch of 32 instances: \" +\n",
" f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\")\n",
"print(\n",
" f\"Predicted class index for a batch of 32 instances: \"\n",
" + f\"{type(pred_idx).__name__} with shape {pred_idx.shape}\"\n",
")\n",
"pred_prob = base_model.predict_proba(X=data_instances)\n",
"print(f\"Predicted class probabilities for a batch of 32 instances: \" +\n",
" f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\")"
"print(\n",
" f\"Predicted class probabilities for a batch of 32 instances: \"\n",
" + f\"{type(pred_prob).__name__} with shape {pred_prob.shape}\"\n",
")"
]
},
{
@@ -268,6 +281,7 @@
],
"source": [
"from abl.data.structures import ListData\n",
"\n",
"# ListData is a data structure provided by ABL-Package that can be used to organize data examples\n",
"data_examples = ListData()\n",
"# We use the first 100 data examples in the training set as an illustration\n",
@@ -276,13 +290,17 @@
"data_examples.Y = train_Y[:100]\n",
"\n",
"# Perform prediction on the 100 data examples\n",
"pred_label, pred_prob = model.predict(data_examples)['label'], model.predict(data_examples)['prob']\n",
"print(f\"Predicted class labels for the 100 data examples: \\n\" +\n",
" f\"a list of length {len(pred_label)}, and each element is \" +\n",
" f\"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\\n\")\n",
"print(f\"Predicted class probabilities for the 100 data examples: \\n\" +\n",
" f\"a list of length {len(pred_prob)}, and each element is \" +\n",
" f\"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.\")"
"pred_label, pred_prob = model.predict(data_examples)[\"label\"], model.predict(data_examples)[\"prob\"]\n",
"print(\n",
" f\"Predicted class labels for the 100 data examples: \\n\"\n",
" + f\"a list of length {len(pred_label)}, and each element is \"\n",
" + f\"a {type(pred_label[0]).__name__} of shape {pred_label[0].shape}.\\n\"\n",
")\n",
"print(\n",
" f\"Predicted class probabilities for the 100 data examples: \\n\"\n",
" + f\"a list of length {len(pred_prob)}, and each element is \"\n",
" + f\"a {type(pred_prob[0]).__name__} of shape {pred_prob[0].shape}.\"\n",
")"
]
},
{
@@ -313,6 +331,7 @@
" def logic_forward(self, nums):\n",
" return sum(nums)\n",
"\n",
"\n",
"kb = AddKB()"
]
},


+ 10
- 7
examples/zoo/get_dataset.py View File

@@ -4,27 +4,30 @@ import openml

# Function to load and preprocess the dataset
def load_and_preprocess_dataset(dataset_id):
dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False)
dataset = openml.datasets.get_dataset(
dataset_id, download_data=True, download_qualities=False, download_features_meta_data=False
)
X, y, _, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
# Convert data types
for col in X.select_dtypes(include='bool').columns:
for col in X.select_dtypes(include="bool").columns:
X[col] = X[col].astype(int)
y = y.cat.codes.astype(int)
X, y = X.to_numpy(), y.to_numpy()
return X, y


# Function to split data (one shot)
def split_dataset(X, y, test_size = 0.3):
def split_dataset(X, y, test_size=0.3):
# For every class: 1 : (1-test_size)*(len-1) : test_size*(len-1)
label_indices, unlabel_indices, test_indices = [], [], []
for class_label in np.unique(y):
idxs = np.where(y == class_label)[0]
np.random.shuffle(idxs)
n_train_unlabel = int((1-test_size)*(len(idxs)-1))
n_train_unlabel = int((1 - test_size) * (len(idxs) - 1))
label_indices.append(idxs[0])
unlabel_indices.extend(idxs[1:1+n_train_unlabel])
test_indices.extend(idxs[1+n_train_unlabel:])
unlabel_indices.extend(idxs[1 : 1 + n_train_unlabel])
test_indices.extend(idxs[1 + n_train_unlabel :])
X_label, y_label = X[label_indices], y[label_indices]
X_unlabel, y_unlabel = X[unlabel_indices], y[unlabel_indices]
X_test, y_test = X[test_indices], y[test_indices]
return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test
return X_label, y_label, X_unlabel, y_unlabel, X_test, y_test

+ 20
- 9
examples/zoo/kb.py View File

@@ -11,18 +11,27 @@ class ZooKB(KBBase):
self.solver = Solver()

# Load information of Zoo dataset
dataset = openml.datasets.get_dataset(dataset_id = 62, download_data=False, download_qualities=False, download_features_meta_data=False)
X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
dataset = openml.datasets.get_dataset(
dataset_id=62,
download_data=False,
download_qualities=False,
download_features_meta_data=False,
)
X, y, categorical_indicator, attribute_names = dataset.get_data(
target=dataset.default_target_attribute
)
self.attribute_names = attribute_names
self.target_names = y.cat.categories.tolist()
# print("Attribute names are: ", self.attribute_names)
# print("Target names are: ", self.target_names)
# self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"]
# self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"]
# self.attribute_names = ["hair", "feathers", "eggs", "milk", "airborne", "aquatic", "predator", "toothed", "backbone", "breathes", "venomous", "fins", "legs", "tail", "domestic", "catsize"] # noqa: E501
# self.target_names = ["mammal", "bird", "reptile", "fish", "amphibian", "insect", "invertebrate"] # noqa: E501

# Define variables
for name in self.attribute_names+self.target_names:
exec(f"globals()['{name}'] = Int('{name}')") ## or use dict to create var and modify rules
for name in self.attribute_names + self.target_names:
exec(
f"globals()['{name}'] = Int('{name}')"
) # or use dict to create var and modify rules
# Define rules
rules = [
Implies(milk == 1, mammal == 1),
@@ -54,11 +63,13 @@ class ZooKB(KBBase):
Implies(insect == 1, eggs == 1),
Implies(insect == 1, Not(backbone == 1)),
Implies(insect == 1, legs == 6),
Implies(invertebrate == 1, Not(backbone == 1))
Implies(invertebrate == 1, Not(backbone == 1)),
]
# Define weights and sum of violated weights
self.weights = {rule: 1 for rule in rules}
self.total_violation_weight = Sum([If(Not(rule), self.weights[rule], 0) for rule in self.weights])
self.total_violation_weight = Sum(
[If(Not(rule), self.weights[rule], 0) for rule in self.weights]
)

def logic_forward(self, pseudo_label, data_point):
attribute_names, target_names = self.attribute_names, self.target_names
@@ -69,7 +80,7 @@ class ZooKB(KBBase):
self.solver.reset()
for name, value in zip(attribute_names, data_point):
solver.add(eval(f"{name} == {value}"))
for cate, name in zip(self.pseudo_label_list,target_names):
for cate, name in zip(self.pseudo_label_list, target_names):
value = 1 if (cate == pseudo_label) else 0
solver.add(eval(f"{name} == {value}"))



+ 16
- 10
examples/zoo/main.py View File

@@ -14,7 +14,6 @@ from get_dataset import load_and_preprocess_dataset, split_dataset
from kb import ZooKB



def consitency(data_example, candidates, candidate_idxs, reasoning_results):
pred_prob = data_example.pred_prob
model_scores = confidence_dist(pred_prob, candidate_idxs)
@@ -22,19 +21,20 @@ def consitency(data_example, candidates, candidate_idxs, reasoning_results):
scores = model_scores + rule_scores
return scores


def main():
parser = argparse.ArgumentParser(description="Zoo example")
parser.add_argument(
"--loops", type=int, default=3, help="number of loop iterations (default : 3)"
)
args = parser.parse_args()
# Build logger
print_log("Abductive Learning on the ZOO example.", logger="current")

### Working with Data
print_log("Working with Data.", logger="current")
X, y = load_and_preprocess_dataset(dataset_id=62)
X_label, y_label, X_unlabel, y_unlabel, X_test, y_test = split_dataset(X, y, test_size=0.3)
label_data = tab_data_to_tuple(X_label, y_label)
@@ -43,7 +43,7 @@ def main():

### Building the Learning Part
print_log("Building the Learning Part.", logger="current")
# Build base model
base_model = RandomForestClassifier()

@@ -52,32 +52,38 @@ def main():

### Building the Reasoning Part
print_log("Building the Reasoning Part.", logger="current")
# Build knowledge base
kb = ZooKB()
# Create reasoner
reasoner = Reasoner(kb, dist_func=consitency)

### Building Evaluation Metrics
print_log("Building Evaluation Metrics.", logger="current")
metric_list = [SymbolAccuracy(prefix="zoo"), ReasoningMetric(kb=kb, prefix="zoo")]
### Bridging learning and reasoning
print_log("Bridge Learning and Reasoning.", logger="current")
bridge = SimpleBridge(model, reasoner, metric_list)
# Retrieve the directory of the Log file and define the directory for saving the model weights.
log_dir = ABLLogger.get_current_instance().log_dir
weights_dir = osp.join(log_dir, "weights")
# Performing training and testing
print_log("------- Use labeled data to pretrain the model -----------", logger="current")
base_model.fit(X_label, y_label)
print_log("------- Test the initial model -----------", logger="current")
bridge.test(test_data)
print_log("------- Use ABL to train the model -----------", logger="current")
bridge.train(train_data=train_data, label_data=label_data, loops=args.loops, segment_size=len(X_unlabel), save_dir=weights_dir)
bridge.train(
train_data=train_data,
label_data=label_data,
loops=args.loops,
segment_size=len(X_unlabel),
save_dir=weights_dir,
)
print_log("------- Test the final model -----------", logger="current")
bridge.test(test_data)



+ 11
- 4
examples/zoo/zoo.ipynb View File

@@ -106,9 +106,9 @@
"metadata": {},
"outputs": [],
"source": [
"label_data = tab_data_to_tuple(X_label, y_label, reasoning_result = 0)\n",
"test_data = tab_data_to_tuple(X_test, y_test, reasoning_result = 0)\n",
"train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result = 0)"
"label_data = tab_data_to_tuple(X_label, y_label, reasoning_result=0)\n",
"test_data = tab_data_to_tuple(X_test, y_test, reasoning_result=0)\n",
"train_data = tab_data_to_tuple(X_unlabel, y_unlabel, reasoning_result=0)"
]
},
{
@@ -240,6 +240,7 @@
" scores = model_scores + rule_scores\n",
" return scores\n",
"\n",
"\n",
"reasoner = Reasoner(kb, dist_func=consitency)"
]
},
@@ -338,7 +339,13 @@
"print_log(\"------- Test the initial model -----------\", logger=\"current\")\n",
"bridge.test(test_data)\n",
"print_log(\"------- Use ABL to train the model -----------\", logger=\"current\")\n",
"bridge.train(train_data=train_data, label_data=label_data, loops=3, segment_size=len(X_unlabel), save_dir=weights_dir)\n",
"bridge.train(\n",
" train_data=train_data,\n",
" label_data=label_data,\n",
" loops=3,\n",
" segment_size=len(X_unlabel),\n",
" save_dir=weights_dir,\n",
")\n",
"print_log(\"------- Test the final model -----------\", logger=\"current\")\n",
"bridge.test(test_data)"
]


+ 2
- 2
tests/conftest.py View File

@@ -200,7 +200,7 @@ def kb_add_ground():

@pytest.fixture
def kb_add_prolog():
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
return
kb = PrologKB(pseudo_label_list=list(range(10)), pl_file="examples/mnist_add/add.pl")
return kb
@@ -218,7 +218,7 @@ def kb_hwf2():

@pytest.fixture
def kb_hed():
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
return
kb = HedKB(
pseudo_label_list=[1, 0, "+", "="],


+ 24
- 20
tests/test_reasoning.py View File

@@ -28,9 +28,13 @@ class TestKBBase(object):
assert result == ([[0, 2], [1, 1], [2, 0]], [2, 2, 2])

def test_abduce_candidates(self, kb_add):
result = kb_add.abduce_candidates([0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0)
result = kb_add.abduce_candidates(
[0, 1], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0
)
assert result == ([[0, 1]], [1])
result = kb_add.abduce_candidates([1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0)
result = kb_add.abduce_candidates(
[1, 2], 1, [0.1, -0.2, 0.2, -0.3], max_revision_num=2, require_more_revision=0
)
assert result == ([[1, 0]], [1])


@@ -53,19 +57,19 @@ class TestGroundKB(object):

class TestPrologKB(object):
def test_init_pl1(self, kb_add_prolog):
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
return
assert kb_add_prolog.pseudo_label_list == list(range(10))
assert kb_add_prolog.pl_file == "examples/mnist_add/add.pl"

def test_init_pl2(self, kb_hed):
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
return
assert kb_hed.pseudo_label_list == [1, 0, "+", "="]
assert kb_hed.pl_file == "examples/hed/reasoning/learn_add.pl"

def test_prolog_file_not_exist(self):
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
return
pseudo_label_list = [1, 2]
non_existing_file = "path/to/non_existing_file.pl"
@@ -74,13 +78,13 @@ class TestPrologKB(object):
assert non_existing_file in str(excinfo.value)

def test_logic_forward_pl1(self, kb_add_prolog):
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
return
result = kb_add_prolog.logic_forward([1, 2])
assert result == 3

def test_logic_forward_pl2(self, kb_hed):
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
return
consist_exs = [
[1, 1, "+", 0, "=", 1, 1],
@@ -97,7 +101,7 @@ class TestPrologKB(object):
assert kb_hed.logic_forward(inconsist_exs) is False

def test_revise_at_idx(self, kb_add_prolog):
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
return
result = kb_add_prolog.revise_at_idx([1, 2], 2, [0.1, -0.2, 0.2, -0.3], [0])
assert result == ([[0, 2]], [2])
@@ -113,34 +117,34 @@ class TestReaonser(object):
assert 'Valid options for predefined dist_func include "hamming" and "confidence"' in str(
excinfo.value
)
def random_dist(self, data_example, candidates, candidate_idxs, reasoning_results):
cost_list = [np.random.rand() for _ in candidates]
return cost_list
def test_user_defined_dist_func(self, kb_add):
reasoner = Reasoner(kb_add, self.random_dist)
assert reasoner.dist_func == self.random_dist
def invalid_dist1(self, candidates):
cost_list = np.array([np.random.rand() for _ in candidates])
return cost_list
def invalid_dist2(self, data_example, candidates, candidate_idxs, reasoning_results):
cost_list = np.array([np.random.rand() for _ in candidates])
return np.append(cost_list, np.random.rand())
def test_invalid_user_defined_dist_func(self, kb_add, data_examples_add):
with pytest.raises(ValueError) as excinfo:
Reasoner(kb_add, self.invalid_dist1)
assert 'User-defined dist_func must have exactly four parameters' in str(
excinfo.value
)
assert "User-defined dist_func must have exactly four parameters" in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
reasoner = Reasoner(kb_add, self.invalid_dist2)
reasoner.batch_abduce(data_examples_add)
assert 'The length of the array returned by dist_func must be equal to the number of candidates' in str(
excinfo.value
assert (
"The length of the array returned by dist_func must be "
+ "equal to the number of candidates"
in str(excinfo.value)
)


@@ -186,7 +190,7 @@ class TestBatchAbduce(object):
]

def test_batch_abduce_prolog(self, kb_add_prolog, data_examples_add):
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
return
reasoner1 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=0)
reasoner2 = Reasoner(kb_add_prolog, "confidence", max_revision=1, require_more_revision=1)
@@ -208,7 +212,7 @@ class TestBatchAbduce(object):
]

def test_batch_abduce_zoopt(self, kb_add_prolog, data_examples_add):
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
return
reasoner1 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=1)
reasoner2 = Reasoner(kb_add_prolog, "confidence", use_zoopt=True, max_revision=2)


Loading…
Cancel
Save