Browse Source

[MNT] import extended package locally

tags/v0.3.2
bxdd 2 years ago
parent
commit
068f18fc5b
13 changed files with 46 additions and 96 deletions
  1. +2
    -26
      learnware/reuse/__init__.py
  2. +18
    -1
      learnware/reuse/ensemble_pruning.py
  3. +6
    -2
      learnware/reuse/job_selector.py
  4. +0
    -23
      learnware/reuse/utils.py
  5. +0
    -2
      learnware/specification/regular/__init__.py
  6. +0
    -1
      learnware/specification/regular/image/__init__.py
  7. +7
    -3
      learnware/specification/regular/image/rkme.py
  8. +0
    -2
      learnware/specification/regular/table/__init__.py
  9. +6
    -1
      learnware/specification/regular/table/rkme.py
  10. +0
    -15
      learnware/specification/regular/table/utils.py
  11. +1
    -4
      learnware/specification/regular/text/__init__.py
  12. +6
    -1
      learnware/specification/regular/text/rkme.py
  13. +0
    -15
      learnware/specification/regular/text/utils.py

+ 2
- 26
learnware/reuse/__init__.py View File

@@ -3,16 +3,9 @@ from .align import AlignLearnware

from ..logger import get_module_logger
from ..utils import is_torch_available, get_platform, SystemType
from .utils import is_geatpy_available, is_lightgbm_available

logger = get_module_logger("reuse")

if not is_geatpy_available(verbose=False):
EnsemblePruningReuser = None
logger.warning("EnsemblePruningReuser is skipped due to 'geatpy' is not installed!")
else:
from .ensemble_pruning import EnsemblePruningReuser

if not is_torch_available(verbose=False):
AveragingReuser = None
FeatureAugmentReuser = None
@@ -20,27 +13,10 @@ if not is_torch_available(verbose=False):
FeatureAlignLearnware = None
JobSelectorReuser = None
logger.error(
"[AveragingReuser, FeatureAugmentReuser, HeteroMapAlignLearnware, FeatureAlignLearnware, JobSelectorReuser] is skipped due to 'torch' is not installed!"
"[AveragingReuser, FeatureAugmentReuser, HeteroMapAlignLearnware, FeatureAlignLearnware, JobSelectorReuser] are skipped due to 'torch' is not installed!"
)
else:
from .averaging import AveragingReuser
from .feature_augment import FeatureAugmentReuser
from .hetero import HeteroMapAlignLearnware, FeatureAlignLearnware
from .job_selector import JobSelectorReuser

if not is_lightgbm_available(verbose=False) or not is_torch_available(verbose=False):
JobSelectorReuser = None
uninstall_packages = [
value
for flag, value in zip(
[
is_lightgbm_available(verbose=False),
is_torch_available(verbose=False),
],
["lightgbm", "torch"],
)
if flag is False
]
logger.error(f"JobSelectorReuser is skipped due to {uninstall_packages} is not installed!")
else:
from .job_selector import JobSelectorReuser

+ 18
- 1
learnware/reuse/ensemble_pruning.py View File

@@ -1,7 +1,6 @@
import torch
import random
import numpy as np
import geatpy as ea
from typing import List

from ..learnware import Learnware
@@ -54,6 +53,13 @@ class EnsemblePruningReuser(BaseReuser):
np.ndarray
Binary one-dimensional vector, 1 indicates that the corresponding model is selected.
"""
try:
import geatpy as ea
except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
model_num = v_predict.shape[1]

@ea.Problem.single
@@ -138,6 +144,12 @@ class EnsemblePruningReuser(BaseReuser):
np.ndarray
Binary one-dimensional vector, 1 indicates that the corresponding model is selected.
"""
try:
import geatpy as ea
except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")

model_num = v_predict.shape[1]

def find_top_two_freq(row):
@@ -252,6 +264,11 @@ class EnsemblePruningReuser(BaseReuser):
np.ndarray
Binary one-dimensional vector, 1 indicates that the corresponding model is selected.
"""
try:
import geatpy as ea
except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
model_num = v_predict.shape[1]
v_predict[v_predict == 0.0] = -1
v_true[v_true == 0.0] = -1


+ 6
- 2
learnware/reuse/job_selector.py View File

@@ -2,7 +2,6 @@ import torch
import numpy as np

from typing import List, Union
from lightgbm import LGBMClassifier, early_stopping
from sklearn.metrics import accuracy_score

from .base import BaseReuser
@@ -196,7 +195,7 @@ class JobSelectorReuser(BaseReuser):
val_x: np.ndarray,
val_y: np.ndarray,
num_class: int,
) -> LGBMClassifier:
):
"""Train a LGBMClassifier as job selector using the herding data as training instances.

