| @@ -1,2 +1,2 @@ | |||||
| from .learning import abl_model, basic_nn | from .learning import abl_model, basic_nn | ||||
| from .reasoning import reasoner, kb | |||||
| from .reasoning import kb, reasoner | |||||
| @@ -1,11 +1,10 @@ | |||||
| from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
| from typing import Any, List, Tuple, Optional, Union | |||||
| from typing import Any, List, Optional, Tuple, Union | |||||
| from ..learning import ABLModel | from ..learning import ABLModel | ||||
| from ..reasoning import ReasonerBase | from ..reasoning import ReasonerBase | ||||
| from ..structures import ListData | from ..structures import ListData | ||||
| DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] | DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] | ||||
| @@ -37,6 +36,7 @@ class BaseBridge(metaclass=ABCMeta): | |||||
| @abstractmethod | @abstractmethod | ||||
| def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | ||||
| """Placeholder for abduce pseudo labels.""" | """Placeholder for abduce pseudo labels.""" | ||||
| pass | |||||
| @abstractmethod | @abstractmethod | ||||
| def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: | ||||
| @@ -1,5 +1,6 @@ | |||||
| from typing import Any, List, Tuple | |||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| from typing import List, Any, Tuple | |||||
| class BridgeDataset(Dataset): | class BridgeDataset(Dataset): | ||||
| @@ -1,6 +1,7 @@ | |||||
| from typing import Any, Callable, List, Tuple | |||||
| import torch | import torch | ||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| from typing import List, Any, Tuple, Callable | |||||
| class ClassificationDataset(Dataset): | class ClassificationDataset(Dataset): | ||||
| @@ -1,6 +1,7 @@ | |||||
| from typing import Any, List, Tuple | |||||
| import torch | import torch | ||||
| from torch.utils.data import Dataset | from torch.utils.data import Dataset | ||||
| from typing import List, Any, Tuple | |||||
| class RegressionDataset(Dataset): | class RegressionDataset(Dataset): | ||||
| @@ -1,3 +1,3 @@ | |||||
| from .base_metric import BaseMetric | from .base_metric import BaseMetric | ||||
| from .symbol_metric import SymbolMetric | |||||
| from .semantics_metric import SemanticsMetric | from .semantics_metric import SemanticsMetric | ||||
| from .symbol_metric import SymbolMetric | |||||
| @@ -1,8 +1,8 @@ | |||||
| import logging | |||||
| from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
| from typing import Any, List, Optional, Sequence | from typing import Any, List, Optional, Sequence | ||||
| from ..utils import print_log | |||||
| import logging | |||||
| from ..utils import print_log | |||||
| class BaseMetric(metaclass=ABCMeta): | class BaseMetric(metaclass=ABCMeta): | ||||
| @@ -1,4 +1,5 @@ | |||||
| from typing import Optional, Sequence | from typing import Optional, Sequence | ||||
| from .base_metric import BaseMetric | from .base_metric import BaseMetric | ||||
| @@ -1,4 +1,5 @@ | |||||
| from typing import Optional, Sequence, Callable | |||||
| from typing import Callable, Optional, Sequence | |||||
| from .base_metric import BaseMetric | from .base_metric import BaseMetric | ||||
| @@ -10,14 +10,15 @@ | |||||
| # | # | ||||
| # ================================================================# | # ================================================================# | ||||
| import torch | |||||
| import os | |||||
| from typing import Any, Callable, List, Optional, T, Tuple | |||||
| import numpy | import numpy | ||||
| import torch | |||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
| from ..utils.logger import print_log | |||||
| from ..dataset import ClassificationDataset | |||||
| import os | |||||
| from typing import List, Any, T, Optional, Callable, Tuple | |||||
| from ..dataset import ClassificationDataset | |||||
| from ..utils.logger import print_log | |||||
| class BasicNN: | class BasicNN: | ||||
| @@ -1,2 +1,2 @@ | |||||
| from .kb import KBBase, ground_KB, prolog_KB | |||||
| from .reasoner import ReasonerBase | from .reasoner import ReasonerBase | ||||
| from .kb import KBBase, prolog_KB | |||||
| @@ -1,24 +1,16 @@ | |||||
| from abc import ABC, abstractmethod | |||||
| import bisect | import bisect | ||||
| import numpy as np | |||||
| from abc import ABC, abstractmethod | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from itertools import product, combinations | |||||
| from ..utils.utils import ( | |||||
| flatten, | |||||
| reform_idx, | |||||
| hamming_dist, | |||||
| check_equal, | |||||
| to_hashable, | |||||
| hashable_to_list, | |||||
| ) | |||||
| from functools import lru_cache | |||||
| from itertools import combinations, product | |||||
| from multiprocessing import Pool | from multiprocessing import Pool | ||||
| from functools import lru_cache | |||||
| import numpy as np | |||||
| import pyswip | import pyswip | ||||
| from ..utils.utils import (check_equal, flatten, hamming_dist, | |||||
| hashable_to_list, reform_idx, to_hashable) | |||||
| class KBBase(ABC): | class KBBase(ABC): | ||||
| def __init__(self, pseudo_label_list, max_err=0, use_cache=True): | def __init__(self, pseudo_label_list, max_err=0, use_cache=True): | ||||
| @@ -1,12 +1,8 @@ | |||||
| import numpy as np | import numpy as np | ||||
| from zoopt import Dimension, Objective, Parameter, Opt | |||||
| from ..utils.utils import ( | |||||
| confidence_dist, | |||||
| flatten, | |||||
| reform_idx, | |||||
| hamming_dist, | |||||
| calculate_revision_num, | |||||
| ) | |||||
| from zoopt import Dimension, Objective, Opt, Parameter | |||||
| from ..utils.utils import (calculate_revision_num, confidence_dist, flatten, | |||||
| hamming_dist, reform_idx) | |||||
| class ReasonerBase: | class ReasonerBase: | ||||
| @@ -1,6 +1,7 @@ | |||||
| import numpy as np | |||||
| from itertools import chain | from itertools import chain | ||||
| import numpy as np | |||||
| def flatten(nested_list): | def flatten(nested_list): | ||||
| """ | """ | ||||
| @@ -1,8 +1,8 @@ | |||||
| # -*- coding: utf-8 -*- | # -*- coding: utf-8 -*- | ||||
| import sys | |||||
| import os | import os | ||||
| import re | import re | ||||
| import sys | |||||
| if not 'READTHEDOCS' in os.environ: | if not 'READTHEDOCS' in os.environ: | ||||
| sys.path.insert(0, os.path.abspath('..')) | sys.path.insert(0, os.path.abspath('..')) | ||||
| @@ -11,7 +11,6 @@ sys.path.append(os.path.abspath('./ABL/')) | |||||
| # from sphinx.locale import _ | # from sphinx.locale import _ | ||||
| from sphinx_rtd_theme import __version__ | from sphinx_rtd_theme import __version__ | ||||
| project = u'ABL' | project = u'ABL' | ||||
| slug = re.sub(r'\W+', '-', project.lower()) | slug = re.sub(r'\W+', '-', project.lower()) | ||||
| author = u'Yu-Xuan Huang, Wen-Chao Hu, En-Hao Gao' | author = u'Yu-Xuan Huang, Wen-Chao Hu, En-Hao Gao' | ||||
| @@ -1,18 +1,18 @@ | |||||
| import os | import os | ||||
| from collections import defaultdict | from collections import defaultdict | ||||
| import torch | import torch | ||||
| from torch.utils.data import DataLoader | from torch.utils.data import DataLoader | ||||
| from abl.reasoning import ReasonerBase | |||||
| from abl.learning import ABLModel, BasicNN | |||||
| from abl.bridge import SimpleBridge | from abl.bridge import SimpleBridge | ||||
| from abl.evaluation import BaseMetric | |||||
| from abl.dataset import BridgeDataset, RegressionDataset | from abl.dataset import BridgeDataset, RegressionDataset | ||||
| from abl.evaluation import BaseMetric | |||||
| from abl.learning import ABLModel, BasicNN | |||||
| from abl.reasoning import ReasonerBase | |||||
| from abl.utils import print_log | from abl.utils import print_log | ||||
| from examples.hed.utils import gen_mappings, InfiniteSampler | |||||
| from examples.models.nn import SymbolNetAutoencoder | |||||
| from examples.hed.datasets.get_hed import get_pretrain_data | from examples.hed.datasets.get_hed import get_pretrain_data | ||||
| from examples.hed.utils import InfiniteSampler, gen_mappings | |||||
| from examples.models.nn import SymbolNetAutoencoder | |||||
| class HEDBridge(SimpleBridge): | class HEDBridge(SimpleBridge): | ||||
| @@ -1,6 +1,6 @@ | |||||
| import numpy as np | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import numpy as np | |||||
| import torch.utils.data.sampler as sampler | import torch.utils.data.sampler as sampler | ||||
| @@ -1,6 +1,7 @@ | |||||
| import torchvision | import torchvision | ||||
| from torchvision.transforms import transforms | from torchvision.transforms import transforms | ||||
| def get_data(file, img_dataset, get_pseudo_label): | def get_data(file, img_dataset, get_pseudo_label): | ||||
| X = [] | X = [] | ||||
| if get_pseudo_label: | if get_pseudo_label: | ||||
| @@ -11,8 +11,8 @@ | |||||
| # ================================================================# | # ================================================================# | ||||
| import torch | |||||
| import numpy as np | import numpy as np | ||||
| import torch | |||||
| from torch import nn | from torch import nn | ||||
| @@ -1,4 +1,5 @@ | |||||
| import os | import os | ||||
| from setuptools import find_packages, setup | from setuptools import find_packages, setup | ||||