Browse Source

Merge branch 'main' into feature/hetero

tags/v0.3.2
Peng Tan 2 years ago
parent
commit
fdff52c425
35 changed files with 404 additions and 166 deletions
  1. +12
    -0
      .pre-commit-config.yaml
  2. +0
    -1
      examples/dataset_text_workflow/get_data.py
  3. +1
    -0
      examples/dataset_text_workflow/requirements.txt
  4. +8
    -3
      learnware/__init__.py
  5. +1
    -1
      learnware/client/learnware_client.py
  6. +1
    -1
      learnware/market/__init__.py
  7. +12
    -1
      learnware/market/anchor/__init__.py
  8. +1
    -39
      learnware/market/anchor/searcher.py
  9. +41
    -0
      learnware/market/anchor/user_info.py
  10. +1
    -1
      learnware/market/base.py
  11. +14
    -2
      learnware/market/easy/__init__.py
  12. +1
    -1
      learnware/market/easy/database_ops.py
  13. +2
    -1
      learnware/market/easy/organizer.py
  14. +15
    -11
      learnware/market/module.py
  15. +3
    -4
      learnware/model/base.py
  16. +41
    -4
      learnware/reuse/__init__.py
  17. +0
    -1
      learnware/reuse/ensemble_pruning.py
  18. +25
    -0
      learnware/reuse/utils.py
  19. +11
    -1
      learnware/specification/__init__.py
  20. +3
    -1
      learnware/specification/regular/__init__.py
  21. +29
    -1
      learnware/specification/regular/image/__init__.py
  22. +23
    -0
      learnware/specification/regular/image/utils.py
  23. +11
    -1
      learnware/specification/regular/table/__init__.py
  24. +23
    -1
      learnware/specification/regular/text/__init__.py
  25. +15
    -0
      learnware/specification/regular/text/utils.py
  26. +0
    -0
      learnware/tests/__init__.py
  27. +0
    -0
      learnware/tests/data.py
  28. +0
    -0
      learnware/tests/module.py
  29. +0
    -58
      learnware/utils.py
  30. +16
    -0
      learnware/utils/__init__.py
  31. +14
    -0
      learnware/utils/file.py
  32. +13
    -0
      learnware/utils/import_utils.py
  33. +24
    -0
      learnware/utils/module.py
  34. +37
    -26
      setup.py
  35. +6
    -6
      tests/test_workflow/test_workflow.py

+ 12
- 0
.pre-commit-config.yaml View File

@@ -0,0 +1,12 @@
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
args: ["-l 120"]

- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
args: ["--count", "--select=E9,F63,F7,F82", "--show-source", "--statistics"]

+ 0
- 1
examples/dataset_text_workflow/get_data.py View File

@@ -1,4 +1,3 @@
import torch
from torchtext.datasets import SST2 from torchtext.datasets import SST2






+ 1
- 0
examples/dataset_text_workflow/requirements.txt View File

@@ -0,0 +1 @@
torchtext>=0.14.1

+ 8
- 3
learnware/__init__.py View File

@@ -1,7 +1,10 @@
__version__ = "0.1.1.99"
__version__ = "0.1.2.99"


import os import os
from .logger import get_module_logger from .logger import get_module_logger
from .utils import is_torch_avaliable

logger = get_module_logger("Initialization")




def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs): def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs):
@@ -10,9 +13,7 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs):
C.reset() C.reset()
C.update(**kwargs) C.update(**kwargs)


logger = get_module_logger("Initialization")
logger.info(f"init learnware market with {kwargs}") logger.info(f"init learnware market with {kwargs}")

## make dirs ## make dirs
if make_dir: if make_dir:
os.makedirs(C.root_path, exist_ok=True) os.makedirs(C.root_path, exist_ok=True)
@@ -25,3 +26,7 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs):
## ignore tensorflow warning ## ignore tensorflow warning
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel # os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel
# logger.info(f"The tensorflow log level is setted to {tf_loglevel}") # logger.info(f"The tensorflow log level is setted to {tf_loglevel}")


if not is_torch_avaliable(verbose=False):
logger.warning("The functionality of learnware is limited due to 'torch' is not installed!")

