Browse Source

Merge branch 'main' of https://github.com/Learnware-LAMDA/Learnware into search_result

tags/v0.3.2
bxdd 2 years ago
parent
commit
ca7feb85dd
18 changed files with 65 additions and 175 deletions
  1. +2
    -2
      learnware/__init__.py
  2. +1
    -1
      learnware/market/anchor/__init__.py
  3. +1
    -1
      learnware/market/easy/__init__.py
  4. +6
    -26
      learnware/reuse/__init__.py
  5. +18
    -1
      learnware/reuse/ensemble_pruning.py
  6. +6
    -2
      learnware/reuse/job_selector.py
  7. +0
    -23
      learnware/reuse/utils.py
  8. +0
    -2
      learnware/specification/regular/__init__.py
  9. +2
    -20
      learnware/specification/regular/image/__init__.py
  10. +11
    -4
      learnware/specification/regular/image/rkme.py
  11. +0
    -23
      learnware/specification/regular/image/utils.py
  12. +3
    -16
      learnware/specification/regular/table/__init__.py
  13. +6
    -1
      learnware/specification/regular/table/rkme.py
  14. +0
    -15
      learnware/specification/regular/table/utils.py
  15. +2
    -21
      learnware/specification/regular/text/__init__.py
  16. +6
    -1
      learnware/specification/regular/text/rkme.py
  17. +0
    -15
      learnware/specification/regular/text/utils.py
  18. +1
    -1
      learnware/specification/system/__init__.py

+ 2
- 2
learnware/__init__.py View File

@@ -1,4 +1,4 @@
__version__ = "0.2.0.3"
__version__ = "0.2.0.4"

import os
import json
@@ -55,7 +55,7 @@ def init(verbose=True, **kwargs):

if not is_torch_available(verbose=False):
logger.warning(
"The ability of learnware is limited due to 'torch' is not installed! Only the core framework is available now."
"The learnware package's capabilities are restricted because 'torch' is not installed. Only the core framework is available now."
)

# default init package


+ 1
- 1
learnware/market/anchor/__init__.py View File

@@ -8,6 +8,6 @@ logger = get_module_logger("market_anchor")

if not is_torch_available(verbose=False):
AnchoredSearcher = None
logger.warning("AnchoredSearcher is skipped because 'torch' is not installed!")
logger.error("AnchoredSearcher is not available because 'torch' is not installed!")
else:
from .searcher import AnchoredSearcher

+ 1
- 1
learnware/market/easy/__init__.py View File

@@ -9,7 +9,7 @@ if not is_torch_available(verbose=False):
EasySearcher = None
EasySemanticChecker = None
EasyStatChecker = None
logger.warning("EasySeacher and EasyChecker are skipped because 'torch' is not installed!")
logger.error("EasySeacher and EasyChecker are not available because 'torch' is not installed!")
else:
from .searcher import EasySearcher, EasyStatSearcher, EasyFuzzSemanticSearcher, EasyExactSemanticSearcher
from .checker import EasySemanticChecker, EasyStatChecker

+ 6
- 26
learnware/reuse/__init__.py View File

@@ -2,43 +2,23 @@ from .base import BaseReuser
from .align import AlignLearnware

from ..logger import get_module_logger
from ..utils import is_torch_available, get_platform, SystemType
from .utils import is_geatpy_available, is_lightgbm_available
from ..utils import is_torch_available

logger = get_module_logger("reuse")

if not is_geatpy_available(verbose=False):
EnsemblePruningReuser = None
logger.warning("EnsemblePruningReuser is skipped due to 'geatpy' is not installed!")
else:
from .ensemble_pruning import EnsemblePruningReuser

