Browse Source

Merge branch 'main' into feature/hetero

tags/v0.3.2
Peng Tan 2 years ago
parent
commit
fdff52c425
35 changed files with 404 additions and 166 deletions
  1. +12
    -0
      .pre-commit-config.yaml
  2. +0
    -1
      examples/dataset_text_workflow/get_data.py
  3. +1
    -0
      examples/dataset_text_workflow/requirements.txt
  4. +8
    -3
      learnware/__init__.py
  5. +1
    -1
      learnware/client/learnware_client.py
  6. +1
    -1
      learnware/market/__init__.py
  7. +12
    -1
      learnware/market/anchor/__init__.py
  8. +1
    -39
      learnware/market/anchor/searcher.py
  9. +41
    -0
      learnware/market/anchor/user_info.py
  10. +1
    -1
      learnware/market/base.py
  11. +14
    -2
      learnware/market/easy/__init__.py
  12. +1
    -1
      learnware/market/easy/database_ops.py
  13. +2
    -1
      learnware/market/easy/organizer.py
  14. +15
    -11
      learnware/market/module.py
  15. +3
    -4
      learnware/model/base.py
  16. +41
    -4
      learnware/reuse/__init__.py
  17. +0
    -1
      learnware/reuse/ensemble_pruning.py
  18. +25
    -0
      learnware/reuse/utils.py
  19. +11
    -1
      learnware/specification/__init__.py
  20. +3
    -1
      learnware/specification/regular/__init__.py
  21. +29
    -1
      learnware/specification/regular/image/__init__.py
  22. +23
    -0
      learnware/specification/regular/image/utils.py
  23. +11
    -1
      learnware/specification/regular/table/__init__.py
  24. +23
    -1
      learnware/specification/regular/text/__init__.py
  25. +15
    -0
      learnware/specification/regular/text/utils.py
  26. +0
    -0
      learnware/tests/__init__.py
  27. +0
    -0
      learnware/tests/data.py
  28. +0
    -0
      learnware/tests/module.py
  29. +0
    -58
      learnware/utils.py
  30. +16
    -0
      learnware/utils/__init__.py
  31. +14
    -0
      learnware/utils/file.py
  32. +13
    -0
      learnware/utils/import_utils.py
  33. +24
    -0
      learnware/utils/module.py
  34. +37
    -26
      setup.py
  35. +6
    -6
      tests/test_workflow/test_workflow.py

+ 12
- 0
.pre-commit-config.yaml View File

@@ -0,0 +1,12 @@
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
args: ["-l 120"]

- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
args: ["--count", "--select=E9,F63,F7,F82", "--show-source", "--statistics"]

+ 0
- 1
examples/dataset_text_workflow/get_data.py View File

@@ -1,4 +1,3 @@
import torch
from torchtext.datasets import SST2




+ 1
- 0
examples/dataset_text_workflow/requirements.txt View File

@@ -0,0 +1 @@
torchtext>=0.14.1

+ 8
- 3
learnware/__init__.py View File

@@ -1,7 +1,10 @@
__version__ = "0.1.1.99"
__version__ = "0.1.2.99"

import os
from .logger import get_module_logger
from .utils import is_torch_avaliable

logger = get_module_logger("Initialization")


def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs):
@@ -10,9 +13,7 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs):
C.reset()
C.update(**kwargs)

logger = get_module_logger("Initialization")
logger.info(f"init learnware market with {kwargs}")

## make dirs
if make_dir:
os.makedirs(C.root_path, exist_ok=True)
@@ -25,3 +26,7 @@ def init(make_dir: bool = False, tf_loglevel: str = "2", **kwargs):
## ignore tensorflow warning
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = tf_loglevel
# logger.info(f"The tensorflow log level is setted to {tf_loglevel}")


if not is_torch_avaliable(verbose=False):
logger.warning("The functionality of learnware is limited due to 'torch' is not installed!")

+ 1
- 1
learnware/client/learnware_client.py View File

@@ -18,7 +18,7 @@ from ..market import BaseChecker, EasySemanticChecker, EasyStatChecker
from ..logger import get_module_logger
from ..specification import Specification
from ..learnware import get_learnware_from_dirpath
from ..test import get_semantic_specification
from ..tests import get_semantic_specification

CHUNK_SIZE = 1024 * 1024
logger = get_module_logger(module_name="LearnwareClient")


+ 1
- 1
learnware/market/__init__.py View File