+ 1
- 1
learnware/client/learnware_client.py View File

@@ -18,7 +18,7 @@ from ..market import BaseChecker, EasySemanticChecker, EasyStatChecker
from ..logger import get_module_logger from ..logger import get_module_logger
from ..specification import Specification from ..specification import Specification
from ..learnware import get_learnware_from_dirpath from ..learnware import get_learnware_from_dirpath
from ..test import get_semantic_specification
from ..tests import get_semantic_specification


CHUNK_SIZE = 1024 * 1024 CHUNK_SIZE = 1024 * 1024
logger = get_module_logger(module_name="LearnwareClient") logger = get_module_logger(module_name="LearnwareClient")


+ 1
- 1
learnware/market/__init__.py View File

@@ -1,4 +1,4 @@
from .anchor import AnchoredUserInfo, AnchoredOrganizer
from .anchor import AnchoredUserInfo, AnchoredSearcher, AnchoredOrganizer
from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher
from .evolve_anchor import EvolvedAnchoredOrganizer from .evolve_anchor import EvolvedAnchoredOrganizer
from .evolve import EvolvedOrganizer from .evolve import EvolvedOrganizer


+ 12
- 1
learnware/market/anchor/__init__.py View File

@@ -1,2 +1,13 @@
from .organizer import AnchoredOrganizer from .organizer import AnchoredOrganizer
from .searcher import AnchoredUserInfo
from .user_info import AnchoredUserInfo

from ...utils import is_torch_avaliable
from ...logger import get_module_logger

logger = get_module_logger("market_anchor")

if not is_torch_avaliable(verbose=False):
AnchoredSearcher = None
logger.warning("AnchoredSearcher is skipped because 'torch' is not installed!")
else:
from .searcher import AnchoredSearcher

+ 1
- 39
learnware/market/anchor/searcher.py View File

@@ -1,5 +1,6 @@
from typing import List, Dict, Tuple, Any, Union from typing import List, Dict, Tuple, Any, Union


from .user_info import AnchoredUserInfo
from ..base import BaseUserInfo from ..base import BaseUserInfo
from ..easy.searcher import EasySearcher from ..easy.searcher import EasySearcher
from ...logger import get_module_logger from ...logger import get_module_logger
@@ -8,45 +9,6 @@ from ...learnware import Learnware
logger = get_module_logger("anchor_searcher") logger = get_module_logger("anchor_searcher")




class AnchoredUserInfo(BaseUserInfo):
"""
User Information for searching learnware (add the anchor design)

- UserInfo contains the anchor id list acquired from the market
- UserInfo can update stat_info based on anchors
"""

def __init__(
self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None
):
super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info)
self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids

def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]):
"""Add the anchor learnware ids acquired from the market

Parameters
----------
learnware_ids : Union[str, List[str]]
Anchor learnware ids
"""
if isinstance(learnware_ids, str):
learnware_ids = [learnware_ids]
self.anchor_learnware_ids += learnware_ids

def update_stat_info(self, name: str, item: Any):
"""Update stat_info based on anchor learnwares

Parameters
----------
name : str
Name of stat_info
item : Any
Statistical information calculated on anchor learnwares
"""
self.stat_info[name] = item


class AnchoredSearcher(EasySearcher): class AnchoredSearcher(EasySearcher):
def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]: def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]:
"""Search anchor Learnwares from anchor_learnware_list based on user_info """Search anchor Learnwares from anchor_learnware_list based on user_info


+ 41
- 0
learnware/market/anchor/user_info.py View File

@@ -0,0 +1,41 @@
from typing import List, Any, Union
from ..base import BaseUserInfo


class AnchoredUserInfo(BaseUserInfo):
"""
User Information for searching learnware (add the anchor design)

- UserInfo contains the anchor id list acquired from the market
- UserInfo can update stat_info based on anchors
"""

def __init__(
self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None
):
super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info)
self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids

def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]):
"""Add the anchor learnware ids acquired from the market

Parameters
----------
learnware_ids : Union[str, List[str]]
Anchor learnware ids
"""
if isinstance(learnware_ids, str):
learnware_ids = [learnware_ids]
self.anchor_learnware_ids += learnware_ids