if not is_torch_available(verbose=False):
AveragingReuser = None
FeatureAugmentReuser = None
HeteroMapAlignLearnware = None
FeatureAlignLearnware = None
logger.warning(
"[AveragingReuser, FeatureAugmentReuser, HeteroMapAlignLearnware, FeatureAlignLearnware] is skipped due to 'torch' is not installed!"
JobSelectorReuser = None
EnsemblePruningReuser = None
logger.error(
"[AveragingReuser, FeatureAugmentReuser, HeteroMapAlignLearnware, FeatureAlignLearnware, JobSelectorReuser, EnsemblePruningReuser] are not available due to 'torch' is not installed!"
)
else:
from .averaging import AveragingReuser
from .feature_augment import FeatureAugmentReuser
from .hetero import HeteroMapAlignLearnware, FeatureAlignLearnware

if not is_lightgbm_available(verbose=False) or not is_torch_available(verbose=False):
JobSelectorReuser = None
uninstall_packages = [
value
for flag, value in zip(
[
is_lightgbm_available(verbose=False),
is_torch_available(verbose=False),
],
["lightgbm", "torch"],
)
if flag is False
]
logger.warning(f"JobSelectorReuser is skipped due to {uninstall_packages} is not installed!")
else:
from .job_selector import JobSelectorReuser
from .ensemble_pruning import EnsemblePruningReuser

+ 18
- 1
learnware/reuse/ensemble_pruning.py View File

@@ -1,7 +1,6 @@
import torch
import random
import numpy as np
import geatpy as ea
from typing import List

from ..learnware import Learnware
@@ -54,6 +53,13 @@ class EnsemblePruningReuser(BaseReuser):
np.ndarray
Binary one-dimensional vector, 1 indicates that the corresponding model is selected.
"""
try:
import geatpy as ea
except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
model_num = v_predict.shape[1]

@ea.Problem.single
@@ -138,6 +144,12 @@ class EnsemblePruningReuser(BaseReuser):
np.ndarray
Binary one-dimensional vector, 1 indicates that the corresponding model is selected.
"""
try:
import geatpy as ea
except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")

model_num = v_predict.shape[1]

def find_top_two_freq(row):
@@ -252,6 +264,11 @@ class EnsemblePruningReuser(BaseReuser):
np.ndarray
Binary one-dimensional vector, 1 indicates that the corresponding model is selected.
"""
try:
import geatpy as ea
except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
model_num = v_predict.shape[1]
v_predict[v_predict == 0.0] = -1
v_true[v_true == 0.0] = -1


+ 6
- 2
learnware/reuse/job_selector.py View File

@@ -2,7 +2,6 @@ import torch
import numpy as np

from typing import List, Union
from lightgbm import LGBMClassifier, early_stopping
from sklearn.metrics import accuracy_score

from .base import BaseReuser
@@ -196,7 +195,7 @@ class JobSelectorReuser(BaseReuser):
val_x: np.ndarray,
val_y: np.ndarray,
num_class: int,
) -> LGBMClassifier:
):
"""Train a LGBMClassifier as job selector using the herding data as training instances.

Parameters
@@ -221,6 +220,11 @@ class JobSelectorReuser(BaseReuser):
LGBMClassifier
The job selector model.
"""
try:
from lightgbm import LGBMClassifier, early_stopping
except ModuleNotFoundError:
raise ModuleNotFoundError(f"JobSelectorReuser is not available because 'lightgbm' is not installed! Please install it manually.")
score_best = -1
learning_rate = [0.01]
max_depth = [66]


+ 0
- 23
learnware/reuse/utils.py View File

@@ -3,29 +3,6 @@ from ..logger import get_module_logger

logger = get_module_logger("reuse_utils")


def is_geatpy_available(verbose=False):
try:
import geatpy
except ModuleNotFoundError as err:
if verbose is True:
logger.warning(
"ModuleNotFoundError: geatpy is not installed, please install geatpy (only support python version<3.11)!"
)
return False
return True


def is_lightgbm_available(verbose=False):
try:
import lightgbm
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: lightgbm is not installed, please install lightgbm!")
return False
return True


def fill_data_with_mean(X: np.ndarray) -> np.ndarray:
"""
Fill missing data (NaN, Inf) in the input array with the mean of the column.


+ 0
- 2
learnware/specification/regular/__init__.py View File

@@ -1,6 +1,4 @@
from .base import RegularStatSpecification
from ...utils import is_torch_available

from .text import RKMETextSpecification
from .table import RKMETableSpecification, RKMEStatSpecification, rkme_solve_qp
from .image import RKMEImageSpecification

+ 2
- 20
learnware/specification/regular/image/__init__.py View File

@@ -1,29 +1,11 @@
from .utils import is_torch_optimizer_available, is_torchvision_available
from ....utils import is_torch_available
from ....logger import get_module_logger


logger = get_module_logger("regular_image_spec")

if (
not is_torchvision_available(verbose=False)
or not is_torch_optimizer_available(verbose=False)
or not is_torch_available(verbose=False)
):
if not is_torch_available(verbose=False):
RKMEImageSpecification = None
uninstall_packages = [
value
for flag, value in zip(
[
is_torchvision_available(verbose=False),
is_torch_optimizer_available(verbose=False),
is_torch_available(verbose=False),
],
["torchvision", "torch-optimizer", "torch"],
)
if flag is False
]

logger.warning(f"RKMEImageSpecification is skipped because {uninstall_packages} is not installed!")
logger.error(f"RKMEImageSpecification is not available because 'torch' is not installed!")
else:
from .rkme import RKMEImageSpecification

+ 11
- 4
learnware/specification/regular/image/rkme.py View File

@@ -10,10 +10,8 @@ from typing import Any

import numpy as np
import torch
import torch_optimizer
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torchvision.transforms import Resize
from tqdm import tqdm

from . import cnn_gp
@@ -126,7 +124,11 @@ class RKMEImageSpecification(RegularStatSpecification):
raise ValueError(f"All values in image {i} are exceptional, e.g., NaN and Inf.")
img_mean = torch.nanmean(img)
X[i] = torch.where(is_nan, img_mean, img)

try:
from torchvision.transforms import Resize
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." )
if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH:
X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None)(X)