@@ -1,4 +1,4 @@
from .anchor import AnchoredUserInfo, AnchoredOrganizer
from .anchor import AnchoredUserInfo, AnchoredSearcher, AnchoredOrganizer
from .base import BaseUserInfo, LearnwareMarket, BaseChecker, BaseOrganizer, BaseSearcher
from .evolve_anchor import EvolvedAnchoredOrganizer
from .evolve import EvolvedOrganizer


+ 12
- 1
learnware/market/anchor/__init__.py View File

@@ -1,2 +1,13 @@
from .organizer import AnchoredOrganizer
from .searcher import AnchoredUserInfo
from .user_info import AnchoredUserInfo

from ...utils import is_torch_avaliable
from ...logger import get_module_logger

logger = get_module_logger("market_anchor")

if not is_torch_avaliable(verbose=False):
AnchoredSearcher = None
logger.warning("AnchoredSearcher is skipped because 'torch' is not installed!")
else:
from .searcher import AnchoredSearcher

+ 1
- 39
learnware/market/anchor/searcher.py View File

@@ -1,5 +1,6 @@
from typing import List, Dict, Tuple, Any, Union

from .user_info import AnchoredUserInfo
from ..base import BaseUserInfo
from ..easy.searcher import EasySearcher
from ...logger import get_module_logger
@@ -8,45 +9,6 @@ from ...learnware import Learnware
logger = get_module_logger("anchor_searcher")


class AnchoredUserInfo(BaseUserInfo):
"""
User Information for searching learnware (add the anchor design)

- UserInfo contains the anchor id list acquired from the market
- UserInfo can update stat_info based on anchors
"""

def __init__(
self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None
):
super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info)
self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids

def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]):
"""Add the anchor learnware ids acquired from the market

Parameters
----------
learnware_ids : Union[str, List[str]]
Anchor learnware ids
"""
if isinstance(learnware_ids, str):
learnware_ids = [learnware_ids]
self.anchor_learnware_ids += learnware_ids

def update_stat_info(self, name: str, item: Any):
"""Update stat_info based on anchor learnwares

Parameters
----------
name : str
Name of stat_info
item : Any
Statistical information calculated on anchor learnwares
"""
self.stat_info[name] = item


class AnchoredSearcher(EasySearcher):
def search_anchor_learnware(self, user_info: AnchoredUserInfo) -> Tuple[Any, List[Learnware]]:
"""Search anchor Learnwares from anchor_learnware_list based on user_info


+ 41
- 0
learnware/market/anchor/user_info.py View File

@@ -0,0 +1,41 @@
from typing import List, Any, Union
from ..base import BaseUserInfo


class AnchoredUserInfo(BaseUserInfo):
"""
User Information for searching learnware (add the anchor design)

- UserInfo contains the anchor id list acquired from the market
- UserInfo can update stat_info based on anchors
"""

def __init__(
self, id: str, semantic_spec: dict = None, stat_info: dict = None, anchor_learnware_ids: List[str] = None
):
super(AnchoredUserInfo, self).__init__(id, semantic_spec, stat_info)
self.anchor_learnware_ids = [] if anchor_learnware_ids is None else anchor_learnware_ids

def add_anchor_learnware_ids(self, learnware_ids: Union[str, List[str]]):
"""Add the anchor learnware ids acquired from the market

Parameters
----------
learnware_ids : Union[str, List[str]]
Anchor learnware ids
"""
if isinstance(learnware_ids, str):
learnware_ids = [learnware_ids]
self.anchor_learnware_ids += learnware_ids

def update_stat_info(self, name: str, item: Any):
"""Update stat_info based on anchor learnwares