def update_stat_info(self, name: str, item: Any):
"""Update stat_info based on anchor learnwares

Parameters
----------
name : str
Name of stat_info
item : Any
Statistical information calculated on anchor learnwares
"""
self.stat_info[name] = item

+ 1
- 1
learnware/market/base.py View File

@@ -227,7 +227,7 @@ class LearnwareMarket:


def reload_learnware(self, learnware_id: str): def reload_learnware(self, learnware_id: str):
self.learnware_organizer.reload_learnware(learnware_id) self.learnware_organizer.reload_learnware(learnware_id)
def get_learnware_zip_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]: def get_learnware_zip_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]:
return self.learnware_organizer.get_learnware_zip_path_by_ids(ids, **kwargs) return self.learnware_organizer.get_learnware_zip_path_by_ids(ids, **kwargs)




+ 14
- 2
learnware/market/easy/__init__.py View File

@@ -1,3 +1,15 @@
from .organizer import EasyOrganizer from .organizer import EasyOrganizer
from .searcher import EasySearcher
from .checker import EasySemanticChecker, EasyStatChecker

from ...utils import is_torch_avaliable
from ...logger import get_module_logger

logger = get_module_logger("market_easy")

if not is_torch_avaliable(verbose=False):
EasySearcher = None
EasySemanticChecker = None
EasyStatChecker = None
logger.warning("EasySeacher and EasyChecker are skipped because 'torch' is not installed!")
else:
from .searcher import EasySearcher
from .checker import EasySemanticChecker, EasyStatChecker

+ 1
- 1
learnware/market/easy/database_ops.py View File

@@ -166,7 +166,7 @@ class DatabaseOperations(object):
return int(row[0]) return int(row[0])
pass pass
pass pass
def load_market(self): def load_market(self):
with self.engine.connect() as conn: with self.engine.connect() as conn:
cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;")) cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;"))


+ 2
- 1
learnware/market/easy/organizer.py View File

@@ -387,7 +387,8 @@ class EasyOrganizer(BaseOrganizer):
self.learnware_folder_list[learnware_id] = target_folder_dir self.learnware_folder_list[learnware_id] = target_folder_dir
semantic_spec = self.dbops.get_learnware_semantic_specification(learnware_id) semantic_spec = self.dbops.get_learnware_semantic_specification(learnware_id)
self.learnware_list[learnware_id] = get_learnware_from_dirpath( self.learnware_list[learnware_id] = get_learnware_from_dirpath(
id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir)
id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir
)
self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id) self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id)
pass pass




+ 15
- 11
learnware/market/module.py View File

@@ -2,25 +2,29 @@ from .base import LearnwareMarket
from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker
from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher


MARKET_CONFIG = {
"easy": {
"organizer": EasyOrganizer(),
"searcher": EasySearcher(),
"checker_list": [EasySemanticChecker(), EasyStatChecker()],
},
"hetero": {

def get_market_config():
market_config = {
"easy": {
"organizer": EasyOrganizer(),
"searcher": EasySearcher(),
"checker_list": [EasySemanticChecker(), EasyStatChecker()],
},
"hetero": {
"organizer": HeteroMapTableOrganizer(), "organizer": HeteroMapTableOrganizer(),
"searcher": HeteroSearcher(), "searcher": HeteroSearcher(),
"checker_list": [] "checker_list": []
}
} }
}
return market_config




def instantiate_learnware_market(market_id="default", name="easy", **kwargs): def instantiate_learnware_market(market_id="default", name="easy", **kwargs):
market_config = get_market_config()
return LearnwareMarket( return LearnwareMarket(
market_id=market_id, market_id=market_id,
organizer=MARKET_CONFIG[name]["organizer"],
searcher=MARKET_CONFIG[name]["searcher"],
checker_list=MARKET_CONFIG[name]["checker_list"],
organizer=market_config[name]["organizer"],
searcher=market_config[name]["searcher"],
checker_list=market_config[name]["checker_list"],
**kwargs **kwargs
) )

