diff --git a/learnware/__init__.py b/learnware/__init__.py index f2414b5..4813623 100644 --- a/learnware/__init__.py +++ b/learnware/__init__.py @@ -52,7 +52,9 @@ def init(**kwargs): if not is_torch_available(verbose=False): - logger.warning("The functionality of learnware is limited due to 'torch' is not installed!") + logger.warning( + "The ability of learnware is limited due to 'torch' is not installed! Only the core framework is available now." + ) # default init package init() diff --git a/learnware/specification/regular/table/__init__.py b/learnware/specification/regular/table/__init__.py index 262e72c..443648a 100644 --- a/learnware/specification/regular/table/__init__.py +++ b/learnware/specification/regular/table/__init__.py @@ -6,6 +6,9 @@ logger = get_module_logger("regular_table_spec") if not is_torch_available(verbose=False): RKMETableSpecification = None RKMEStatSpecification = None - logger.warning("RKMETableSpecification is skipped because torch is not installed!") + rkme_solve_qp = None + logger.warning( + "RKMETableSpecification, RKMEStatSpecification and rkme_solve_qp are skipped because torch is not installed!" + ) else: from .rkme import RKMETableSpecification, RKMEStatSpecification, rkme_solve_qp diff --git a/learnware/specification/system/__init__.py b/learnware/specification/system/__init__.py index a09ecca..82fbe3f 100644 --- a/learnware/specification/system/__init__.py +++ b/learnware/specification/system/__init__.py @@ -1 +1,11 @@ -from .hetero_table import HeteroMapTableSpecification +from .base import SystemStatSpecification +from ...utils import is_torch_available +from ...logger import get_module_logger + +logger = get_module_logger("system_spec") + +if not is_torch_available(verbose=False): + HeteroMapTableSpecification = None + logger.warning("HeteroMapTableSpecification is skipped because torch is not installed!") +else: + from .hetero_table import HeteroMapTableSpecification diff --git a/setup.py b/setup.py index a3f25b5..2145a04 100644 --- a/setup.py +++ b/setup.py @@ -107,8 +107,9 @@ if __name__ == "__main__": "pre-commit", ], "full": [ - "torch>=2.0.0", - "torchvision>=0.15.1", + # The default full requirements for learnware package + "torch==2.1.0", + "torchvision==0.16.0", "torch-optimizer>=0.3.0", "lightgbm>=3.3.0", "sentence_transformers>=2.2.2",