Parameters
@@ -221,6 +220,11 @@ class JobSelectorReuser(BaseReuser):
LGBMClassifier
The job selector model.
"""
try:
from lightgbm import LGBMClassifier, early_stopping
except ModuleNotFoundError:
raise ModuleNotFoundError(f"JobSelectorReuser is not available because 'lightgbm' is not installed! Please install it manually.")
score_best = -1
learning_rate = [0.01]
max_depth = [66]


+ 0
- 23
learnware/reuse/utils.py View File

@@ -3,29 +3,6 @@ from ..logger import get_module_logger

logger = get_module_logger("reuse_utils")


def is_geatpy_available(verbose=False):
try:
import geatpy
except ModuleNotFoundError as err:
if verbose is True:
logger.warning(
"ModuleNotFoundError: geatpy is not installed, please install geatpy (only support python version<3.11)!"
)
return False
return True


def is_lightgbm_available(verbose=False):
try:
import lightgbm
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: lightgbm is not installed, please install lightgbm!")
return False
return True


def fill_data_with_mean(X: np.ndarray) -> np.ndarray:
"""
Fill missing data (NaN, Inf) in the input array with the mean of the column.


+ 0
- 2
learnware/specification/regular/__init__.py View File

@@ -1,6 +1,4 @@
from .base import RegularStatSpecification
from ...utils import is_torch_available

from .text import RKMETextSpecification
from .table import RKMETableSpecification, RKMEStatSpecification, rkme_solve_qp
from .image import RKMEImageSpecification

+ 0
- 1
learnware/specification/regular/image/__init__.py View File

@@ -1,4 +1,3 @@
from .utils import is_torch_optimizer_available, is_torchvision_available
from ....utils import is_torch_available
from ....logger import get_module_logger



+ 7
- 3
learnware/specification/regular/image/rkme.py View File

@@ -10,7 +10,6 @@ from typing import Any

import numpy as np
import torch
import torch_optimizer
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
@@ -128,7 +127,7 @@ class RKMEImageSpecification(RegularStatSpecification):
try:
from torchvision.transforms import Resize
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torchvision' is not installed!")
raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." )
if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH:
X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None)(X)
@@ -155,7 +154,12 @@ class RKMEImageSpecification(RegularStatSpecification):
with torch.no_grad():
x_features = self._generate_random_feature(X_train, random_models=random_models)
self._update_beta(x_features, nonnegative_beta, random_models=random_models)

try:
import torch_optimizer
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually.")
optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16)

for _ in tqdm(range(steps)) if verbose else range(steps):


+ 0
- 2
learnware/specification/regular/table/__init__.py View File

@@ -1,5 +1,3 @@
from .utils import is_fast_pytorch_kmeans_available

from ....utils import is_torch_available
from ....logger import get_module_logger



+ 6
- 1
learnware/specification/regular/table/rkme.py View File

@@ -9,7 +9,6 @@ import numpy as np
from qpsolvers import solve_qp, Problem, solve_problem
from collections import Counter
from typing import Any, Union
from fast_pytorch_kmeans import KMeans

from ..base import RegularStatSpecification
from ....logger import get_module_logger
@@ -143,6 +142,12 @@ class RKMETableSpecification(RegularStatSpecification):
X = torch.from_numpy(X)
X = X.to(self._device)
try:
from fast_pytorch_kmeans import KMeans
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." )

kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0)
kmeans.fit(X)
self.z = kmeans.centroids.double()


+ 0
- 15
learnware/specification/regular/table/utils.py View File

@@ -1,15 +0,0 @@
from ....logger import get_module_logger

logger = get_module_logger("regular_table_spec_utils")


def is_fast_pytorch_kmeans_available(verbose=False):
try:
import fast_pytorch_kmeans
except ModuleNotFoundError as err:
if verbose is True:
logger.warning(
"ModuleNotFoundError: fast_pytorch_kmeans is not installed, please install fast_pytorch_kmeans!"
)
return False
return True

+ 1
- 4
learnware/specification/regular/text/__init__.py View File

@@ -1,6 +1,3 @@
from .utils import is_sentence_transformers_available
from ..table.utils import is_fast_pytorch_kmeans_available

from ....utils import is_torch_available
from ....logger import get_module_logger

@@ -8,6 +5,6 @@ logger = get_module_logger("regular_text_spec")

if not is_torch_available(verbose=False):
RKMETextSpecification = None
logger.warning(f"RKMETextSpecification is skipped because 'torch' is not installed!")
logger.error(f"RKMETextSpecification is skipped because 'torch' is not installed!")
else:
from .rkme import RKMETextSpecification

+ 6
- 1
learnware/specification/regular/text/rkme.py View File

@@ -1,7 +1,6 @@
import os
import langdetect
import numpy as np
from sentence_transformers import SentenceTransformer

from ..table import RKMETableSpecification
from ....logger import get_module_logger
@@ -87,6 +86,12 @@ class RKMETextSpecification(RKMETableSpecification):
return np.array(miniLM_learnware.predict(X))

logger.info("Load the necessary feature extractor for RKMETextSpecification.")
try:
from sentence_transformers import SentenceTransformer
except ModuleNotFoundError:
raise ModuleNotFoundError(f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually.")
if os.path.exists(zip_path):
X = _get_from_client(zip_path, X)
else:


+ 0
- 15
learnware/specification/regular/text/utils.py View File

@@ -1,15 +0,0 @@
from ....logger import get_module_logger

logger = get_module_logger("regular_text_spec_utils")


def is_sentence_transformers_available(verbose=False):
try:
import sentence_transformers
except ModuleNotFoundError as err:
if verbose is True:
logger.warning(
"ModuleNotFoundError: sentence_transformers is not installed, please install sentence_transformers!"
)
return False
return True

Loading…
Cancel
Save