+ 3
- 4
learnware/model/base.py View File

@@ -1,5 +1,4 @@
import numpy as np import numpy as np
import torch
from typing import Union from typing import Union




@@ -19,7 +18,7 @@ class BaseModel:
self.input_shape = input_shape self.input_shape = input_shape
self.output_shape = output_shape self.output_shape = output_shape


def predict(self, X: Union[np.ndarray, torch.tensor]) -> Union[np.ndarray, torch.tensor]:
def predict(self, X: np.ndarray) -> np.ndarray:
"""The prediction method for model in learnware, which will be checked when learnware is submitted into the market. """The prediction method for model in learnware, which will be checked when learnware is submitted into the market.


Parameters Parameters
@@ -33,10 +32,10 @@ class BaseModel:
""" """
pass pass


def fit(self, X: Union[np.ndarray, torch.tensor], y: Union[np.ndarray, torch.tensor]):
def fit(self, X: np.ndarray, y: np.ndarray):
pass pass


def finetune(self, X: Union[np.ndarray, torch.tensor], y: np.ndarray):
def finetune(self, X: np.ndarray, y: np.ndarray):
"""The finetune method for continuing train the model searched by market """The finetune method for continuing train the model searched by market


Parameters Parameters


+ 41
- 4
learnware/reuse/__init__.py View File

@@ -1,5 +1,42 @@
from .ensemble_pruning import EnsemblePruningReuser
from .averaging import AveragingReuser
from .job_selector import JobSelectorReuser
from ..logger import get_module_logger
from ..utils import is_torch_avaliable
from .utils import is_geatpy_avaliable, is_lightgbm_avaliable

logger = get_module_logger("reuse")

if not is_geatpy_avaliable(verbose=False):
EnsemblePruningReuser = None
logger.warning("EnsemblePruningReuser is skipped due to 'geatpy' is not installed!")
else:
from .ensemble_pruning import EnsemblePruningReuser

if not is_torch_avaliable(verbose=False):
AveragingReuser = None
logger.warning("AveragingReuser is skipped due to 'torch' is not installed!")
else:
from .averaging import AveragingReuser

if not is_lightgbm_avaliable(verbose=False) or not is_torch_avaliable(verbose=False):
JobSelectorReuser = None
uninstall_packages = [
value
for flag, value in zip(
[
is_lightgbm_avaliable(verbose=False),
is_torch_avaliable(verbose=False),
],
["lightgbm", "torch"],
)
if flag is False
]
logger.warning(f"JobSelectorReuser is skipped due to {uninstall_packages} is not installed!")
else:
from .job_selector import JobSelectorReuser

if not is_torch_avaliable(verbose=False):
HeteroMapTableReuser = None
logger.warning("FeatureAugmentReuser is skipped due to 'torch' is not installed!")
else:
from .hetero_reuser import HeteroMapTableReuser

from .feature_augment_reuser import FeatureAugmentReuser from .feature_augment_reuser import FeatureAugmentReuser
from .hetero_reuser import HeteroMapTableReuser

+ 0
- 1
learnware/reuse/ensemble_pruning.py View File

@@ -2,7 +2,6 @@ import torch
import random import random
import numpy as np import numpy as np
import geatpy as ea import geatpy as ea

from typing import List from typing import List


from learnware.learnware import Learnware from learnware.learnware import Learnware


+ 25
- 0
learnware/reuse/utils.py View File

@@ -0,0 +1,25 @@
from ..logger import get_module_logger

logger = get_module_logger("reuse_utils")


def is_geatpy_avaliable(verbose=False):
try:
import geatpy
except ModuleNotFoundError as err:
if verbose is True:
logger.warning(
"ModuleNotFoundError: geatpy is not installed, please install geatpy (only support python version<3.11)!"
)
return False
return True


def is_lightgbm_avaliable(verbose=False):
try:
import lightgbm
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: lightgbm is not installed, please install lightgbm!")
return False
return True

+ 11
- 1
learnware/specification/__init__.py View File