Parameters
----------
name : str
Name of stat_info
item : Any
Statistical information calculated on anchor learnwares
"""
self.stat_info[name] = item

+ 1
- 1
learnware/market/base.py View File

@@ -227,7 +227,7 @@ class LearnwareMarket:

def reload_learnware(self, learnware_id: str):
self.learnware_organizer.reload_learnware(learnware_id)
def get_learnware_zip_path_by_ids(self, ids: Union[str, List[str]], **kwargs) -> Union[Learnware, List[Learnware]]:
return self.learnware_organizer.get_learnware_zip_path_by_ids(ids, **kwargs)



+ 14
- 2
learnware/market/easy/__init__.py View File

@@ -1,3 +1,15 @@
from .organizer import EasyOrganizer
from .searcher import EasySearcher
from .checker import EasySemanticChecker, EasyStatChecker

from ...utils import is_torch_avaliable
from ...logger import get_module_logger

logger = get_module_logger("market_easy")

if not is_torch_avaliable(verbose=False):
EasySearcher = None
EasySemanticChecker = None
EasyStatChecker = None
logger.warning("EasySeacher and EasyChecker are skipped because 'torch' is not installed!")
else:
from .searcher import EasySearcher
from .checker import EasySemanticChecker, EasyStatChecker

+ 1
- 1
learnware/market/easy/database_ops.py View File

@@ -166,7 +166,7 @@ class DatabaseOperations(object):
return int(row[0])
pass
pass
def load_market(self):
with self.engine.connect() as conn:
cursor = conn.execute(text("SELECT id, semantic_spec, zip_path, folder_path, use_flag FROM tb_learnware;"))


+ 2
- 1
learnware/market/easy/organizer.py View File

@@ -387,7 +387,8 @@ class EasyOrganizer(BaseOrganizer):
self.learnware_folder_list[learnware_id] = target_folder_dir
semantic_spec = self.dbops.get_learnware_semantic_specification(learnware_id)
self.learnware_list[learnware_id] = get_learnware_from_dirpath(
id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir)
id=learnware_id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir
)
self.use_flags[learnware_id] = self.dbops.get_learnware_use_flag(learnware_id)
pass



+ 15
- 11
learnware/market/module.py View File

@@ -2,25 +2,29 @@ from .base import LearnwareMarket
from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker
from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher

MARKET_CONFIG = {
"easy": {
"organizer": EasyOrganizer(),
"searcher": EasySearcher(),
"checker_list": [EasySemanticChecker(), EasyStatChecker()],
},
"hetero": {

def get_market_config():
market_config = {
"easy": {
"organizer": EasyOrganizer(),
"searcher": EasySearcher(),
"checker_list": [EasySemanticChecker(), EasyStatChecker()],
},
"hetero": {
"organizer": HeteroMapTableOrganizer(),
"searcher": HeteroSearcher(),
"checker_list": []
}
}
}
return market_config


def instantiate_learnware_market(market_id="default", name="easy", **kwargs):
market_config = get_market_config()
return LearnwareMarket(
market_id=market_id,
organizer=MARKET_CONFIG[name]["organizer"],
searcher=MARKET_CONFIG[name]["searcher"],
checker_list=MARKET_CONFIG[name]["checker_list"],
organizer=market_config[name]["organizer"],
searcher=market_config[name]["searcher"],
checker_list=market_config[name]["checker_list"],
**kwargs
)

+ 3
- 4
learnware/model/base.py View File

@@ -1,5 +1,4 @@
import numpy as np
import torch
from typing import Union


@@ -19,7 +18,7 @@ class BaseModel:
self.input_shape = input_shape
self.output_shape = output_shape

def predict(self, X: Union[np.ndarray, torch.tensor]) -> Union[np.ndarray, torch.tensor]:
def predict(self, X: np.ndarray) -> np.ndarray:
"""The prediction method for model in learnware, which will be checked when learnware is submitted into the market.

Parameters
@@ -33,10 +32,10 @@ class BaseModel:
"""
pass

def fit(self, X: Union[np.ndarray, torch.tensor], y: Union[np.ndarray, torch.tensor]):
def fit(self, X: np.ndarray, y: np.ndarray):
pass

def finetune(self, X: Union[np.ndarray, torch.tensor], y: np.ndarray):
def finetune(self, X: np.ndarray, y: np.ndarray):
"""The finetune method for continuing train the model searched by market

Parameters


+ 41
- 4
learnware/reuse/__init__.py View File

@@ -1,5 +1,42 @@
from .ensemble_pruning import EnsemblePruningReuser
from .averaging import AveragingReuser
from .job_selector import JobSelectorReuser
from ..logger import get_module_logger
from ..utils import is_torch_avaliable
from .utils import is_geatpy_avaliable, is_lightgbm_avaliable

logger = get_module_logger("reuse")

if not is_geatpy_avaliable(verbose=False):
EnsemblePruningReuser = None
logger.warning("EnsemblePruningReuser is skipped due to 'geatpy' is not installed!")
else:
from .ensemble_pruning import EnsemblePruningReuser

if not is_torch_avaliable(verbose=False):
AveragingReuser = None
logger.warning("AveragingReuser is skipped due to 'torch' is not installed!")
else:
from .averaging import AveragingReuser

