Browse Source

[ENH] sort import

ab_data
Gao Enhao 2 years ago
parent
commit
95fa385ce6
20 changed files with 46 additions and 50 deletions
  1. +1
    -1
      abl/__init__.py
  2. +2
    -2
      abl/bridge/base_bridge.py
  3. +2
    -1
      abl/dataset/bridge_dataset.py
  4. +2
    -1
      abl/dataset/classification_dataset.py
  5. +2
    -1
      abl/dataset/regression_dataset.py
  6. +1
    -1
      abl/evaluation/__init__.py
  7. +2
    -2
      abl/evaluation/base_metric.py
  8. +1
    -0
      abl/evaluation/semantics_metric.py
  9. +2
    -1
      abl/evaluation/symbol_metric.py
  10. +6
    -5
      abl/learning/basic_nn.py
  11. +1
    -1
      abl/reasoning/__init__.py
  12. +7
    -15
      abl/reasoning/kb.py
  13. +4
    -8
      abl/reasoning/reasoner.py
  14. +2
    -1
      abl/utils/utils.py
  15. +1
    -2
      docs/conf.py
  16. +6
    -6
      examples/hed/hed_bridge.py
  17. +1
    -1
      examples/hed/utils.py
  18. +1
    -0
      examples/mnist_add/datasets/get_mnist_add.py
  19. +1
    -1
      examples/models/nn.py
  20. +1
    -0
      setup.py

+ 1
- 1
abl/__init__.py View File

@@ -1,2 +1,2 @@
from .learning import abl_model, basic_nn from .learning import abl_model, basic_nn
from .reasoning import reasoner, kb
from .reasoning import kb, reasoner

+ 2
- 2
abl/bridge/base_bridge.py View File

@@ -1,11 +1,10 @@
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Any, List, Tuple, Optional, Union
from typing import Any, List, Optional, Tuple, Union


from ..learning import ABLModel from ..learning import ABLModel
from ..reasoning import ReasonerBase from ..reasoning import ReasonerBase
from ..structures import ListData from ..structures import ListData



DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]] DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]]




