diff --git a/abl/__init__.py b/abl/__init__.py index b30df59..742c9c5 100644 --- a/abl/__init__.py +++ b/abl/__init__.py @@ -1,2 +1,2 @@ from .learning import abl_model, basic_nn -from .reasoning import reasoner, kb \ No newline at end of file +from .reasoning import kb, reasoner diff --git a/abl/bridge/base_bridge.py b/abl/bridge/base_bridge.py index b535211..869ea39 100644 --- a/abl/bridge/base_bridge.py +++ b/abl/bridge/base_bridge.py @@ -1,11 +1,10 @@ from abc import ABCMeta, abstractmethod -from typing import Any, List, Tuple, Optional, Union +from typing import Any, List, Optional, Tuple, Union from ..learning import ABLModel from ..reasoning import ReasonerBase from ..structures import ListData - DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] @@ -37,6 +36,7 @@ class BaseBridge(metaclass=ABCMeta): @abstractmethod def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: """Placeholder for abduce pseudo labels.""" + pass @abstractmethod def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: diff --git a/abl/dataset/bridge_dataset.py b/abl/dataset/bridge_dataset.py index bb0ce98..a7d32c5 100644 --- a/abl/dataset/bridge_dataset.py +++ b/abl/dataset/bridge_dataset.py @@ -1,5 +1,6 @@ +from typing import Any, List, Tuple + from torch.utils.data import Dataset -from typing import List, Any, Tuple class BridgeDataset(Dataset): diff --git a/abl/dataset/classification_dataset.py b/abl/dataset/classification_dataset.py index 28f9299..1663642 100644 --- a/abl/dataset/classification_dataset.py +++ b/abl/dataset/classification_dataset.py @@ -1,6 +1,7 @@ +from typing import Any, Callable, List, Tuple + import torch from torch.utils.data import Dataset -from typing import List, Any, Tuple, Callable class ClassificationDataset(Dataset): diff --git a/abl/dataset/regression_dataset.py b/abl/dataset/regression_dataset.py index 8cf136c..118ac65 100644 --- a/abl/dataset/regression_dataset.py +++ b/abl/dataset/regression_dataset.py @@ -1,6 +1,7 @@ +from typing import Any, List, Tuple + import torch from torch.utils.data import Dataset -from typing import List, Any, Tuple class RegressionDataset(Dataset): diff --git a/abl/evaluation/__init__.py b/abl/evaluation/__init__.py index a849d68..3106412 100644 --- a/abl/evaluation/__init__.py +++ b/abl/evaluation/__init__.py @@ -1,3 +1,3 @@ from .base_metric import BaseMetric -from .symbol_metric import SymbolMetric from .semantics_metric import SemanticsMetric +from .symbol_metric import SymbolMetric diff --git a/abl/evaluation/base_metric.py b/abl/evaluation/base_metric.py index 44364f8..e18f452 100644 --- a/abl/evaluation/base_metric.py +++ b/abl/evaluation/base_metric.py @@ -1,8 +1,8 @@ +import logging from abc import ABCMeta, abstractmethod from typing import Any, List, Optional, Sequence -from ..utils import print_log -import logging +from ..utils import print_log class BaseMetric(metaclass=ABCMeta): diff --git a/abl/evaluation/semantics_metric.py b/abl/evaluation/semantics_metric.py index 1bacca4..09eb238 100644 --- a/abl/evaluation/semantics_metric.py +++ b/abl/evaluation/semantics_metric.py @@ -1,4 +1,5 @@ from typing import Optional, Sequence + from .base_metric import BaseMetric diff --git a/abl/evaluation/symbol_metric.py b/abl/evaluation/symbol_metric.py index 3c0c216..e133381 100644 --- a/abl/evaluation/symbol_metric.py +++ b/abl/evaluation/symbol_metric.py @@ -1,4 +1,5 @@ -from typing import Optional, Sequence, Callable +from typing import Callable, Optional, Sequence + from .base_metric import BaseMetric diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index b9b0b36..25b16bd 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -10,14 +10,15 @@ # # ================================================================# -import torch +import os +from typing import Any, Callable, List, Optional, T, Tuple + import numpy +import torch from torch.utils.data import DataLoader -from ..utils.logger import print_log -from ..dataset import ClassificationDataset -import os -from typing import List, Any, T, Optional, Callable, Tuple +from ..dataset import ClassificationDataset +from ..utils.logger import print_log class BasicNN: diff --git a/abl/reasoning/__init__.py b/abl/reasoning/__init__.py index 8930758..9b3d244 100644 --- a/abl/reasoning/__init__.py +++ b/abl/reasoning/__init__.py @@ -1,2 +1,2 @@ +from .kb import KBBase, ground_KB, prolog_KB from .reasoner import ReasonerBase -from .kb import KBBase, prolog_KB \ No newline at end of file diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 614e454..a839ae8 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -1,24 +1,16 @@ -from abc import ABC, abstractmethod import bisect -import numpy as np - +from abc import ABC, abstractmethod from collections import defaultdict -from itertools import product, combinations - -from ..utils.utils import ( - flatten, - reform_idx, - hamming_dist, - check_equal, - to_hashable, - hashable_to_list, -) - +from functools import lru_cache +from itertools import combinations, product from multiprocessing import Pool -from functools import lru_cache +import numpy as np import pyswip +from ..utils.utils import (check_equal, flatten, hamming_dist, + hashable_to_list, reform_idx, to_hashable) + class KBBase(ABC): def __init__(self, pseudo_label_list, max_err=0, use_cache=True): diff --git a/abl/reasoning/reasoner.py b/abl/reasoning/reasoner.py index d1ca4bb..630f1af 100644 --- a/abl/reasoning/reasoner.py +++ b/abl/reasoning/reasoner.py @@ -1,12 +1,8 @@ import numpy as np -from zoopt import Dimension, Objective, Parameter, Opt -from ..utils.utils import ( - confidence_dist, - flatten, - reform_idx, - hamming_dist, - calculate_revision_num, -) +from zoopt import Dimension, Objective, Opt, Parameter + +from ..utils.utils import (calculate_revision_num, confidence_dist, flatten, + hamming_dist, reform_idx) class ReasonerBase: diff --git a/abl/utils/utils.py b/abl/utils/utils.py index 9d1dc7a..8192bf9 100644 --- a/abl/utils/utils.py +++ b/abl/utils/utils.py @@ -1,6 +1,7 @@ -import numpy as np from itertools import chain +import numpy as np + def flatten(nested_list): """ diff --git a/docs/conf.py b/docs/conf.py index 25bd78e..bc8476e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- -import sys import os import re +import sys if not 'READTHEDOCS' in os.environ: sys.path.insert(0, os.path.abspath('..')) @@ -11,7 +11,6 @@ sys.path.append(os.path.abspath('./ABL/')) # from sphinx.locale import _ from sphinx_rtd_theme import __version__ - project = u'ABL' slug = re.sub(r'\W+', '-', project.lower()) author = u'Yu-Xuan Huang, Wen-Chao Hu, En-Hao Gao' diff --git a/examples/hed/hed_bridge.py b/examples/hed/hed_bridge.py index e93d46c..d7845cc 100644 --- a/examples/hed/hed_bridge.py +++ b/examples/hed/hed_bridge.py @@ -1,18 +1,18 @@ import os from collections import defaultdict + import torch from torch.utils.data import DataLoader -from abl.reasoning import ReasonerBase -from abl.learning import ABLModel, BasicNN from abl.bridge import SimpleBridge -from abl.evaluation import BaseMetric from abl.dataset import BridgeDataset, RegressionDataset +from abl.evaluation import BaseMetric +from abl.learning import ABLModel, BasicNN +from abl.reasoning import ReasonerBase from abl.utils import print_log - -from examples.hed.utils import gen_mappings, InfiniteSampler -from examples.models.nn import SymbolNetAutoencoder from examples.hed.datasets.get_hed import get_pretrain_data +from examples.hed.utils import InfiniteSampler, gen_mappings +from examples.models.nn import SymbolNetAutoencoder class HEDBridge(SimpleBridge): diff --git a/examples/hed/utils.py b/examples/hed/utils.py index 42b7316..77c015e 100644 --- a/examples/hed/utils.py +++ b/examples/hed/utils.py @@ -1,6 +1,6 @@ +import numpy as np import torch import torch.nn as nn -import numpy as np import torch.utils.data.sampler as sampler diff --git a/examples/mnist_add/datasets/get_mnist_add.py b/examples/mnist_add/datasets/get_mnist_add.py index 46b5f12..21b9101 100644 --- a/examples/mnist_add/datasets/get_mnist_add.py +++ b/examples/mnist_add/datasets/get_mnist_add.py @@ -1,6 +1,7 @@ import torchvision from torchvision.transforms import transforms + def get_data(file, img_dataset, get_pseudo_label): X = [] if get_pseudo_label: diff --git a/examples/models/nn.py b/examples/models/nn.py index 64a3deb..c3ecc40 100644 --- a/examples/models/nn.py +++ b/examples/models/nn.py @@ -11,8 +11,8 @@ # ================================================================# -import torch import numpy as np +import torch from torch import nn diff --git a/setup.py b/setup.py index 33263a6..9e675d3 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ import os + from setuptools import find_packages, setup