Browse Source

[ENH] sort import

ab_data
Gao Enhao 2 years ago
parent
commit
95fa385ce6
20 changed files with 46 additions and 50 deletions
  1. +1
    -1
      abl/__init__.py
  2. +2
    -2
      abl/bridge/base_bridge.py
  3. +2
    -1
      abl/dataset/bridge_dataset.py
  4. +2
    -1
      abl/dataset/classification_dataset.py
  5. +2
    -1
      abl/dataset/regression_dataset.py
  6. +1
    -1
      abl/evaluation/__init__.py
  7. +2
    -2
      abl/evaluation/base_metric.py
  8. +1
    -0
      abl/evaluation/semantics_metric.py
  9. +2
    -1
      abl/evaluation/symbol_metric.py
  10. +6
    -5
      abl/learning/basic_nn.py
  11. +1
    -1
      abl/reasoning/__init__.py
  12. +7
    -15
      abl/reasoning/kb.py
  13. +4
    -8
      abl/reasoning/reasoner.py
  14. +2
    -1
      abl/utils/utils.py
  15. +1
    -2
      docs/conf.py
  16. +6
    -6
      examples/hed/hed_bridge.py
  17. +1
    -1
      examples/hed/utils.py
  18. +1
    -0
      examples/mnist_add/datasets/get_mnist_add.py
  19. +1
    -1
      examples/models/nn.py
  20. +1
    -0
      setup.py

+ 1
- 1
abl/__init__.py View File

@@ -1,2 +1,2 @@
from .learning import abl_model, basic_nn
from .reasoning import reasoner, kb
from .reasoning import kb, reasoner

+ 2
- 2
abl/bridge/base_bridge.py View File

@@ -1,11 +1,10 @@
from abc import ABCMeta, abstractmethod
from typing import Any, List, Tuple, Optional, Union
from typing import Any, List, Optional, Tuple, Union

from ..learning import ABLModel
from ..reasoning import ReasonerBase
from ..structures import ListData


DataSet = Tuple[List[List[Any]], Optional[List[List[Any]]], List[List[Any]]]


@@ -37,6 +36,7 @@ class BaseBridge(metaclass=ABCMeta):
@abstractmethod
def abduce_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:
"""Placeholder for abduce pseudo labels."""
pass

@abstractmethod
def idx_to_pseudo_label(self, data_samples: ListData) -> List[List[Any]]:


+ 2
- 1
abl/dataset/bridge_dataset.py View File

@@ -1,5 +1,6 @@
from typing import Any, List, Tuple

from torch.utils.data import Dataset
from typing import List, Any, Tuple


class BridgeDataset(Dataset):


+ 2
- 1
abl/dataset/classification_dataset.py View File

@@ -1,6 +1,7 @@
from typing import Any, Callable, List, Tuple

import torch
from torch.utils.data import Dataset
from typing import List, Any, Tuple, Callable


class ClassificationDataset(Dataset):


+ 2
- 1
abl/dataset/regression_dataset.py View File

@@ -1,6 +1,7 @@
from typing import Any, List, Tuple

import torch
from torch.utils.data import Dataset
from typing import List, Any, Tuple


class RegressionDataset(Dataset):


+ 1
- 1
abl/evaluation/__init__.py View File

@@ -1,3 +1,3 @@
from .base_metric import BaseMetric
from .symbol_metric import SymbolMetric
from .semantics_metric import SemanticsMetric
from .symbol_metric import SymbolMetric

+ 2
- 2
abl/evaluation/base_metric.py View File

@@ -1,8 +1,8 @@
import logging
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence
from ..utils import print_log

import logging
from ..utils import print_log


class BaseMetric(metaclass=ABCMeta):


+ 1
- 0
abl/evaluation/semantics_metric.py View File

@@ -1,4 +1,5 @@
from typing import Optional, Sequence

from .base_metric import BaseMetric




+ 2
- 1
abl/evaluation/symbol_metric.py View File

@@ -1,4 +1,5 @@
from typing import Optional, Sequence, Callable
from typing import Callable, Optional, Sequence

from .base_metric import BaseMetric




+ 6
- 5
abl/learning/basic_nn.py View File

@@ -10,14 +10,15 @@
#
# ================================================================#

