diff --git a/learnware/__init__.py b/learnware/__init__.py index 322be04..c3f122e 100644 --- a/learnware/__init__.py +++ b/learnware/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.0.3" +__version__ = "0.2.0.4" import os import json @@ -55,7 +55,7 @@ def init(verbose=True, **kwargs): if not is_torch_available(verbose=False): logger.warning( - "The ability of learnware is limited due to 'torch' is not installed! Only the core framework is available now." + "The learnware package's capabilities are restricted because 'torch' is not installed. Only the core framework is available now." ) # default init package diff --git a/learnware/market/anchor/__init__.py b/learnware/market/anchor/__init__.py index d1b8392..2e2763c 100644 --- a/learnware/market/anchor/__init__.py +++ b/learnware/market/anchor/__init__.py @@ -8,6 +8,6 @@ logger = get_module_logger("market_anchor") if not is_torch_available(verbose=False): AnchoredSearcher = None - logger.warning("AnchoredSearcher is skipped because 'torch' is not installed!") + logger.error("AnchoredSearcher is not available because 'torch' is not installed!") else: from .searcher import AnchoredSearcher diff --git a/learnware/market/easy/__init__.py b/learnware/market/easy/__init__.py index 7495ee5..2605999 100644 --- a/learnware/market/easy/__init__.py +++ b/learnware/market/easy/__init__.py @@ -9,7 +9,7 @@ if not is_torch_available(verbose=False): EasySearcher = None EasySemanticChecker = None EasyStatChecker = None - logger.warning("EasySeacher and EasyChecker are skipped because 'torch' is not installed!") + logger.error("EasySeacher and EasyChecker are not available because 'torch' is not installed!") else: from .searcher import EasySearcher, EasyStatSearcher, EasyFuzzSemanticSearcher, EasyExactSemanticSearcher from .checker import EasySemanticChecker, EasyStatChecker diff --git a/learnware/reuse/__init__.py b/learnware/reuse/__init__.py index f379429..1e2d289 100644 --- a/learnware/reuse/__init__.py +++ b/learnware/reuse/__init__.py @@ -2,43 +2,23 @@ from .base import BaseReuser from .align import AlignLearnware from ..logger import get_module_logger -from ..utils import is_torch_available, get_platform, SystemType -from .utils import is_geatpy_available, is_lightgbm_available +from ..utils import is_torch_available logger = get_module_logger("reuse") -if not is_geatpy_available(verbose=False): - EnsemblePruningReuser = None - logger.warning("EnsemblePruningReuser is skipped due to 'geatpy' is not installed!") -else: - from .ensemble_pruning import EnsemblePruningReuser - if not is_torch_available(verbose=False): AveragingReuser = None FeatureAugmentReuser = None HeteroMapAlignLearnware = None FeatureAlignLearnware = None - logger.warning( - "[AveragingReuser, FeatureAugmentReuser, HeteroMapAlignLearnware, FeatureAlignLearnware] is skipped due to 'torch' is not installed!" + JobSelectorReuser = None + EnsemblePruningReuser = None + logger.error( + "[AveragingReuser, FeatureAugmentReuser, HeteroMapAlignLearnware, FeatureAlignLearnware, JobSelectorReuser, EnsemblePruningReuser] are not available due to 'torch' is not installed!" ) else: from .averaging import AveragingReuser from .feature_augment import FeatureAugmentReuser from .hetero import HeteroMapAlignLearnware, FeatureAlignLearnware - -if not is_lightgbm_available(verbose=False) or not is_torch_available(verbose=False): - JobSelectorReuser = None - uninstall_packages = [ - value - for flag, value in zip( - [ - is_lightgbm_available(verbose=False), - is_torch_available(verbose=False), - ], - ["lightgbm", "torch"], - ) - if flag is False - ] - logger.warning(f"JobSelectorReuser is skipped due to {uninstall_packages} is not installed!") -else: from .job_selector import JobSelectorReuser + from .ensemble_pruning import EnsemblePruningReuser \ No newline at end of file diff --git a/learnware/reuse/ensemble_pruning.py b/learnware/reuse/ensemble_pruning.py index f937dbd..49c65b5 100644 --- a/learnware/reuse/ensemble_pruning.py +++ b/learnware/reuse/ensemble_pruning.py @@ -1,7 +1,6 @@ import torch import random import numpy as np -import geatpy as ea from typing import List from ..learnware import Learnware @@ -54,6 +53,13 @@ class EnsemblePruningReuser(BaseReuser): np.ndarray Binary one-dimensional vector, 1 indicates that the corresponding model is selected. """ + + + try: + import geatpy as ea + except ModuleNotFoundError: + raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") + model_num = v_predict.shape[1] @ea.Problem.single @@ -138,6 +144,12 @@ class EnsemblePruningReuser(BaseReuser): np.ndarray Binary one-dimensional vector, 1 indicates that the corresponding model is selected. """ + try: + import geatpy as ea + except ModuleNotFoundError: + raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") + + model_num = v_predict.shape[1] def find_top_two_freq(row): @@ -252,6 +264,11 @@ class EnsemblePruningReuser(BaseReuser): np.ndarray Binary one-dimensional vector, 1 indicates that the corresponding model is selected. """ + try: + import geatpy as ea + except ModuleNotFoundError: + raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).") + model_num = v_predict.shape[1] v_predict[v_predict == 0.0] = -1 v_true[v_true == 0.0] = -1 diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py index 91ad512..467e063 100644 --- a/learnware/reuse/job_selector.py +++ b/learnware/reuse/job_selector.py @@ -2,7 +2,6 @@ import torch import numpy as np from typing import List, Union -from lightgbm import LGBMClassifier, early_stopping from sklearn.metrics import accuracy_score from .base import BaseReuser @@ -196,7 +195,7 @@ class JobSelectorReuser(BaseReuser): val_x: np.ndarray, val_y: np.ndarray, num_class: int, - ) -> LGBMClassifier: + ): """Train a LGBMClassifier as job selector using the herding data as training instances. Parameters @@ -221,6 +220,11 @@ class JobSelectorReuser(BaseReuser): LGBMClassifier The job selector model. """ + try: + from lightgbm import LGBMClassifier, early_stopping + except ModuleNotFoundError: + raise ModuleNotFoundError(f"JobSelectorReuser is not available because 'lightgbm' is not installed! Please install it manually.") + score_best = -1 learning_rate = [0.01] max_depth = [66] diff --git a/learnware/reuse/utils.py b/learnware/reuse/utils.py index a85135f..17430dc 100644 --- a/learnware/reuse/utils.py +++ b/learnware/reuse/utils.py @@ -3,29 +3,6 @@ from ..logger import get_module_logger logger = get_module_logger("reuse_utils") - -def is_geatpy_available(verbose=False): - try: - import geatpy - except ModuleNotFoundError as err: - if verbose is True: - logger.warning( - "ModuleNotFoundError: geatpy is not installed, please install geatpy (only support python version<3.11)!" - ) - return False - return True - - -def is_lightgbm_available(verbose=False): - try: - import lightgbm - except ModuleNotFoundError as err: - if verbose is True: - logger.warning("ModuleNotFoundError: lightgbm is not installed, please install lightgbm!") - return False - return True - - def fill_data_with_mean(X: np.ndarray) -> np.ndarray: """ Fill missing data (NaN, Inf) in the input array with the mean of the column. diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py index 0c8b05d..fc95950 100644 --- a/learnware/specification/regular/__init__.py +++ b/learnware/specification/regular/__init__.py @@ -1,6 +1,4 @@ from .base import RegularStatSpecification -from ...utils import is_torch_available - from .text import RKMETextSpecification from .table import RKMETableSpecification, RKMEStatSpecification, rkme_solve_qp from .image import RKMEImageSpecification diff --git a/learnware/specification/regular/image/__init__.py b/learnware/specification/regular/image/__init__.py index 224d875..d76b56b 100644 --- a/learnware/specification/regular/image/__init__.py +++ b/learnware/specification/regular/image/__init__.py @@ -1,29 +1,11 @@ -from .utils import is_torch_optimizer_available, is_torchvision_available from ....utils import is_torch_available from ....logger import get_module_logger logger = get_module_logger("regular_image_spec") -if ( - not is_torchvision_available(verbose=False) - or not is_torch_optimizer_available(verbose=False) - or not is_torch_available(verbose=False) -): +if not is_torch_available(verbose=False): RKMEImageSpecification = None - uninstall_packages = [ - value - for flag, value in zip( - [ - is_torchvision_available(verbose=False), - is_torch_optimizer_available(verbose=False), - is_torch_available(verbose=False), - ], - ["torchvision", "torch-optimizer", "torch"], - ) - if flag is False - ] - - logger.warning(f"RKMEImageSpecification is skipped because {uninstall_packages} is not installed!") + logger.error(f"RKMEImageSpecification is not available because 'torch' is not installed!") else: from .rkme import RKMEImageSpecification diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index e65f23d..1f9a4a1 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -10,10 +10,8 @@ from typing import Any import numpy as np import torch -import torch_optimizer from torch import nn from torch.utils.data import TensorDataset, DataLoader -from torchvision.transforms import Resize from tqdm import tqdm from . import cnn_gp @@ -126,7 +124,11 @@ class RKMEImageSpecification(RegularStatSpecification): raise ValueError(f"All values in image {i} are exceptional, e.g., NaN and Inf.") img_mean = torch.nanmean(img) X[i] = torch.where(is_nan, img_mean, img) - + try: + from torchvision.transforms import Resize + except ModuleNotFoundError: + raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torchvision' is not installed! Please install it manually." ) + if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None)(X) @@ -152,7 +154,12 @@ class RKMEImageSpecification(RegularStatSpecification): with torch.no_grad(): x_features = self._generate_random_feature(X_train, random_models=random_models) self._update_beta(x_features, nonnegative_beta, random_models=random_models) - + + try: + import torch_optimizer + except ModuleNotFoundError: + raise ModuleNotFoundError(f"RKMEImageSpecification is not available because 'torch-optimizer' is not installed! Please install it manually.") + optimizer = torch_optimizer.AdaBelief([{"params": [self.z]}], lr=step_size, eps=1e-16) for _ in tqdm(range(steps)) if verbose else range(steps): diff --git a/learnware/specification/regular/image/utils.py b/learnware/specification/regular/image/utils.py deleted file mode 100644 index cd60d96..0000000 --- a/learnware/specification/regular/image/utils.py +++ /dev/null @@ -1,23 +0,0 @@ -from ....logger import get_module_logger - -logger = get_module_logger("regular_image_spec_utils") - - -def is_torch_optimizer_available(verbose=False): - try: - import torch_optimizer - except ModuleNotFoundError as err: - if verbose is True: - logger.warning("ModuleNotFoundError: torch_optimizer is not installed, please install torch_optimizer!") - return False - return True - - -def is_torchvision_available(verbose=False): - try: - import torchvision - except ModuleNotFoundError as err: - if verbose is True: - logger.warning("ModuleNotFoundError: torchvision is not installed, please install torchvision!") - return False - return True diff --git a/learnware/specification/regular/table/__init__.py b/learnware/specification/regular/table/__init__.py index a380ea3..d816907 100644 --- a/learnware/specification/regular/table/__init__.py +++ b/learnware/specification/regular/table/__init__.py @@ -1,27 +1,14 @@ -from .utils import is_fast_pytorch_kmeans_available - from ....utils import is_torch_available from ....logger import get_module_logger logger = get_module_logger("regular_table_spec") -if not is_torch_available(verbose=False) or not is_fast_pytorch_kmeans_available(verbose=False): +if not is_torch_available(verbose=False): RKMETableSpecification = None RKMEStatSpecification = None rkme_solve_qp = None - uninstall_packages = [ - value - for flag, value in zip( - [ - is_torch_available(verbose=False), - is_fast_pytorch_kmeans_available(verbose=False), - ], - ["torch", "fast_pytorch_kmeans"], - ) - if flag is False - ] - logger.warning( - f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are skipped because {uninstall_packages} is not installed!" + logger.error( + f"RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are not available because 'torch' is not installed!" ) else: from .rkme import RKMETableSpecification, RKMEStatSpecification, rkme_solve_qp diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index 68cfc1f..f335f4d 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -9,7 +9,6 @@ import numpy as np from qpsolvers import solve_qp, Problem, solve_problem from collections import Counter from typing import Any, Union -from fast_pytorch_kmeans import KMeans from ..base import RegularStatSpecification from ....logger import get_module_logger @@ -143,6 +142,12 @@ class RKMETableSpecification(RegularStatSpecification): X = torch.from_numpy(X) X = X.to(self._device) + + try: + from fast_pytorch_kmeans import KMeans + except ModuleNotFoundError: + raise ModuleNotFoundError(f"RKMETableSpecification is not available because 'fast_pytorch_kmeans' is not installed! Please install it manually." ) + kmeans = KMeans(n_clusters=K, mode='euclidean', max_iter=100, verbose=0) kmeans.fit(X) self.z = kmeans.centroids.double() diff --git a/learnware/specification/regular/table/utils.py b/learnware/specification/regular/table/utils.py deleted file mode 100644 index 3243b72..0000000 --- a/learnware/specification/regular/table/utils.py +++ /dev/null @@ -1,15 +0,0 @@ -from ....logger import get_module_logger - -logger = get_module_logger("regular_table_spec_utils") - - -def is_fast_pytorch_kmeans_available(verbose=False): - try: - import fast_pytorch_kmeans - except ModuleNotFoundError as err: - if verbose is True: - logger.warning( - "ModuleNotFoundError: fast_pytorch_kmeans is not installed, please install fast_pytorch_kmeans!" - ) - return False - return True diff --git a/learnware/specification/regular/text/__init__.py b/learnware/specification/regular/text/__init__.py index eda4208..23f1a91 100644 --- a/learnware/specification/regular/text/__init__.py +++ b/learnware/specification/regular/text/__init__.py @@ -1,29 +1,10 @@ -from .utils import is_sentence_transformers_available -from ..table.utils import is_fast_pytorch_kmeans_available - from ....utils import is_torch_available from ....logger import get_module_logger logger = get_module_logger("regular_text_spec") -if ( - not is_sentence_transformers_available(verbose=False) - or not is_torch_available(verbose=False) - or not is_fast_pytorch_kmeans_available(verbose=False) -): +if not is_torch_available(verbose=False): RKMETextSpecification = None - uninstall_packages = [ - value - for flag, value in zip( - [ - is_sentence_transformers_available(verbose=False), - is_torch_available(verbose=False), - is_fast_pytorch_kmeans_available(verbose=False), - ], - ["sentence_transformers", "torch", "fast_pytorch_kmeans"], - ) - if flag is False - ] - logger.warning(f"RKMETextSpecification is skipped because {uninstall_packages} is not installed!") + logger.error(f"RKMETextSpecification is not available because 'torch' is not installed!") else: from .rkme import RKMETextSpecification diff --git a/learnware/specification/regular/text/rkme.py b/learnware/specification/regular/text/rkme.py index aa7d72e..0396d24 100644 --- a/learnware/specification/regular/text/rkme.py +++ b/learnware/specification/regular/text/rkme.py @@ -1,7 +1,6 @@ import os import langdetect import numpy as np -from sentence_transformers import SentenceTransformer from ..table import RKMETableSpecification from ....logger import get_module_logger @@ -87,6 +86,12 @@ class RKMETextSpecification(RKMETableSpecification): return np.array(miniLM_learnware.predict(X)) logger.info("Load the necessary feature extractor for RKMETextSpecification.") + + try: + from sentence_transformers import SentenceTransformer + except ModuleNotFoundError: + raise ModuleNotFoundError(f"RKMETextSpecification is not available because 'sentence_transformers' is not installed! Please install it manually.") + if os.path.exists(zip_path): X = _get_from_client(zip_path, X) else: diff --git a/learnware/specification/regular/text/utils.py b/learnware/specification/regular/text/utils.py deleted file mode 100644 index d4568a5..0000000 --- a/learnware/specification/regular/text/utils.py +++ /dev/null @@ -1,15 +0,0 @@ -from ....logger import get_module_logger - -logger = get_module_logger("regular_text_spec_utils") - - -def is_sentence_transformers_available(verbose=False): - try: - import sentence_transformers - except ModuleNotFoundError as err: - if verbose is True: - logger.warning( - "ModuleNotFoundError: sentence_transformers is not installed, please install sentence_transformers!" - ) - return False - return True diff --git a/learnware/specification/system/__init__.py b/learnware/specification/system/__init__.py index 82fbe3f..45be6cc 100644 --- a/learnware/specification/system/__init__.py +++ b/learnware/specification/system/__init__.py @@ -6,6 +6,6 @@ logger = get_module_logger("system_spec") if not is_torch_available(verbose=False): HeteroMapTableSpecification = None - logger.warning("HeteroMapTableSpecification is skipped because torch is not installed!") + logger.error("HeteroMapTableSpecification is not available because 'torch' is not installed!") else: from .hetero_table import HeteroMapTableSpecification