@@ -152,7 +154,12 @@ class RKMEImageSpecification(RegularStatSpecification):
with torch.no_grad():
x_features = self._generate_random_feature(X_train, random_models=random_models)
self._update_beta(x_features, nonnegative_beta, random_models=random_models)

try:
import torch_optimizer
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually.")
optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16)

for _ in tqdm(range(steps)) if verbose else range(steps):


+ 0
- 23
learnware/specification/regular/image/utils.py View File

@@ -1,23 +0,0 @@
from ....logger import get_module_logger

logger = get_module_logger("regular_image_spec_utils")


def is_torch_optimizer_available(verbose=False):
try:
import torch_optimizer
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: torch_optimizer is not installed, please install torch_optimizer!")
return False
return True


def is_torchvision_available(verbose=False):
try:
import torchvision
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: torchvision is not installed, please install torchvision!")
return False
return True

+ 3
- 16
learnware/specification/regular/table/__init__.py View File

@@ -1,27 +1,14 @@
from .utils import is_fast_pytorch_kmeans_available

from ....utils import is_torch_available
from ....logger import get_module_logger

logger = get_module_logger("regular_table_spec")

if not is_torch_available(verbose=False) or not is_fast_pytorch_kmeans_available(verbose=False):
if not is_torch_available(verbose=False):
RKMETableSpecification = None
RKMEStatSpecification = None
rkme_solve_qp = None
uninstall_packages = [
value
for flag, value in zip(
[
is_torch_available(verbose=False),
is_fast_pytorch_kmeans_available(verbose=False),
],
["torch", "fast_pytorch_kmeans"],
)
if flag is False
]
logger.warning(
f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are skipped because {uninstall_packages} is not installed!"
logger.error(
f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are not available because 'torch' is not installed!"
)
else:
from .rkme import RKMETableSpecification, RKMEStatSpecification, rkme_solve_qp

+ 6
- 1
learnware/specification/regular/table/rkme.py View File

@@ -9,7 +9,6 @@ import numpy as np
from qpsolvers import solve_qp, Problem, solve_problem
from collections import Counter
from typing import Any, Union
from fast_pytorch_kmeans import KMeans

from ..base import RegularStatSpecification
from ....logger import get_module_logger
@@ -143,6 +142,12 @@ class RKMETableSpecification(RegularStatSpecification):
X = torch.from_numpy(X)
X = X.to(self._device)
try:
from fast_pytorch_kmeans import KMeans
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." )

kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0)
kmeans.fit(X)
self.z = kmeans.centroids.double()