@@ -1,4 +1,3 @@
from .module import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec, generate_rkme_text_spec
from .base import Specification, BaseStatSpecification from .base import Specification, BaseStatSpecification
from .regular import ( from .regular import (
RegularStatsSpecification, RegularStatsSpecification,
@@ -7,4 +6,15 @@ from .regular import (
RKMEImageSpecification, RKMEImageSpecification,
RKMETextSpecification, RKMETextSpecification,
) )

from .system import HeteroSpecification from .system import HeteroSpecification

from ..utils import is_torch_avaliable

if not is_torch_avaliable(verbose=False):
generate_stat_spec = None
generate_rkme_spec = None
generate_rkme_image_spec = None
generate_rkme_text_spec = None
else:
from .module import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec, generate_rkme_text_spec

+ 3
- 1
learnware/specification/regular/__init__.py View File

@@ -1,4 +1,6 @@
from .base import RegularStatsSpecification
from ...utils import is_torch_avaliable

from .text import RKMETextSpecification from .text import RKMETextSpecification
from .table import RKMETableSpecification, RKMEStatSpecification from .table import RKMETableSpecification, RKMEStatSpecification
from .image import RKMEImageSpecification from .image import RKMEImageSpecification
from .base import RegularStatsSpecification

+ 29
- 1
learnware/specification/regular/image/__init__.py View File

@@ -1 +1,29 @@
from .rkme import RKMEImageSpecification
from .utils import is_torch_optimizer_avaliable, is_torch_vision_avaliable
from ....utils import is_torch_avaliable
from ....logger import get_module_logger


logger = get_module_logger("regular_image_spec")

if (
not is_torch_vision_avaliable(verbose=False)
or not is_torch_optimizer_avaliable(verbose=False)
or not is_torch_avaliable(verbose=False)
):
RKMEImageSpecification = None
uninstall_packages = [
value
for flag, value in zip(
[
is_torch_vision_avaliable(verbose=False),
is_torch_optimizer_avaliable(verbose=False),
is_torch_avaliable(verbose=False),
],
["torchvision", "torch-optimizer", "torch"],
)
if flag is False
]

logger.warning(f"RKMEImageSpecification is skipped because {uninstall_packages} is not installed!")
else:
from .rkme import RKMEImageSpecification

+ 23
- 0
learnware/specification/regular/image/utils.py View File

@@ -0,0 +1,23 @@
from ....logger import get_module_logger

logger = get_module_logger("regular_image_spec_utils")


def is_torch_optimizer_avaliable(verbose=False):
try:
import torch_optimizer
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: torch_optimizer is not installed, please install torch_optimizer!")
return False
return True


def is_torch_vision_avaliable(verbose=False):
try:
import torchvision
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: torchvision is not installed, please install torchvision!")
return False
return True

+ 11
- 1
learnware/specification/regular/table/__init__.py View File

@@ -1 +1,11 @@
from .rkme import RKMETableSpecification, RKMEStatSpecification
from ....utils import is_torch_avaliable
from ....logger import get_module_logger

logger = get_module_logger("regular_table_spec")

if not is_torch_avaliable(verbose=False):
RKMETableSpecification = None
RKMEStatSpecification = None
logger.warning("RKMETableSpecification is skipped because torch is not installed!")
else:
from .rkme import RKMETableSpecification, RKMEStatSpecification

+ 23
- 1
learnware/specification/regular/text/__init__.py View File

@@ -1 +1,23 @@
from .rkme import RKMETextSpecification
from .utils import is_sentence_transformers_avaliable

from ....utils import is_torch_avaliable
from ....logger import get_module_logger

logger = get_module_logger("regular_text_spec")

if not is_sentence_transformers_avaliable(verbose=False) or not is_torch_avaliable(verbose=False):
RKMETextSpecification = None
uninstall_packages = [
value
for flag, value in zip(
[
is_sentence_transformers_avaliable(verbose=False),
is_torch_avaliable(verbose=False),
],
["sentence_transformers", "torch"],
)
if flag is False
]
logger.warning(f"RKMETextSpecification is skipped because {uninstall_packages} is not installed!")
else:
from .rkme import RKMETextSpecification

+ 15
- 0
learnware/specification/regular/text/utils.py View File

@@ -0,0 +1,15 @@
from ....logger import get_module_logger

logger = get_module_logger("regular_text_spec_utils")


def is_sentence_transformers_avaliable(verbose=False):
try:
import sentence_transformers
except ModuleNotFoundError as err:
if verbose is True:
logger.warning(
"ModuleNotFoundError: sentence_transformers is not installed, please install sentence_transformers!"
)
return False
return True

learnware/test/__init__.py → learnware/tests/__init__.py View File


learnware/test/data.py → learnware/tests/data.py View File


learnware/test/module.py → learnware/tests/module.py View File


+ 0
- 58
learnware/utils.py View File

@@ -1,58 +0,0 @@
import os
import sys

import re
import yaml
import importlib
import importlib.util
from typing import Union
from types import ModuleType
import zipfile
from .logger import get_module_logger

logger = get_module_logger("utils")


def get_module_by_module_path(module_path: Union[str, ModuleType]):
if module_path is None:
raise ModuleNotFoundError("None is passed in as parameters as module_path")

if isinstance(module_path, ModuleType):
module = module_path
else:
if module_path.endswith(".py"):
module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_")))
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
else:
module = importlib.import_module(module_path)
return module


def save_dict_to_yaml(dict_value: dict, save_path: str):
"""save dict object into yaml file"""
with open(save_path, "w") as file:
file.write(yaml.dump(dict_value, allow_unicode=True))


def read_yaml_to_dict(yaml_path: str):
"""load yaml file into dict object"""
with open(yaml_path, "r") as file:
dict_value = yaml.load(file.read(), Loader=yaml.FullLoader)
return dict_value


def zip_learnware_folder(path: str, output_name: str):
with zipfile.ZipFile(output_name, "w") as zip_ref:
for root, dirs, files in os.walk(path):
for file in files:
full_path = os.path.join(root, file)
if file.endswith(".pyc") or os.path.islink(full_path):
continue
zip_ref.write(full_path, arcname=os.path.relpath(full_path, path))
pass
pass
pass
pass

+ 16
- 0
learnware/utils/__init__.py View File

@@ -0,0 +1,16 @@
import os
import zipfile

from .import_utils import is_torch_avaliable
from .module import get_module_by_module_path
from .file import read_yaml_to_dict, save_dict_to_yaml


def zip_learnware_folder(path: str, output_name: str):
with zipfile.ZipFile(output_name, "w") as zip_ref:
for root, dirs, files in os.walk(path):
for file in files:
full_path = os.path.join(root, file)
if file.endswith(".pyc") or os.path.islink(full_path):
continue
zip_ref.write(full_path, arcname=os.path.relpath(full_path, path))

+ 14
- 0
learnware/utils/file.py View File

@@ -0,0 +1,14 @@
import yaml


def save_dict_to_yaml(dict_value: dict, save_path: str):
"""save dict object into yaml file"""
with open(save_path, "w") as file:
file.write(yaml.dump(dict_value, allow_unicode=True))


def read_yaml_to_dict(yaml_path: str):
"""load yaml file into dict object"""
with open(yaml_path, "r") as file:
dict_value = yaml.load(file.read(), Loader=yaml.FullLoader)
return dict_value

+ 13
- 0
learnware/utils/import_utils.py View File

@@ -0,0 +1,13 @@
from ..logger import get_module_logger

logger = get_module_logger("import_utils")


def is_torch_avaliable(verbose=False):
try:
import torch
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: torch is not installed, please install pytorch!")
return False
return True

+ 24
- 0
learnware/utils/module.py View File

@@ -0,0 +1,24 @@
import sys
import re
import importlib
import importlib.util
from typing import Union
from types import ModuleType


def get_module_by_module_path(module_path: Union[str, ModuleType]):
if module_path is None:
raise ModuleNotFoundError("None is passed in as parameters as module_path")

if isinstance(module_path, ModuleType):
module = module_path
else:
if module_path.endswith(".py"):
module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_")))
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
else:
module = importlib.import_module(module_path)
return module

+ 37
- 26
setup.py View File

@@ -51,31 +51,23 @@ def get_platform():
# What packages are required for this module to be executed? # What packages are required for this module to be executed?
# `estimator` may depend on other packages. In order to reduce dependencies, it is not written here. # `estimator` may depend on other packages. In order to reduce dependencies, it is not written here.
REQUIRED = [ REQUIRED = [
# "numpy>=1.20.0",
# "pandas>=0.25.1",
# "scipy>=1.0.0",
# "matplotlib>=3.1.3",
# "torch>=1.11.0",
# "cvxopt>=1.3.0",
# "tqdm>=4.65.0",
# "scikit-learn>=0.22",
# "joblib>=1.2.0",
# "pyyaml>=6.0",
# "fire>=0.3.1",
# "lightgbm>=3.3.0",
# "psutil>=5.9.4",
# "torchvision>=0.15.1",
# "sqlalchemy>=2.0.21",
# "shortuuid>=1.0.11",
# "geatpy>=2.7.0",
# "docker>=6.1.3",
# "rapidfuzz>=3.4.0",
# "torchtext>=0.16.0",
# "sentence_transformers>=2.2.2",
# "torch-optimizer>=0.3.0",
# "langdetect>=1.0.9",
# "huggingface-hub<0.18",
# "portalocker>=2.0.0",
"numpy>=1.20.0",
"pandas>=0.25.1",
"scipy>=1.0.0",
"cvxopt>=1.3.0",
"tqdm>=4.65.0",
"scikit-learn>=0.22",
"joblib>=1.2.0",
"pyyaml>=6.0",
"fire>=0.3.1",
"psutil>=5.9.4",
"sqlalchemy>=2.0.21",
"shortuuid>=1.0.11",
"docker>=6.1.3",
"rapidfuzz>=3.4.0",
"langdetect>=1.0.9",
"huggingface-hub<0.18",
"portalocker>=2.0.0",
] ]


if get_platform() != MACOS: if get_platform() != MACOS:
@@ -99,6 +91,23 @@ if __name__ == "__main__":
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
python_requires=REQUIRES_PYTHON, python_requires=REQUIRES_PYTHON,
install_requires=REQUIRED, install_requires=REQUIRED,
extras_require={
"dev": [
# For documentations
"sphinx",
"sphinx_rtd_theme",
# CI dependencies
"pytest>=3",
"wheel",
"setuptools",
"pylint",
# For static analysis
"mypy<0.981",
"flake8",
"black==23.1.0",
"pre-commit",
],
},
classifiers=[ classifiers=[
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
"Intended Audience :: Developers", "Intended Audience :: Developers",
@@ -108,8 +117,10 @@ if __name__ == "__main__":
"Operating System :: POSIX :: Linux", "Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows", "Operating System :: Microsoft :: Windows",
"Operating System :: MacOS", "Operating System :: MacOS",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
], ],
) )

+ 6
- 6
tests/test_workflow/test_workflow.py View File

@@ -30,7 +30,7 @@ user_semantic = {
} }




class TestMarket(unittest.TestCase):
class TestWorkflow(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls) -> None: def setUpClass(cls) -> None:
np.random.seed(2023) np.random.seed(2023)
@@ -226,11 +226,11 @@ class TestMarket(unittest.TestCase):


def suite(): def suite():
_suite = unittest.TestSuite() _suite = unittest.TestSuite()
_suite.addTest(TestMarket("test_prepare_learnware_randomly"))
_suite.addTest(TestMarket("test_upload_delete_learnware"))
_suite.addTest(TestMarket("test_search_semantics"))
_suite.addTest(TestMarket("test_stat_search"))
_suite.addTest(TestMarket("test_learnware_reuse"))
_suite.addTest(TestWorkflow("test_prepare_learnware_randomly"))
_suite.addTest(TestWorkflow("test_upload_delete_learnware"))
_suite.addTest(TestWorkflow("test_search_semantics"))
_suite.addTest(TestWorkflow("test_stat_search"))
_suite.addTest(TestWorkflow("test_learnware_reuse"))
return _suite return _suite






Loading…
Cancel
Save