import torch
import os
from typing import Any, Callable, List, Optional, T, Tuple

import numpy
import torch
from torch.utils.data import DataLoader
from ..utils.logger import print_log
from ..dataset import ClassificationDataset

import os
from typing import List, Any, T, Optional, Callable, Tuple
from ..dataset import ClassificationDataset
from ..utils.logger import print_log


class BasicNN:


+ 1
- 1
abl/reasoning/__init__.py View File

@@ -1,2 +1,2 @@
from .kb import KBBase, ground_KB, prolog_KB
from .reasoner import ReasonerBase
from .kb import KBBase, prolog_KB

+ 7
- 15
abl/reasoning/kb.py View File

@@ -1,24 +1,16 @@
from abc import ABC, abstractmethod
import bisect
import numpy as np

from abc import ABC, abstractmethod
from collections import defaultdict
from itertools import product, combinations

from ..utils.utils import (
flatten,
reform_idx,
hamming_dist,
check_equal,
to_hashable,
hashable_to_list,
)

from functools import lru_cache
from itertools import combinations, product
from multiprocessing import Pool

from functools import lru_cache
import numpy as np
import pyswip

from ..utils.utils import (check_equal, flatten, hamming_dist,
hashable_to_list, reform_idx, to_hashable)


class KBBase(ABC):
def __init__(self, pseudo_label_list, max_err=0, use_cache=True):


+ 4
- 8
abl/reasoning/reasoner.py View File

@@ -1,12 +1,8 @@
import numpy as np
from zoopt import Dimension, Objective, Parameter, Opt
from ..utils.utils import (
confidence_dist,
flatten,
reform_idx,
hamming_dist,
calculate_revision_num,
)
from zoopt import Dimension, Objective, Opt, Parameter

from ..utils.utils import (calculate_revision_num, confidence_dist, flatten,
hamming_dist, reform_idx)


class ReasonerBase:


+ 2
- 1
abl/utils/utils.py View File

@@ -1,6 +1,7 @@
import numpy as np
from itertools import chain

import numpy as np


def flatten(nested_list):
"""


+ 1
- 2
docs/conf.py View File

@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-

import sys
import os
import re
import sys

if not 'READTHEDOCS' in os.environ:
sys.path.insert(0, os.path.abspath('..'))
@@ -11,7 +11,6 @@ sys.path.append(os.path.abspath('./ABL/'))
# from sphinx.locale import _
from sphinx_rtd_theme import __version__


project = u'ABL'
slug = re.sub(r'\W+', '-', project.lower())
author = u'Yu-Xuan Huang, Wen-Chao Hu, En-Hao Gao'


+ 6
- 6
examples/hed/hed_bridge.py View File

@@ -1,18 +1,18 @@
import os
from collections import defaultdict

import torch
from torch.utils.data import DataLoader

from abl.reasoning import ReasonerBase
from abl.learning import ABLModel, BasicNN
from abl.bridge import SimpleBridge
from abl.evaluation import BaseMetric
from abl.dataset import BridgeDataset, RegressionDataset
from abl.evaluation import BaseMetric
from abl.learning import ABLModel, BasicNN
from abl.reasoning import ReasonerBase
from abl.utils import print_log

from examples.hed.utils import gen_mappings, InfiniteSampler
from examples.models.nn import SymbolNetAutoencoder
from examples.hed.datasets.get_hed import get_pretrain_data
from examples.hed.utils import InfiniteSampler, gen_mappings
from examples.models.nn import SymbolNetAutoencoder


class HEDBridge(SimpleBridge):


+ 1
- 1
examples/hed/utils.py View File

@@ -1,6 +1,6 @@
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import torch.utils.data.sampler as sampler




+ 1
- 0
examples/mnist_add/datasets/get_mnist_add.py View File

@@ -1,6 +1,7 @@
import torchvision
from torchvision.transforms import transforms


def get_data(file, img_dataset, get_pseudo_label):
X = []
if get_pseudo_label:


+ 1
- 1
examples/models/nn.py View File

@@ -11,8 +11,8 @@
# ================================================================#


import torch
import numpy as np
import torch
from torch import nn




+ 1
- 0
setup.py View File

@@ -1,4 +1,5 @@
import os

from setuptools import find_packages, setup




Loading…
Cancel
Save