+ 0
- 15
learnware/specification/regular/table/utils.py View File

@@ -1,15 +0,0 @@
from ....logger import get_module_logger

logger = get_module_logger("regular_table_spec_utils")


def is_fast_pytorch_kmeans_available(verbose=False):
try:
import fast_pytorch_kmeans
except ModuleNotFoundError as err:
if verbose is True:
logger.warning(
"ModuleNotFoundError: fast_pytorch_kmeans is not installed, please install fast_pytorch_kmeans!"
)
return False
return True

+ 2
- 21
learnware/specification/regular/text/__init__.py View File

@@ -1,29 +1,10 @@
from .utils import is_sentence_transformers_available
from ..table.utils import is_fast_pytorch_kmeans_available

from ....utils import is_torch_available
from ....logger import get_module_logger

logger = get_module_logger("regular_text_spec")

if (
not is_sentence_transformers_available(verbose=False)
or not is_torch_available(verbose=False)
or not is_fast_pytorch_kmeans_available(verbose=False)
):
if not is_torch_available(verbose=False):
RKMETextSpecification = None
uninstall_packages = [
value
for flag, value in zip(
[
is_sentence_transformers_available(verbose=False),
is_torch_available(verbose=False),
is_fast_pytorch_kmeans_available(verbose=False),
],
["sentence_transformers", "torch", "fast_pytorch_kmeans"],
)
if flag is False
]
logger.warning(f"RKMETextSpecification is skipped because {uninstall_packages} is not installed!")
logger.error(f"RKMETextSpecification is not available because 'torch' is not installed!")
else:
from .rkme import RKMETextSpecification

+ 6
- 1
learnware/specification/regular/text/rkme.py View File

@@ -1,7 +1,6 @@
import os
import langdetect
import numpy as np
from sentence_transformers import SentenceTransformer

from ..table import RKMETableSpecification
from ....logger import get_module_logger
@@ -87,6 +86,12 @@ class RKMETextSpecification(RKMETableSpecification):
return np.array(miniLM_learnware.predict(X))

logger.info("Load the necessary feature extractor for RKMETextSpecification.")
try:
from sentence_transformers import SentenceTransformer
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually.")
if os.path.exists(zip_path):
X = _get_from_client(zip_path, X)
else:


+ 0
- 15
learnware/specification/regular/text/utils.py View File

@@ -1,15 +0,0 @@
from ....logger import get_module_logger

logger = get_module_logger("regular_text_spec_utils")


def is_sentence_transformers_available(verbose=False):
try:
import sentence_transformers
except ModuleNotFoundError as err:
if verbose is True:
logger.warning(
"ModuleNotFoundError: sentence_transformers is not installed, please install sentence_transformers!"
)
return False
return True

+ 1
- 1
learnware/specification/system/__init__.py View File

@@ -6,6 +6,6 @@ logger = get_module_logger("system_spec")

if not is_torch_available(verbose=False):
HeteroMapTableSpecification = None
logger.warning("HeteroMapTableSpecification is skipped because torch is not installed!")
logger.error("HeteroMapTableSpecification is not available because 'torch' is not installed!")
else:
from .hetero_table import HeteroMapTableSpecification

Loading…
Cancel
Save