@@ -37,6 +36,7 @@ class BaseBridge(metaclass=ABCMeta):
@abstractmethod @abstractmethod
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for abduce pseudo labels.""" """Placeholder for abduce pseudo labels."""
pass


@abstractmethod @abstractmethod
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]: def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:


+ 2
- 1
abl/dataset/bridge_dataset.py View File

@@ -1,5 +1,6 @@
from typing import Any, List, Tuple

from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import List, Any, Tuple




class BridgeDataset(Dataset): class BridgeDataset(Dataset):


+ 2
- 1
abl/dataset/classification_dataset.py View File

@@ -1,6 +1,7 @@
from typing import Any, Callable, List, Tuple

import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import List, Any, Tuple, Callable




class ClassificationDataset(Dataset): class ClassificationDataset(Dataset):


+ 2
- 1
abl/dataset/regression_dataset.py View File

@@ -1,6 +1,7 @@
from typing import Any, List, Tuple

import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import List, Any, Tuple




class RegressionDataset(Dataset): class RegressionDataset(Dataset):


+ 1
- 1
abl/evaluation/__init__.py View File

@@ -1,3 +1,3 @@
from .base_metric import BaseMetric from .base_metric import BaseMetric
from .symbol_metric import SymbolMetric
from .semantics_metric import SemanticsMetric from .semantics_metric import SemanticsMetric
from .symbol_metric import SymbolMetric

+ 2
- 2
abl/evaluation/base_metric.py View File

@@ -1,8 +1,8 @@
import logging
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence from typing import Any, List, Optional, Sequence
from ..utils import print_log


import logging
from ..utils import print_log




class BaseMetric(metaclass=ABCMeta): class BaseMetric(metaclass=ABCMeta):


+ 1
- 0
abl/evaluation/semantics_metric.py View File

@@ -1,4 +1,5 @@
from typing import Optional, Sequence from typing import Optional, Sequence

from .base_metric import BaseMetric from .base_metric import BaseMetric






+ 2
- 1
abl/evaluation/symbol_metric.py View File

@@ -1,4 +1,5 @@
from typing import Optional, Sequence, Callable
from typing import Callable, Optional, Sequence

from .base_metric import BaseMetric from .base_metric import BaseMetric






+ 6
- 5
abl/learning/basic_nn.py View File

@@ -10,14 +10,15 @@
# #
# ================================================================# # ================================================================#


import torch
import os
from typing import Any, Callable, List, Optional, T, Tuple

import numpy import numpy
import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from ..utils.logger import print_log
from ..dataset import ClassificationDataset


import os
from typing import List, Any, T, Optional, Callable, Tuple
from ..dataset import ClassificationDataset
from ..utils.logger import print_log




class BasicNN: class BasicNN:


+ 1
- 1
abl/reasoning/__init__.py View File

@@ -1,2 +1,2 @@
from .kb import KBBase, ground_KB, prolog_KB
from .reasoner import ReasonerBase from .reasoner import ReasonerBase
from .kb import KBBase, prolog_KB

+ 7
- 15
abl/reasoning/kb.py View File

@@ -1,24 +1,16 @@
from abc import ABC, abstractmethod
import bisect import bisect
import numpy as np

from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from itertools import product, combinations

from ..utils.utils import (
flatten,
reform_idx,
hamming_dist,
check_equal,
to_hashable,
hashable_to_list,
)

from functools import lru_cache
from itertools import combinations, product
from multiprocessing import Pool from multiprocessing import Pool


from functools import lru_cache
import numpy as np
import pyswip import pyswip


from ..utils.utils import (check_equal, flatten, hamming_dist,
hashable_to_list, reform_idx, to_hashable)



class KBBase(ABC): class KBBase(ABC):
def __init__(self, pseudo_label_list, max_err=0, use_cache=True): def __init__(self, pseudo_label_list, max_err=0, use_cache=True):


+ 4
- 8
abl/reasoning/reasoner.py View File

@@ -1,12 +1,8 @@
import numpy as np import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import (
confidence_dist,
flatten,
reform_idx,
hamming_dist,
calculate_revision_num,
)
from zoopt import Dimension, Objective, Opt, Parameter

from ..utils.utils import (calculate_revision_num, confidence_dist, flatten,
hamming_dist, reform_idx)




class ReasonerBase: class ReasonerBase:


+ 2
- 1
abl/utils/utils.py View File

@@ -1,6 +1,7 @@
import numpy as np
from itertools import chain from itertools import chain


import numpy as np



def flatten(nested_list): def flatten(nested_list):
""" """


+ 1
- 2
docs/conf.py View File

@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-


import sys
import os import os
import re import re
import sys


if not 'READTHEDOCS' in os.environ: if not 'READTHEDOCS' in os.environ:
sys.path.insert(0, os.path.abspath('..')) sys.path.insert(0, os.path.abspath('..'))
@@ -11,7 +11,6 @@ sys.path.append(os.path.abspath('./ABL/'))
# from sphinx.locale import _ # from sphinx.locale import _
from sphinx_rtd_theme import __version__ from sphinx_rtd_theme import __version__



project = u'ABL' project = u'ABL'
slug = re.sub(r'\W+', '-', project.lower()) slug = re.sub(r'\W+', '-', project.lower())
author = u'Yu-Xuan Huang, Wen-Chao Hu, En-Hao Gao' author = u'Yu-Xuan Huang, Wen-Chao Hu, En-Hao Gao'


+ 6
- 6
examples/hed/hed_bridge.py View File

@@ -1,18 +1,18 @@
import os import os
from collections import defaultdict from collections import defaultdict

import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader


from abl.reasoning import ReasonerBase
from abl.learning import ABLModel, BasicNN
from abl.bridge import SimpleBridge from abl.bridge import SimpleBridge
from abl.evaluation import BaseMetric
from abl.dataset import BridgeDataset, RegressionDataset from abl.dataset import BridgeDataset, RegressionDataset
from abl.evaluation import BaseMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import ReasonerBase
from abl.utils import print_log from abl.utils import print_log

from examples.hed.utils import gen_mappings, InfiniteSampler
from examples.models.nn import SymbolNetAutoencoder
from examples.hed.datasets.get_hed import get_pretrain_data from examples.hed.datasets.get_hed import get_pretrain_data
from examples.hed.utils import InfiniteSampler, gen_mappings
from examples.models.nn import SymbolNetAutoencoder




class HEDBridge(SimpleBridge): class HEDBridge(SimpleBridge):


+ 1
- 1
examples/hed/utils.py View File

@@ -1,6 +1,6 @@
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
import torch.utils.data.sampler as sampler import torch.utils.data.sampler as sampler






+ 1
- 0
examples/mnist_add/datasets/get_mnist_add.py View File

@@ -1,6 +1,7 @@
import torchvision import torchvision
from torchvision.transforms import transforms from torchvision.transforms import transforms



def get_data(file, img_dataset, get_pseudo_label): def get_data(file, img_dataset, get_pseudo_label):
X = [] X = []
if get_pseudo_label: if get_pseudo_label:


+ 1
- 1
examples/models/nn.py View File

@@ -11,8 +11,8 @@
# ================================================================# # ================================================================#




import torch
import numpy as np import numpy as np
import torch
from torch import nn from torch import nn






+ 1
- 0
setup.py View File

@@ -1,4 +1,5 @@
import os import os

from setuptools import find_packages, setup from setuptools import find_packages, setup






Loading…
Cancel
Save