if not is_lightgbm_avaliable(verbose=False) or not is_torch_avaliable(verbose=False):
JobSelectorReuser = None
uninstall_packages = [
value
for flag, value in zip(
[
is_lightgbm_avaliable(verbose=False),
is_torch_avaliable(verbose=False),
],
["lightgbm", "torch"],
)
if flag is False
]
logger.warning(f"JobSelectorReuser is skipped due to {uninstall_packages} is not installed!")
else:
from .job_selector import JobSelectorReuser

if not is_torch_avaliable(verbose=False):
HeteroMapTableReuser = None
logger.warning("FeatureAugmentReuser is skipped due to 'torch' is not installed!")
else:
from .hetero_reuser import HeteroMapTableReuser

from .feature_augment_reuser import FeatureAugmentReuser
from .hetero_reuser import HeteroMapTableReuser

+ 0
- 1
learnware/reuse/ensemble_pruning.py View File

@@ -2,7 +2,6 @@ import torch
import random
import numpy as np
import geatpy as ea

from typing import List

from learnware.learnware import Learnware


+ 25
- 0
learnware/reuse/utils.py View File

@@ -0,0 +1,25 @@
from ..logger import get_module_logger

logger = get_module_logger("reuse_utils")


def is_geatpy_avaliable(verbose=False):
try:
import geatpy
except ModuleNotFoundError as err:
if verbose is True:
logger.warning(
"ModuleNotFoundError: geatpy is not installed, please install geatpy (only support python version<3.11)!"
)
return False
return True


def is_lightgbm_avaliable(verbose=False):
try:
import lightgbm
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: lightgbm is not installed, please install lightgbm!")
return False
return True

+ 11
- 1
learnware/specification/__init__.py View File

@@ -1,4 +1,3 @@
from .module import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec, generate_rkme_text_spec
from .base import Specification, BaseStatSpecification
from .regular import (
RegularStatsSpecification,
@@ -7,4 +6,15 @@ from .regular import (
RKMEImageSpecification,
RKMETextSpecification,
)

from .system import HeteroSpecification

from ..utils import is_torch_avaliable

if not is_torch_avaliable(verbose=False):
generate_stat_spec = None
generate_rkme_spec = None
generate_rkme_image_spec = None
generate_rkme_text_spec = None
else:
from .module import generate_stat_spec, generate_rkme_spec, generate_rkme_image_spec, generate_rkme_text_spec

+ 3
- 1
learnware/specification/regular/__init__.py View File

@@ -1,4 +1,6 @@
from .base import RegularStatsSpecification
from ...utils import is_torch_avaliable

from .text import RKMETextSpecification
from .table import RKMETableSpecification, RKMEStatSpecification
from .image import RKMEImageSpecification
from .base import RegularStatsSpecification

+ 29
- 1
learnware/specification/regular/image/__init__.py View File

@@ -1 +1,29 @@
from .rkme import RKMEImageSpecification
from .utils import is_torch_optimizer_avaliable, is_torch_vision_avaliable
from ....utils import is_torch_avaliable
from ....logger import get_module_logger


logger = get_module_logger("regular_image_spec")

if (
not is_torch_vision_avaliable(verbose=False)
or not is_torch_optimizer_avaliable(verbose=False)
or not is_torch_avaliable(verbose=False)
):
RKMEImageSpecification = None
uninstall_packages = [
value
for flag, value in zip(
[
is_torch_vision_avaliable(verbose=False),
is_torch_optimizer_avaliable(verbose=False),
is_torch_avaliable(verbose=False),
],
["torchvision", "torch-optimizer", "torch"],
)
if flag is False
]

logger.warning(f"RKMEImageSpecification is skipped because {uninstall_packages} is not installed!")
else:
from .rkme import RKMEImageSpecification

+ 23
- 0
learnware/specification/regular/image/utils.py View File

@@ -0,0 +1,23 @@
from ....logger import get_module_logger

logger = get_module_logger("regular_image_spec_utils")


def is_torch_optimizer_avaliable(verbose=False):
try:
import torch_optimizer
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: torch_optimizer is not installed, please install torch_optimizer!")
return False
return True


def is_torch_vision_avaliable(verbose=False):
try:
import torchvision
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: torchvision is not installed, please install torchvision!")
return False
return True

+ 11
- 1
learnware/specification/regular/table/__init__.py View File

@@ -1 +1,11 @@
from .rkme import RKMETableSpecification, RKMEStatSpecification
from ....utils import is_torch_avaliable
from ....logger import get_module_logger

logger = get_module_logger("regular_table_spec")

if not is_torch_avaliable(verbose=False):
RKMETableSpecification = None
RKMEStatSpecification = None
logger.warning("RKMETableSpecification is skipped because torch is not installed!")
else:
from .rkme import RKMETableSpecification, RKMEStatSpecification

+ 23
- 1
learnware/specification/regular/text/__init__.py View File

@@ -1 +1,23 @@
from .rkme import RKMETextSpecification
from .utils import is_sentence_transformers_avaliable

from ....utils import is_torch_avaliable
from ....logger import get_module_logger

logger = get_module_logger("regular_text_spec")

if not is_sentence_transformers_avaliable(verbose=False) or not is_torch_avaliable(verbose=False):
RKMETextSpecification = None
uninstall_packages = [
value
for flag, value in zip(
[
is_sentence_transformers_avaliable(verbose=False),
is_torch_avaliable(verbose=False),
],
["sentence_transformers", "torch"],
)
if flag is False
]
logger.warning(f"RKMETextSpecification is skipped because {uninstall_packages} is not installed!")
else:
from .rkme import RKMETextSpecification

+ 15
- 0
learnware/specification/regular/text/utils.py View File

@@ -0,0 +1,15 @@
from ....logger import get_module_logger

logger = get_module_logger("regular_text_spec_utils")


def is_sentence_transformers_avaliable(verbose=False):
try:
import sentence_transformers
except ModuleNotFoundError as err:
if verbose is True:
logger.warning(
"ModuleNotFoundError: sentence_transformers is not installed, please install sentence_transformers!"
)
return False
return True

learnware/test/__init__.py → learnware/tests/__init__.py View File


learnware/test/data.py → learnware/tests/data.py View File


learnware/test/module.py → learnware/tests/module.py View File


+ 0
- 58
learnware/utils.py View File

@@ -1,58 +0,0 @@
import os
import sys

import re
import yaml
import importlib
import importlib.util
from typing import Union
from types import ModuleType
import zipfile
from .logger import get_module_logger

logger = get_module_logger("utils")


def get_module_by_module_path(module_path: Union[str, ModuleType]):
if module_path is None:
raise ModuleNotFoundError("None is passed in as parameters as module_path")

if isinstance(module_path, ModuleType):
module = module_path
else:
if module_path.endswith(".py"):
module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_")))
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
else:
module = importlib.import_module(module_path)
return module


def save_dict_to_yaml(dict_value: dict, save_path: str):
"""save dict object into yaml file"""
with open(save_path, "w") as file:
file.write(yaml.dump(dict_value, allow_unicode=True))


def read_yaml_to_dict(yaml_path: str):
"""load yaml file into dict object"""
with open(yaml_path, "r") as file:
dict_value = yaml.load(file.read(), Loader=yaml.FullLoader)
return dict_value


def zip_learnware_folder(path: str, output_name: str):
with zipfile.ZipFile(output_name, "w") as zip_ref:
for root, dirs, files in os.walk(path):
for file in files:
full_path = os.path.join(root, file)
if file.endswith(".pyc") or os.path.islink(full_path):
continue
zip_ref.write(full_path, arcname=os.path.relpath(full_path, path))
pass
pass
pass
pass

+ 16
- 0
learnware/utils/__init__.py View File

@@ -0,0 +1,16 @@
import os
import zipfile

from .import_utils import is_torch_avaliable
from .module import get_module_by_module_path
from .file import read_yaml_to_dict, save_dict_to_yaml


def zip_learnware_folder(path: str, output_name: str):
with zipfile.ZipFile(output_name, "w") as zip_ref:
for root, dirs, files in os.walk(path):
for file in files:
full_path = os.path.join(root, file)
if file.endswith(".pyc") or os.path.islink(full_path):
continue
zip_ref.write(full_path, arcname=os.path.relpath(full_path, path))

+ 14
- 0
learnware/utils/file.py View File

@@ -0,0 +1,14 @@
import yaml


def save_dict_to_yaml(dict_value: dict, save_path: str):
"""save dict object into yaml file"""
with open(save_path, "w") as file:
file.write(yaml.dump(dict_value, allow_unicode=True))


def read_yaml_to_dict(yaml_path: str):
"""load yaml file into dict object"""
with open(yaml_path, "r") as file:
dict_value = yaml.load(file.read(), Loader=yaml.FullLoader)
return dict_value

+ 13
- 0
learnware/utils/import_utils.py View File

@@ -0,0 +1,13 @@
from ..logger import get_module_logger

logger = get_module_logger("import_utils")


def is_torch_avaliable(verbose=False):
try:
import torch
except ModuleNotFoundError as err:
if verbose is True:
logger.warning("ModuleNotFoundError: torch is not installed, please install pytorch!")
return False
return True

+ 24
- 0
learnware/utils/module.py View File

@@ -0,0 +1,24 @@
import sys
import re
import importlib
import importlib.util
from typing import Union
from types import ModuleType


def get_module_by_module_path(module_path: Union[str, ModuleType]):
if module_path is None:
raise ModuleNotFoundError("None is passed in as parameters as module_path")

if isinstance(module_path, ModuleType):
module = module_path
else:
if module_path.endswith(".py"):
module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_")))
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
else:
module = importlib.import_module(module_path)
return module

+ 37
- 26
setup.py View File

@@ -51,31 +51,23 @@ def get_platform():
# What packages are required for this module to be executed?
# `estimator` may depend on other packages. In order to reduce dependencies, it is not written here.
REQUIRED = [
# "numpy>=1.20.0",
# "pandas>=0.25.1",
# "scipy>=1.0.0",
# "matplotlib>=3.1.3",
# "torch>=1.11.0",
# "cvxopt>=1.3.0",
# "tqdm>=4.65.0",
# "scikit-learn>=0.22",
# "joblib>=1.2.0",
# "pyyaml>=6.0",
# "fire>=0.3.1",
# "lightgbm>=3.3.0",
# "psutil>=5.9.4",
# "torchvision>=0.15.1",
# "sqlalchemy>=2.0.21",
# "shortuuid>=1.0.11",
# "geatpy>=2.7.0",
# "docker>=6.1.3",
# "rapidfuzz>=3.4.0",
# "torchtext>=0.16.0",
# "sentence_transformers>=2.2.2",
# "torch-optimizer>=0.3.0",
# "langdetect>=1.0.9",
# "huggingface-hub<0.18",
# "portalocker>=2.0.0",
"numpy>=1.20.0",
"pandas>=0.25.1",
"scipy>=1.0.0",
"cvxopt>=1.3.0",
"tqdm>=4.65.0",
"scikit-learn>=0.22",
"joblib>=1.2.0",
"pyyaml>=6.0",
"fire>=0.3.1",
"psutil>=5.9.4",
"sqlalchemy>=2.0.21",
"shortuuid>=1.0.11",
"docker>=6.1.3",
"rapidfuzz>=3.4.0",
"langdetect>=1.0.9",
"huggingface-hub<0.18",
"portalocker>=2.0.0",
]

if get_platform() != MACOS:
@@ -99,6 +91,23 @@ if __name__ == "__main__":
long_description_content_type="text/markdown",
python_requires=REQUIRES_PYTHON,
install_requires=REQUIRED,
extras_require={
"dev": [
# For documentations
"sphinx",
"sphinx_rtd_theme",
# CI dependencies
"pytest>=3",
"wheel",
"setuptools",
"pylint",
# For static analysis
"mypy<0.981",
"flake8",
"black==23.1.0",
"pre-commit",
],
},
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
@@ -108,8 +117,10 @@ if __name__ == "__main__":
"Operating System :: POSIX :: Linux",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
],
)

+ 6
- 6
tests/test_workflow/test_workflow.py View File

@@ -30,7 +30,7 @@ user_semantic = {
}


class TestMarket(unittest.TestCase):
class TestWorkflow(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
np.random.seed(2023)
@@ -226,11 +226,11 @@ class TestMarket(unittest.TestCase):

def suite():
_suite = unittest.TestSuite()
_suite.addTest(TestMarket("test_prepare_learnware_randomly"))
_suite.addTest(TestMarket("test_upload_delete_learnware"))
_suite.addTest(TestMarket("test_search_semantics"))
_suite.addTest(TestMarket("test_stat_search"))
_suite.addTest(TestMarket("test_learnware_reuse"))
_suite.addTest(TestWorkflow("test_prepare_learnware_randomly"))
_suite.addTest(TestWorkflow("test_upload_delete_learnware"))
_suite.addTest(TestWorkflow("test_search_semantics"))
_suite.addTest(TestWorkflow("test_stat_search"))
_suite.addTest(TestWorkflow("test_learnware_reuse"))
return _suite




Loading…
Cancel
Save