| @@ -1,5 +1,5 @@ | |||
| ============================== | |||
| Specification evolvement | |||
| Specification Evolvement | |||
| ============================== | |||
| The specification is the core of the learnware paradigm. | |||
| @@ -3,65 +3,52 @@ | |||
| Market | |||
| ================================ | |||
| The ``learnware market`` receives high-performance machine learning models from developers, incorporates them into the system, and provides services to users by identifying and reusing learnware to help users solve current tasks. Developers voluntarily submit various learnwares to the learnware market, and the market conducts quality checks and further organization of these learnwares. When users submit task requirements, the learnware market automatically selects whether to recommend a single learnware or a combination of multiple learnwares. | |||
| Concepts | |||
| ====================================== | |||
| In the learnware paradigm, there are three key players: *developers*, *users*, and the *market*. | |||
| * Developers: Typically machine learning experts who create and aim to share or sell their high-performing trained machine learning models. | |||
| * Users: Require machine learning services but often possess limited data and lack the necessary knowledge and expertise in machine learning. | |||
| * Market: Acquires top-performing trained models from developers, houses them within the marketplace, and offers services to users by identifying and reusing learnwares to help users with their current tasks. | |||
| This process can be broken down into two main stages. | |||
| The ``learnware market`` will receive various kinds of learnwares, and learnwares from different feature/label spaces form numerous islands of specifications. All these islands together constitute the ``specification world`` in the learnware market. The market should discover and establish connections between different islands, and then merge them into a unified specification world. This further organization of learnwares support search learnwares among all learnwares, not just among learnwares which has the same feature space and label space with the user's task requirements. | |||
| Submitting Stage | |||
| ------------------------------ | |||
| During the *submitting stage*, developers can voluntarily submit their trained models to the learnware market. The market will then implement a quality assurance mechanism, such as performance validation, to determine if a submitted model is suitable for acceptance. In a learnware market with millions of models, identifying potentially helpful models for a new user is a challenge. | |||
| Requiring users to submit their own data to the market for model testing is impractical, time-consuming, and costly, as it could lead to data leakage. Straightforward approaches, such as measuring the similarity between user data and the original training data of models, are also infeasible due to privacy and proprietary concerns. Our design operates under the constraint that the learnware market has no access to the original training data from developers or users. Furthermore, it assumes that users have limited knowledge of the models available in the market. | |||
| Framework | |||
| ====================================== | |||
| The solution's crux lies in the *specification*, which is central to the learnware proposal. Once a submitted model is accepted by the learnware market, it is assigned a specification, which conveys the model's specialty and utility without revealing its original training data. For simplicity, consider models as functions that map input domain :math:`\mathcal{X}` to output domain :math:`\mathcal{Y}` with respect to objective 'obj.' These models exist in a functional space :math:`\mathcal{F}: \mathcal{X} \mapsto \mathcal{Y}` with respect to 'obj.' Each model has a specification, and all specifications form a specification space, where those for models that serve similar tasks are situated closely. | |||
| The ``learnware market`` is combined with a ``organizer``, a ``searcher``, and a list of ``checker``s. | |||
| In a learnware market, heterogeneous models may have different :math:`\mathcal{X}`, :math:`\mathcal{Y}`, or objectives. If we refer to the specification space that covers all possible models in all possible functional spaces as the 'specification world' analogously, then each specification space corresponding to one possible functional space can be called a 'specification island.' Designing an elegant specification format that encompasses the entire specification world and allows all possible models to be efficiently and adequately identified is a significant challenge. Currently, we adopt a practical design, where each learnware's specification consists of two parts. | |||
| The ``organizer`` can store and organize learnwares in the market. It supports ``add``, ``delete``, and ``update`` operations for learnwares. It also provides the interface for ``searcher`` to search learnwares based on user requirement. | |||
| The ``searcher`` can search learnwares based on user requirement. The implementation of ``searcher`` is dependent on the concrete implementation and interface for ``organizer``, where usually an ``organizer`` can be compatible with multiple different ``searcher``s. | |||
| Reusing Stage | |||
| ------------------------------ | |||
| The ``checker`` is used for checking the learnware in some standards. It should check the utility of a learnware and is supposed to return the status and a message related to the learnware's check result. Only the learnwares who passed the ``checker`` could be able to be stored and added into the ``learnware market``. | |||
| Creating Learnware Specifications | |||
| ++++++++++++++++++++++++++++++++++++ | |||
| The first part of the learnware specification can be realized by a string consisting of a set of descriptions/tags given by the learnware market. These tags address aspects such as the task, input, output, and objective. Based on the user's provided descriptions/tags, the corresponding specification island can be efficiently and accurately located. The designer of the learnware market can create an initial set of descriptions/tags, which can grow as new models are accepted, and new functional spaces and specification islands are created. | |||
| Merging Specification Islands | |||
| +++++++++++++++++++++++++++++++++ | |||
| Current Checkers | |||
| ====================================== | |||
| Specification islands can merge into larger ones. For example, when a new model about :math:`F: \mathcal{X}_1 \cup \mathcal{X}_2 \mapsto \mathcal{Y}` with respect to 'obj' is accepted by the learnware market, two islands can be merged. This is possible because the market can have synthetic data by randomly generating inputs, feeding them to models, and concatenating each input with its corresponding output to construct a dataset reflecting the function of a model. In principle, specification islands can be merged if there are common ingredients in :math:`\mathcal{X}`, :math:`\mathcal{Y}`, and 'obj.' | |||
| The ``learnware`` package provide two different implementation of ``market`` where both of them share the same ``checker`` list. So we first introduce the details of ``checker``s. | |||
| Deploying Learnware Models | |||
| ++++++++++++++++++++++++++++++ | |||
| The ``checker``s check a learnware object in different aspects, including environment configuration (``CondaChecker``), semantic specifications (``EasySemanticChecker``), and statistical specifications (``EasyStatChecker``). The ``__call__`` method of each checker is designed to be invoked as a function to conduct the respective checks on the learnware and return the outcomes. It defines three types of learnwares: ``INVALID_LEARNWARE`` denotes the learnware does not pass the check, ``NONUSABLE_LEARNWARE`` denotes the learnware pass the check but cannot make prediction, ``USABLE_LEARWARE`` denotes the leanrware pass the check and can make prediction. Currently, we have three ``checker``s, which are described below. | |||
| In the deploying stage, the user submits their requirement to the learnware market, which then identifies and returns helpful learnwares to the user. There are two issues to address: how to identify learnwares matching the user requirement and how to reuse the returned learnwares. | |||
| The learnware market can house thousands or millions of models. Efficiently identifying helpful learnwares is challenging, especially given that the market has no access to the original training data of learnwares or the current user's data. With the specification design mentioned earlier, the market can request users to describe their intentions using a set of descriptions/tags, through a user interface or a learnware description language. Based on this information, the task becomes identifying helpful learnwares in a specification island. | |||
| ``CondaChecker`` | |||
| ------------------ | |||
| This ``checker`` checks a the environment of the learnware object. It creates a ``LearnwaresContainer`` instance to handle the Learnware and uses ``inner_checker`` to check the Learnware. If an exception occurs, it logs the error and returns ``NONUSABLE_LEARNWARE`` status and error message. | |||
| Reusing Learnwares | |||
| ++++++++++++++++++++++ | |||
| Once helpful learnwares are identified and delivered to the user, they can be reused in various ways. Users can apply the received learnware directly to their data, use multiple learnwares to create an ensemble, or adapt and polish the received learnware(s) using their own data. Learnwares can also be used as feature augmentors, with their outputs used as augmented features for building the final model. | |||
| ``EasySemanticChecker`` | |||
| ------------------------- | |||
| This ``checker`` checks the semantic specification of a learnware object. It checks if the given semantic specification conforms to predefined standards. It verifies each key in predefined dictionary. If the check fails, it logs the error and returns ``NONUSABLE_LEARNWARE`` status and error message. | |||
| Helpful learnwares may be trained for tasks that are not exactly the same as the user's current task. In such cases, users can tackle their tasks in a divide-and-conquer way or reuse the learnwares collectively through measuring the utility of each model on each testing instance. If users find it difficult to express their requirements accurately, they can adapt and polish the received learnwares directly using their own data. | |||
| ``EasyStatChecker`` | |||
| --------------------- | |||
| Framework | |||
| ====================================== | |||
| This ``checker`` checks the statistical specification and functionality of a learnware object. It performs multiple checks to validate the learnware. It checks for model instantiation, verifies input shape and statistical specifications, and test output shape using random generated data. In case of any exceptions, it logs the error and returns ``NONUSABLE_LEARNWARE`` status and error message. | |||
| Current Markets | |||
| ====================================== | |||
| The ``learnware`` package provide two different implementation of ``market``, i.e. ``Easy Market`` and ``Hetero Market``. They have different implementation of ``organizer`` and ``searcher``. | |||
| Easy Market | |||
| ------------- | |||
| @@ -77,6 +64,3 @@ One important case is that models have different feature spaces. In order to ena | |||
| - First, design a method for the market to connect different feature spaces to a common subspace and implement the function ``HeterogeneousFeatureMarket.learn_mapping_functions``. This function uses specifications of all submitted models to learn mapping functions that can map the data in the original feature space to the common subspace and vice verse. | |||
| - Second, use learned mapping functions to implement the functions ``HeterogeneousFeatureMarket.transform_original_to_subspace`` and ``HeterogeneousFeatureMarket.transform_subspace_to_original``. | |||
| - Third, use the functions ``HeterogeneousFeatureMarket.transform_original_to_subspace`` and ``HeterogeneousFeatureMarket.transform_subspace_to_original`` to overwrite the mehtod ``EvolvedMarket.generate_new_stat_specification`` and ``EvolvedMarket.EvolvedMarket.evolve_learnware_list`` of the base class ``EvolvedMarket``. | |||
| Current Checkers | |||
| ====================================== | |||
| @@ -80,6 +80,51 @@ Table Specification | |||
| Image Specification | |||
| -------------------------- | |||
| Image data lives in a higher dimensional space than other data types. Unlike lower dimensional spaces, metrics defined based on Euclidean distances (or similar distances) will fail in higher dimensional spaces. This means that measuring the similarity between image samples becomes difficult. | |||
| To address these issues, we use the Neural Tangent Kernel (NTK) based on Convolutional Neural Networks (CNN) to measure the similarity of image samples. As we all know, CNN has greatly advanced the field of computer vision and is still a mainstream deep learning technique. | |||
| Usage & Example | |||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ | |||
| In this part, we show that how to generate Image Specification for the training set of the CIFAR-10 dataset. | |||
| Note that the Image Specification is generated on a subset of the CIFAR-10 dataset with ``generate_rkme_image_spec``. | |||
| Then, it is saved to file "cifar10.json" using ``spec.save``. | |||
| In many cases, it is difficult to construct Image Specification on the full dataset. | |||
| By randomly sampling a subset of the dataset, we can construct Image Specification based on it efficiently, with a strong enough statistical description of the full dataset. | |||
| .. tip:: | |||
| Typically, sampling 3,000 to 10,000 images is sufficient to generate the Image Specification. | |||
| .. code-block:: python | |||
| import torchvision | |||
| from torch.utils.data import DataLoader | |||
| from learnware.specification import generate_rkme_image_spec | |||
| SAMPLED_SIZE = 5000 | |||
| full_set = torchvision.datasets.CIFAR10( | |||
| root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor()) | |||
| loader = DataLoader(full_set, batch_size=SAMPLED_SIZE, shuffle=True) | |||
| sampled_X, _ = next(iter(loader)) | |||
| spec = generate_rkme_image_spec(sampled_X) | |||
| spec.save("cifar10.json") | |||
| Privacy Protection | |||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ | |||
| In the third row of the figure, we show the eight pseudo-data with the largest weights :math:`\beta` in the Image Specification generated on the CIFAR-10 dataset. | |||
| Notice that the Image Specification generated based on Neural Tangent Kernel (NTK) protects the user's privacy very well. | |||
| In contrast, we show the performance of the RBF kernel on image dat in the first row of the figure below. | |||
| The RBF not only exposes the real data (plotted in the corresponding position in the second row), but also fails to fully utilise the weights :math:`\beta`. | |||
| .. image:: ../_static/img/image_spec.png | |||
| :align: center | |||
| Text Specification | |||
| -------------------------- | |||
| @@ -92,8 +92,8 @@ html_theme = "sphinx_book_theme" | |||
| html_theme_path = [sphinx_book_theme.get_html_theme_path()] | |||
| html_theme_options = { | |||
| "logo_only": True, | |||
| # "collapse_navigation": False, | |||
| # "display_version": False, | |||
| "collapse_navigation": False, | |||
| # "display_version": False, | |||
| "navigation_depth": 4, | |||
| } | |||
| html_logo = "_static/img/logo/logo1.png" | |||
| @@ -1,4 +1,4 @@ | |||
| __version__ = "0.2.0.7" | |||
| __version__ = "0.2.0.8" | |||
| import os | |||
| import json | |||
| @@ -14,10 +14,11 @@ from typing import Union, List, Optional | |||
| from ..config import C | |||
| from .container import LearnwaresContainer | |||
| from ..market import BaseChecker | |||
| from ..specification import generate_semantic_spec | |||
| from ..logger import get_module_logger | |||
| from ..learnware import get_learnware_from_dirpath | |||
| from ..market import BaseUserInfo | |||
| from ..tests import get_semantic_specification | |||
| CHUNK_SIZE = 1024 * 1024 | |||
| logger = get_module_logger(module_name="LearnwareClient") | |||
| @@ -52,8 +53,8 @@ class SemanticSpecificationKey(Enum): | |||
| DATA_TYPE = "Data" | |||
| TASK_TYPE = "Task" | |||
| LIBRARY_TYPE = "Library" | |||
| LICENSE = "License" | |||
| SENARIOES = "Scenario" | |||
| LICENSE = "License" | |||
| class LearnwareClient: | |||
| @@ -67,8 +68,16 @@ class LearnwareClient: | |||
| self.chunk_size = 1024 * 1024 | |||
| self.tempdir_list = [] | |||
| self.login_status = False | |||
| atexit.register(self.cleanup) | |||
| def is_connected(self): | |||
| url = f"{self.host}/auth/login_by_token" | |||
| response = requests.post(url) | |||
| if response.status_code == 404: | |||
| return False | |||
| return True | |||
| def login(self, email, token): | |||
| url = f"{self.host}/auth/login_by_token" | |||
| @@ -80,6 +89,10 @@ class LearnwareClient: | |||
| token = result["data"]["token"] | |||
| self.headers = {"Authorization": f"Bearer {token}"} | |||
| self.login_status = True | |||
| def is_login(self): | |||
| return self.login_status | |||
| @require_login | |||
| def logout(self): | |||
| @@ -166,7 +179,7 @@ class LearnwareClient: | |||
| if result["code"] != 0: | |||
| raise Exception("update failed: " + json.dumps(result)) | |||
| def download_learnware(self, learnware_id, save_path): | |||
| def download_learnware(self, learnware_id: str, save_path: str): | |||
| url = f"{self.host}/engine/download_learnware" | |||
| response = requests.get( | |||
| @@ -216,55 +229,59 @@ class LearnwareClient: | |||
| else: | |||
| stat_spec = None | |||
| returns = [] | |||
| with tempfile.NamedTemporaryFile(prefix="learnware_stat_", suffix=".json") as ftemp: | |||
| returns = { | |||
| "single": { | |||
| "learnware_ids": [], | |||
| "semantic_specifications": [], | |||
| "matching": [], | |||
| }, | |||
| "multiple": { | |||
| "learnware_ids": [], | |||
| "semantic_specifications": [], | |||
| "matching": None, | |||
| }, | |||
| } | |||
| with tempfile.NamedTemporaryFile(prefix="learnware_stat_", suffix=".json", delete=False) as ftemp: | |||
| temp_file_name = ftemp.name | |||
| if stat_spec is not None: | |||
| stat_spec.save(ftemp.name) | |||
| with open(ftemp.name, "r") as fin: | |||
| semantic_specification = user_info.get_semantic_spec() | |||
| if stat_spec is None: | |||
| files = None | |||
| else: | |||
| files = {"statistical_specification": fin} | |||
| response = requests.post( | |||
| url, | |||
| files=files, | |||
| data={ | |||
| "semantic_specification": json.dumps(semantic_specification), | |||
| "limit": page_size, | |||
| "page": page_index, | |||
| }, | |||
| headers=self.headers, | |||
| ) | |||
| result = response.json() | |||
| if result["code"] != 0: | |||
| raise Exception("search failed: " + json.dumps(result)) | |||
| for learnware in result["data"]["learnware_list_single"]: | |||
| returns.append( | |||
| { | |||
| "type": "single", | |||
| "learnware_id": learnware["learnware_id"], | |||
| "semantic_specification": learnware["semantic_specification"], | |||
| "matching": learnware["matching"], | |||
| } | |||
| ) | |||
| if len(result["data"]["learnware_list_multi"]) > 0: | |||
| multiple_learnware = { | |||
| "type": "multiple", | |||
| "learnware_ids": [], | |||
| "semantic_specifications": [], | |||
| "matching": result["data"]["learnware_list_multi"][0]["matching"], | |||
| } | |||
| for learnware in result["data"]["learnware_list_multi"]: | |||
| multiple_learnware["learnware_ids"].append(learnware["learnware_id"]) | |||
| multiple_learnware["semantic_specifications"].append(learnware["semantic_specification"]) | |||
| returns.append(multiple_learnware) | |||
| stat_spec.save(temp_file_name) | |||
| with open(temp_file_name, "r") as fin: | |||
| semantic_specification = user_info.get_semantic_spec() | |||
| if stat_spec is None: | |||
| files = None | |||
| else: | |||
| files = {"statistical_specification": fin} | |||
| response = requests.post( | |||
| url, | |||
| files=files, | |||
| data={ | |||
| "semantic_specification": json.dumps(semantic_specification), | |||
| "limit": page_size, | |||
| "page": page_index, | |||
| }, | |||
| headers=self.headers, | |||
| ) | |||
| result = response.json() | |||
| if result["code"] != 0: | |||
| raise Exception("search failed: " + json.dumps(result)) | |||
| for learnware in result["data"]["learnware_list_single"]: | |||
| returns["single"]["learnware_ids"].append(learnware["learnware_id"]) | |||
| returns["single"]["semantic_specifications"].append(learnware["semantic_specification"]) | |||
| returns["single"]["matching"].append(learnware["matching"]) | |||
| if len(result["data"]["learnware_list_multi"]) > 0: | |||
| multi_learnware = result["data"]["learnware_list_multi"][0] | |||
| returns["multiple"]["learnware_ids"].append(multi_learnware["learnware_id"]) | |||
| returns["multiple"]["semantic_specifications"].append(multi_learnware["semantic_specification"]) | |||
| returns["multiple"]["matching"] = learnware["matching"] | |||
| # Delete temp json file | |||
| os.remove(temp_file_name) | |||
| return returns | |||
| @require_login | |||
| @@ -277,41 +294,6 @@ class LearnwareClient: | |||
| if result["code"] != 0: | |||
| raise Exception("delete failed: " + json.dumps(result)) | |||
| def create_semantic_specification( | |||
| self, | |||
| name: Optional[str] = None, | |||
| description: Optional[str] = None, | |||
| data_type: Optional[str] = None, | |||
| task_type: Optional[str] = None, | |||
| library_type: Optional[str] = None, | |||
| scenarios: Optional[Union[str, List[str]]] = None, | |||
| license: Optional[Union[str, List[str]]] = None, | |||
| input_description: Optional[dict] = None, | |||
| output_description: Optional[dict] = None, | |||
| ): | |||
| semantic_specification = dict() | |||
| semantic_specification["Data"] = {"Type": "Class", "Values": [data_type] if data_type is not None else []} | |||
| semantic_specification["Task"] = {"Type": "Class", "Values": [task_type] if task_type is not None else []} | |||
| semantic_specification["Library"] = { | |||
| "Type": "Class", | |||
| "Values": [library_type] if library_type is not None else [], | |||
| } | |||
| license = [license] if isinstance(license, str) else license | |||
| semantic_specification["License"] = {"Type": "Class", "Values": license if license is not None else []} | |||
| scenarios = [scenarios] if isinstance(scenarios, str) else scenarios | |||
| semantic_specification["Scenario"] = {"Type": "Tag", "Values": scenarios if scenarios is not None else []} | |||
| semantic_specification["Name"] = {"Type": "String", "Values": name if name is not None else ""} | |||
| semantic_specification["Description"] = { | |||
| "Type": "String", | |||
| "Values": description if description is not None else "", | |||
| } | |||
| semantic_specification["Input"] = {} if input_description is None else input_description | |||
| semantic_specification["Output"] = {} if output_description is None else output_description | |||
| return semantic_specification | |||
| def list_semantic_specification_values(self, key: SemanticSpecificationKey): | |||
| url = f"{self.host}/engine/semantic_specification" | |||
| response = requests.get(url, headers=self.headers) | |||
| @@ -416,7 +398,7 @@ class LearnwareClient: | |||
| @staticmethod | |||
| def _check_semantic_specification(semantic_spec): | |||
| from ..market import EasySemanticChecker | |||
| check_status, message = EasySemanticChecker.check_semantic_spec(semantic_spec) | |||
| return check_status != BaseChecker.INVALID_LEARNWARE, message | |||
| @@ -430,10 +412,16 @@ class LearnwareClient: | |||
| @staticmethod | |||
| def check_learnware(learnware_zip_path, semantic_specification=None): | |||
| semantic_specification = ( | |||
| get_semantic_specification() if semantic_specification is None else semantic_specification | |||
| ) | |||
| semantic_specification = generate_semantic_spec( | |||
| name="test", | |||
| description="test", | |||
| data_type="Text", | |||
| task_type="Segmentation", | |||
| scenarios="Financial", | |||
| library_type="Scikit-learn", | |||
| license="Apache-2.0", | |||
| ) if semantic_specification is None else semantic_specification | |||
| check_status, message = LearnwareClient._check_semantic_specification(semantic_specification) | |||
| assert check_status, f"Semantic specification check failed due to {message}!" | |||
| @@ -442,12 +430,9 @@ class LearnwareClient: | |||
| z_file.extractall(tempdir) | |||
| learnware = get_learnware_from_dirpath( | |||
| id="test", semantic_spec=semantic_specification, learnware_dirpath=tempdir | |||
| id="test", semantic_spec=semantic_specification, learnware_dirpath=tempdir, ignore_error=False | |||
| ) | |||
| if learnware is None: | |||
| raise Exception("The learnware is not valid.") | |||
| check_status, message = LearnwareClient._check_stat_specification(learnware) | |||
| assert check_status is True, message | |||
| @@ -28,8 +28,6 @@ def system_execute(args, timeout=None, env=None, stdout=subprocess.DEVNULL, stde | |||
| def remove_enviroment(conda_env): | |||
| system_execute(args=["conda", "env", "remove", "-n", f"{conda_env}"]) | |||
| logger.info(f"The learnware conda env [{conda_env}] is removed.") | |||
| def install_environment(learnware_dirpath, conda_env): | |||
| """Install environment of a learnware | |||
| @@ -51,7 +49,7 @@ def install_environment(learnware_dirpath, conda_env): | |||
| if "environment.yaml" in os.listdir(learnware_dirpath): | |||
| yaml_path: str = os.path.join(learnware_dirpath, "environment.yaml") | |||
| yaml_path_filter: str = os.path.join(tempdir, "environment_filter.yaml") | |||
| logger.info(f"checking the avaliabe conda packages for {conda_env}") | |||
| logger.info(f"checking the available conda packages for {conda_env}") | |||
| filter_nonexist_conda_packages_file(yaml_file=yaml_path, output_yaml_file=yaml_path_filter) | |||
| # create environment | |||
| logger.info(f"create conda env [{conda_env}] according to .yaml file") | |||
| @@ -60,7 +58,7 @@ def install_environment(learnware_dirpath, conda_env): | |||
| elif "requirements.txt" in os.listdir(learnware_dirpath): | |||
| requirements_path: str = os.path.join(learnware_dirpath, "requirements.txt") | |||
| requirements_path_filter: str = os.path.join(tempdir, "requirements_filter.txt") | |||
| logger.info(f"checking the avaliabe pip packages for {conda_env}") | |||
| logger.info(f"checking the available pip packages for {conda_env}") | |||
| filter_nonexist_pip_packages_file(requirements_file=requirements_path, output_file=requirements_path_filter) | |||
| logger.info(f"create empty conda env [{conda_env}]") | |||
| system_execute(args=["conda", "create", "-y", "--name", f"{conda_env}", "python=3.8"]) | |||
| @@ -1,8 +1,9 @@ | |||
| import os | |||
| import copy | |||
| from typing import Optional | |||
| import traceback | |||
| from .base import Learnware | |||
| from .utils import get_stat_spec_from_config | |||
| from ..specification import Specification | |||
| from ..utils import read_yaml_to_dict | |||
| @@ -12,7 +13,7 @@ from ..config import C | |||
| logger = get_module_logger("learnware.learnware") | |||
| def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True) -> Learnware: | |||
| def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, ignore_error=True) -> Optional[Learnware]: | |||
| """Get the learnware object from dirpath, and provide the manage interface tor Learnware class | |||
| Parameters | |||
| @@ -45,7 +46,12 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, | |||
| } | |||
| try: | |||
| yaml_config = read_yaml_to_dict(os.path.join(learnware_dirpath, C.learnware_folder_config["yaml_file"])) | |||
| learnware_yaml_path = os.path.join(learnware_dirpath, C.learnware_folder_config["yaml_file"]) | |||
| assert os.path.exists(learnware_yaml_path), f"learnware.yaml is not found for learnware_{id}, please check the learnware folder or zipfile." | |||
| yaml_config = read_yaml_to_dict(learnware_yaml_path) | |||
| if "name" in yaml_config: | |||
| learnware_config["name"] = yaml_config["name"] | |||
| @@ -60,7 +66,10 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, | |||
| learnware_spec = Specification() | |||
| for _stat_spec in learnware_config["stat_specifications"]: | |||
| stat_spec = _stat_spec.copy() | |||
| stat_spec["file_name"] = os.path.join(learnware_dirpath, stat_spec["file_name"]) | |||
| stat_spec_path = os.path.join(learnware_dirpath, stat_spec["file_name"]) | |||
| assert os.path.exists(stat_spec_path), f"statistical specification file {stat_spec['file_name']} is not found for learnware_{id}, please check the learnware folder or zipfile." | |||
| stat_spec["file_name"] = stat_spec_path | |||
| stat_spec_inst = get_stat_spec_from_config(stat_spec) | |||
| learnware_spec.update_stat_spec(**{stat_spec_inst.type: stat_spec_inst}) | |||
| @@ -69,7 +78,7 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath, | |||
| except Exception as e: | |||
| if not ignore_error: | |||
| raise e | |||
| logger.warning(f"Load Learnware {id} failed! Due to {repr(e)}") | |||
| logger.warning(f"Load Learnware {id} failed! Due to {e}; details:\n{traceback.format_exc()}") | |||
| return None | |||
| return Learnware( | |||
| @@ -1,9 +1,8 @@ | |||
| from typing import List, Dict, Tuple, Any | |||
| from typing import Dict | |||
| from ..easy.organizer import EasyOrganizer | |||
| from ...logger import get_module_logger | |||
| from ...learnware import Learnware | |||
| from ...specification import BaseStatSpecification | |||
| logger = get_module_logger("anchor_organizer") | |||
| @@ -1,7 +1,6 @@ | |||
| from typing import List, Dict, Tuple, Any, Union | |||
| from typing import List, Tuple, Any | |||
| from .user_info import AnchoredUserInfo | |||
| from ..base import BaseUserInfo | |||
| from ..easy.searcher import EasySearcher | |||
| from ...logger import get_module_logger | |||
| from ...learnware import Learnware | |||
| @@ -3,7 +3,7 @@ from __future__ import annotations | |||
| import traceback | |||
| import zipfile | |||
| import tempfile | |||
| from typing import Tuple, Any, List, Union, Dict, Optional | |||
| from typing import Tuple, Any, List, Union, Optional | |||
| from dataclasses import dataclass | |||
| from ..learnware import Learnware, get_learnware_from_dirpath | |||
| from ..logger import get_module_logger | |||
| @@ -45,7 +45,7 @@ class BaseUserInfo: | |||
| def update_semantic_spec(self, semantic_spec: dict): | |||
| self.semantic_spec = semantic_spec | |||
| def update_stat_info(self, name: str, item: Any): | |||
| """Update stat_info by market | |||
| @@ -64,28 +64,35 @@ class SingleSearchItem: | |||
| learnware: Learnware | |||
| score: Optional[float] = None | |||
| @dataclass | |||
| class MultipleSearchItem: | |||
| learnwares: List[Learnware] | |||
| score: float | |||
| class SearchResults: | |||
| def __init__(self, single_results: Optional[List[SingleSearchItem]] = None, multiple_results: Optional[List[MultipleSearchItem]] = None): | |||
| def __init__( | |||
| self, | |||
| single_results: Optional[List[SingleSearchItem]] = None, | |||
| multiple_results: Optional[List[MultipleSearchItem]] = None, | |||
| ): | |||
| self.update_single_results([] if single_results is None else single_results) | |||
| self.update_multiple_results([] if multiple_results is None else multiple_results) | |||
| def get_single_results(self) -> List[SingleSearchItem]: | |||
| return self.single_results | |||
| def get_multiple_results(self) -> List[MultipleSearchItem]: | |||
| return self.multiple_results | |||
| def update_single_results(self, single_results: List[SingleSearchItem]): | |||
| self.single_results = single_results | |||
| def update_multiple_results(self, multiple_results: List[MultipleSearchItem]): | |||
| self.multiple_results = multiple_results | |||
| class LearnwareMarket: | |||
| """Base interface for market, it provide the interface of search/add/detele/update learnwares""" | |||
| @@ -132,7 +139,7 @@ class LearnwareMarket: | |||
| def check_learnware(self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs) -> bool: | |||
| try: | |||
| final_status = BaseChecker.NONUSABLE_LEARNWARE | |||
| if len(checker_names): | |||
| if checker_names is not None and len(checker_names): | |||
| with tempfile.TemporaryDirectory(prefix="pending_learnware_") as tempdir: | |||
| with zipfile.ZipFile(zip_path, mode="r") as z_file: | |||
| z_file.extractall(tempdir) | |||
| @@ -179,9 +186,7 @@ class LearnwareMarket: | |||
| zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs | |||
| ) | |||
| def search_learnware( | |||
| self, user_info: BaseUserInfo, check_status: int = None, **kwargs | |||
| ) -> SearchResults: | |||
| def search_learnware(self, user_info: BaseUserInfo, check_status: int = None, **kwargs) -> SearchResults: | |||
| """Search learnwares based on user_info from learnwares with check_status | |||
| Parameters | |||
| @@ -1,6 +1,6 @@ | |||
| from sqlalchemy.ext.declarative import declarative_base | |||
| from sqlalchemy import create_engine, text | |||
| from sqlalchemy import Column, Integer, Text, DateTime, String | |||
| from sqlalchemy import Column, Text, String | |||
| import os | |||
| import json | |||
| import traceback | |||
| @@ -245,7 +245,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): | |||
| ret = [] | |||
| for idx in ids: | |||
| spec = self.learnware_list[idx].get_specification() | |||
| if is_hetero(stat_specs=spec.get_stat_spec(), semantic_spec=spec.get_semantic_spec()): | |||
| if is_hetero(stat_specs=spec.get_stat_spec(), semantic_spec=spec.get_semantic_spec(), verbose=False): | |||
| ret.append(idx) | |||
| return ret | |||
| @@ -1,15 +1,15 @@ | |||
| from typing import Callable, List, Optional, Union | |||
| from typing import Callable, Union | |||
| import numpy as np | |||
| import pandas as pd | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import Tensor, nn | |||
| from torch import nn | |||
| from .....utils import allocate_cuda_idx, choose_device | |||
| from .....specification import HeteroMapTableSpecification, RKMETableSpecification | |||
| from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer | |||
| from .trainer import Trainer, TransTabCollatorForCL | |||
| from .trainer import TransTabCollatorForCL, Trainer | |||
| class HeteroMap(nn.Module): | |||
| @@ -127,6 +127,7 @@ class HeteroMap(nn.Module): | |||
| self.base_temperature = base_temperature | |||
| self.num_partition = num_partition | |||
| self.overlap_ratio = overlap_ratio | |||
| self.max_process_size = 20480 | |||
| self.to(device) | |||
| def to(self, device: Union[str, torch.device]): | |||
| @@ -306,6 +307,10 @@ class HeteroMap(nn.Module): | |||
| """ | |||
| self.eval() | |||
| output_feas_list = [] | |||
| if eval_batch_size * x_test.shape[1] > self.max_process_size: | |||
| eval_batch_size = max(1, self.max_process_size // x_test.shape[1]) | |||
| for i in range(0, len(x_test), eval_batch_size): | |||
| bs_x_test = x_test.iloc[i : i + eval_batch_size] | |||
| with torch.no_grad(): | |||
| @@ -1,6 +1,6 @@ | |||
| import math | |||
| import os | |||
| from typing import Callable, Dict, List, Union | |||
| from typing import Dict, List, Union | |||
| import numpy as np | |||
| import pandas as pd | |||
| @@ -1,9 +1,10 @@ | |||
| import traceback | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("hetero_utils") | |||
| def is_hetero(stat_specs: dict, semantic_spec: dict) -> bool: | |||
| def is_hetero(stat_specs: dict, semantic_spec: dict, verbose=True) -> bool: | |||
| """Check if user_info satifies all the criteria required for enabling heterogeneous learnware search | |||
| Parameters | |||
| @@ -35,15 +36,17 @@ def is_hetero(stat_specs: dict, semantic_spec: dict) -> bool: | |||
| semantic_decription_feature_num = len(semantic_input_description["Description"]) | |||
| if semantic_decription_feature_num <= 0: | |||
| logger.warning("At least one of Input.Description in semantic spec should be provides.") | |||
| if verbose: | |||
| logger.warning("At least one of Input.Description in semantic spec should be provides.") | |||
| return False | |||
| if table_input_shape != semantic_description_dim: | |||
| logger.warning("User data feature dimensions mismatch with semantic specification.") | |||
| if verbose: | |||
| logger.warning("User data feature dimensions mismatch with semantic specification.") | |||
| return False | |||
| return True | |||
| except Exception as e: | |||
| logger.warning(f"Invalid heterogeneous search information provided due to {e}. Use homogeneous search instead.") | |||
| except Exception as err: | |||
| if verbose: | |||
| logger.warning(f"Invalid heterogeneous search information provided.") | |||
| return False | |||
| @@ -1,9 +1,10 @@ | |||
| from .base import LearnwareMarket | |||
| from .classes import CondaChecker | |||
| from .easy import EasyOrganizer, EasySearcher, EasySemanticChecker, EasyStatChecker | |||
| from .heterogeneous import HeteroMapTableOrganizer, HeteroSearcher | |||
| def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None): | |||
| def get_market_component(name, market_id, rebuild, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, conda_checker=False): | |||
| organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs | |||
| searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs | |||
| checker_kwargs = {} if checker_kwargs is None else checker_kwargs | |||
| @@ -11,7 +12,7 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search | |||
| if name == "easy": | |||
| easy_organizer = EasyOrganizer(market_id=market_id, rebuild=rebuild) | |||
| easy_searcher = EasySearcher(organizer=easy_organizer) | |||
| easy_checker_list = [EasySemanticChecker(), EasyStatChecker()] | |||
| easy_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())] | |||
| market_component = { | |||
| "organizer": easy_organizer, | |||
| "searcher": easy_searcher, | |||
| @@ -20,7 +21,7 @@ def get_market_component(name, market_id, rebuild, organizer_kwargs=None, search | |||
| elif name == "hetero": | |||
| hetero_organizer = HeteroMapTableOrganizer(market_id=market_id, rebuild=rebuild, **organizer_kwargs) | |||
| hetero_searcher = HeteroSearcher(organizer=hetero_organizer) | |||
| hetero_checker_list = [EasySemanticChecker(), EasyStatChecker()] | |||
| hetero_checker_list = [EasySemanticChecker(), EasyStatChecker() if conda_checker is False else CondaChecker(EasyStatChecker())] | |||
| market_component = { | |||
| "organizer": hetero_organizer, | |||
| @@ -40,9 +41,10 @@ def instantiate_learnware_market( | |||
| organizer_kwargs: dict = None, | |||
| searcher_kwargs: dict = None, | |||
| checker_kwargs: dict = None, | |||
| conda_checker: bool = False, | |||
| **kwargs, | |||
| ): | |||
| market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs) | |||
| market_componets = get_market_component(name, market_id, rebuild, organizer_kwargs, searcher_kwargs, checker_kwargs, conda_checker) | |||
| return LearnwareMarket( | |||
| organizer=market_componets["organizer"], | |||
| searcher=market_componets["searcher"], | |||
| @@ -17,5 +17,12 @@ if not is_torch_available(verbose=False): | |||
| generate_rkme_table_spec = None | |||
| generate_rkme_image_spec = None | |||
| generate_rkme_text_spec = None | |||
| generate_semantic_spec = None | |||
| else: | |||
| from .module import generate_stat_spec, generate_rkme_table_spec, generate_rkme_image_spec, generate_rkme_text_spec | |||
| from .module import ( | |||
| generate_stat_spec, | |||
| generate_rkme_table_spec, | |||
| generate_rkme_image_spec, | |||
| generate_rkme_text_spec, | |||
| generate_semantic_spec, | |||
| ) | |||
| @@ -1,7 +1,7 @@ | |||
| import torch | |||
| import numpy as np | |||
| import pandas as pd | |||
| from typing import Union, List | |||
| from typing import Union, List, Optional | |||
| from .utils import convert_to_numpy | |||
| from .base import BaseStatSpecification | |||
| @@ -175,7 +175,7 @@ def generate_rkme_text_spec( | |||
| def generate_stat_spec( | |||
| type: str, X: Union[np.ndarray, pd.DataFrame, torch.Tensor, List[str]], *args, **kwargs | |||
| ) -> BaseStatSpecification: | |||
| ) -> Union[RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification]: | |||
| """ | |||
| Interface for users to generate statistical specification. | |||
| Return a StatSpecification object, use .save() method to save as npy file. | |||
| @@ -202,3 +202,41 @@ def generate_stat_spec( | |||
| return generate_rkme_image_spec(X=X, *args, **kwargs) | |||
| else: | |||
| raise TypeError(f"type {type} is not supported!") | |||
| def generate_semantic_spec( | |||
| name: Optional[str] = None, | |||
| description: Optional[str] = None, | |||
| data_type: Optional[str] = None, | |||
| task_type: Optional[str] = None, | |||
| library_type: Optional[str] = None, | |||
| scenarios: Optional[Union[str, List[str]]] = None, | |||
| license: Optional[Union[str, List[str]]] = None, | |||
| input_description: Optional[dict] = None, | |||
| output_description: Optional[dict] = None, | |||
| ): | |||
| semantic_specification = dict() | |||
| semantic_specification["Data"] = {"Type": "Class", "Values": [data_type] if data_type is not None else []} | |||
| semantic_specification["Task"] = {"Type": "Class", "Values": [task_type] if task_type is not None else []} | |||
| semantic_specification["Library"] = { | |||
| "Type": "Class", | |||
| "Values": [library_type] if library_type is not None else [], | |||
| } | |||
| license = [license] if isinstance(license, str) else license | |||
| semantic_specification["License"] = {"Type": "Class", "Values": license if license is not None else []} | |||
| scenarios = [scenarios] if isinstance(scenarios, str) else scenarios | |||
| semantic_specification["Scenario"] = {"Type": "Tag", "Values": scenarios if scenarios is not None else []} | |||
| semantic_specification["Name"] = {"Type": "String", "Values": name if name is not None else ""} | |||
| semantic_specification["Description"] = { | |||
| "Type": "String", | |||
| "Values": description if description is not None else "", | |||
| } | |||
| if input_description is not None: | |||
| semantic_specification["Input"] = input_description | |||
| if output_description is not None: | |||
| semantic_specification["Output"] = output_description | |||
| return semantic_specification | |||
| @@ -1 +1 @@ | |||
| from .module import get_semantic_specification | |||
| from .utils import parametrize | |||
| @@ -0,0 +1,161 @@ | |||
| import os | |||
| import pickle | |||
| import tempfile | |||
| import zipfile | |||
| from dataclasses import dataclass | |||
| from typing import Tuple, Optional, List, Union | |||
| from .config import BenchmarkConfig, benchmark_configs | |||
| from ..data import GetData | |||
| from ...config import C | |||
| @dataclass | |||
| class Benchmark: | |||
| learnware_ids: List[str] | |||
| user_num: int | |||
| test_X_paths: List[str] | |||
| test_y_paths: List[str] | |||
| train_X_paths: Optional[List[str]] = None | |||
| train_y_paths: Optional[List[str]] = None | |||
| extra_info_path: Optional[str] = None | |||
| def get_test_data(self, user_ids: Union[str, List[str]]): | |||
| if isinstance(user_ids, str): | |||
| user_ids = [user_ids] | |||
| ret = [] | |||
| for user_id in user_ids: | |||
| with open(self.test_X_paths[user_id], "rb") as fin: | |||
| test_X = pickle.load(fin) | |||
| with open(self.test_y_paths[user_id], "rb") as fin: | |||
| test_y = pickle.load(fin) | |||
| ret.append((test_X, test_y)) | |||
| return ret | |||
| def get_train_data(self, user_ids): | |||
| if self.train_X_paths is None or self.train_y_paths is None: | |||
| return None | |||
| if isinstance(user_ids, str): | |||
| user_ids = [user_ids] | |||
| ret = [] | |||
| for user_id in user_ids: | |||
| with open(self.train_X_paths[user_id], "rb") as fin: | |||
| train_X = pickle.load(fin) | |||
| with open(self.train_y_paths[user_id], "rb") as fin: | |||
| train_y = pickle.load(fin) | |||
| ret.append((train_X, train_y)) | |||
| return ret | |||
| class LearnwareBenchmark: | |||
| def __init__(self): | |||
| self.benchmark_configs = benchmark_configs | |||
| def list_benchmarks(self): | |||
| return list(self.benchmark_configs.keys()) | |||
| def _check_cache_data_valid(self, benchmark_config: BenchmarkConfig, data_type: str) -> bool: | |||
| """Check if the cache data is valid | |||
| Parameters | |||
| ---------- | |||
| benchmark_config : BenchmarkConfig | |||
| benchmark config | |||
| data_type : str | |||
| "test" for test data or "train" for train data | |||
| Returns | |||
| ------- | |||
| bool | |||
| A flag indicating if the cache data is valid | |||
| """ | |||
| cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data") | |||
| if os.path.exists(cache_folder): | |||
| for user_id in range(benchmark_config.user_num): | |||
| X_path = os.path.join(cache_folder, f"user{user_id}_X.pkl") | |||
| y_path = os.path.join(cache_folder, f"user{user_id}_X.pkl") | |||
| if not os.path.isfile(X_path) or not os.path.isfile(y_path): | |||
| return False | |||
| return True | |||
| else: | |||
| return False | |||
| def _download_data(self, download_path: str, save_path: str): | |||
| """Download data from backend | |||
| Parameters | |||
| ---------- | |||
| download_path : str | |||
| data path for download in backend | |||
| save_path : str | |||
| local cache path for saving data | |||
| """ | |||
| with tempfile.TemporaryDirectory(prefix="learnware_benchmark_") as tempdir: | |||
| test_data_zippath = os.path.join(tempdir, "benchmark_data.zip") | |||
| GetData().download_file(download_path, test_data_zippath) | |||
| os.makedirs(save_path, exist_ok=True) | |||
| with zipfile.ZipFile(test_data_zippath, "r") as z_file: | |||
| z_file.extractall(save_path) | |||
| def _load_cache_data(self, benchmark_config: BenchmarkConfig, data_type: str) -> Tuple(List[str], List[str]): | |||
| """Load data from local cache path | |||
| Parameters | |||
| ---------- | |||
| benchmark_config : BenchmarkConfig | |||
| benchmark config | |||
| data_type : str | |||
| "test" for test data or "train" for train data | |||
| """ | |||
| cache_folder = os.path.join(C.cache_path, benchmark_config.name, f"{data_type}_data") | |||
| if not self._check_cache_data_valid(benchmark_config, data_type): | |||
| download_path = getattr(benchmark_config, f"{data_type}_data_path", None) | |||
| self._download_data(download_path, cache_folder) | |||
| X_paths, y_paths = [], [] | |||
| for user_id in range(benchmark_config.user_num): | |||
| user_X_path = os.path.join(cache_folder, f"user{user_id}_X.pkl") | |||
| user_y_path = os.path.join(cache_folder, f"user{user_id}_y.pkl") | |||
| assert os.path.isfile(user_X_path), f"user {user_id} {data_type}_X is not valid!" | |||
| assert os.path.isfile(user_y_path), f"user {user_id} {data_type}_y is not valid!" | |||
| X_paths.append(user_X_path) | |||
| y_paths.append(user_y_path) | |||
| def get_benchmark(self, benchmark_config: Union[str, BenchmarkConfig]): | |||
| if isinstance(benchmark_config, str): | |||
| benchmark_config = self.benchmark_configs[benchmark_config] | |||
| # Load test data | |||
| test_X_paths, test_y_paths = self._load_cache_data(benchmark_config, "test") | |||
| # Load train data | |||
| train_X_paths, train_y_paths = None, None | |||
| if benchmark_config.train_data_path is not None: | |||
| train_X_paths, train_y_paths = self._load_cache_data(benchmark_config, "train") | |||
| # Load extra info | |||
| extra_info_path = None | |||
| if benchmark_config.extra_info_path is not None: | |||
| extra_info_path = os.path.join(C.cache_path, benchmark_config.name, "extra_info") | |||
| if not os.path.exists(extra_info_path): | |||
| self._download_data(benchmark_config.extra_info_path, extra_info_path) | |||
| return Benchmark( | |||
| learnware_ids=benchmark_config.learnware_ids, | |||
| user_num=benchmark_config.user_num, | |||
| test_X_paths=test_X_paths, | |||
| test_y_paths=test_y_paths, | |||
| train_X_paths=train_X_paths, | |||
| train_y_paths=train_y_paths, | |||
| extra_info_path=extra_info_path, | |||
| ) | |||
| @@ -0,0 +1,24 @@ | |||
| from dataclasses import dataclass | |||
| from typing import Optional, List | |||
| @dataclass | |||
| class BenchmarkConfig: | |||
| name: str | |||
| learnware_ids: List[str] | |||
| user_num: int | |||
| test_data_path: str | |||
| train_data_path: Optional[str] = None | |||
| extra_info_path: Optional[str] = None | |||
| benchmark_configs = { | |||
| "example": BenchmarkConfig( | |||
| name="example", | |||
| learnware_ids=["00001951", "00001980", "00001987"], | |||
| user_num=3, | |||
| test_data_path="example_path1", | |||
| train_data_path="example_path2", | |||
| extra_info_path="example_path3", | |||
| ) | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| import json | |||
| import requests | |||
| from tqdm import tqdm | |||
| from ..config import C | |||
| class GetData: | |||
| def __init__(self, host=None, chunk_size=1024 * 1024): | |||
| self.headers = None | |||
| if host is None: | |||
| self.host = C.backend_host | |||
| else: | |||
| self.host = host | |||
| self.chunk_size = chunk_size | |||
| def download_file(self, file_path: str, save_path: str): | |||
| url = f"{self.host}/engine/download" | |||
| response = requests.get( | |||
| url, | |||
| params={ | |||
| "file_path": file_path, | |||
| }, | |||
| stream=True, | |||
| ) | |||
| if response.status_code != 200: | |||
| raise Exception("download failed: " + json.dumps(response.json())) | |||
| num_chunks = int(response.headers["Content-Length"]) // self.chunk_size + 1 | |||
| bar = tqdm(total=num_chunks, desc="Downloading", unit="MB") | |||
| with open(save_path, "wb") as f: | |||
| for chunk in response.iter_content(chunk_size=self.chunk_size): | |||
| f.write(chunk) | |||
| bar.update(1) | |||
| @@ -1,10 +0,0 @@ | |||
| def get_semantic_specification(): | |||
| semantic_specification = dict() | |||
| semantic_specification["Data"] = {"Type": "Class", "Values": ["Text"]} | |||
| semantic_specification["Task"] = {"Type": "Class", "Values": ["Segmentation"]} | |||
| semantic_specification["Library"] = {"Type": "Class", "Values": ["Scikit-learn"]} | |||
| semantic_specification["Scenario"] = {"Type": "Tag", "Values": ["Financial"]} | |||
| semantic_specification["License"] = {"Type": "Class", "Values": ["Apache-2.0"]} | |||
| semantic_specification["Name"] = {"Type": "String", "Values": "test"} | |||
| semantic_specification["Description"] = {"Type": "String", "Values": "test"} | |||
| return semantic_specification | |||
| @@ -0,0 +1,101 @@ | |||
| import os | |||
| import tempfile | |||
| from dataclasses import dataclass, field | |||
| from shutil import copyfile | |||
| from typing import List, Tuple, Union, Optional | |||
| from ...utils import save_dict_to_yaml, convert_folder_to_zipfile | |||
| from ...config import C | |||
| @dataclass | |||
| class ModelTemplate: | |||
| class_name: str = field(init=False) | |||
| template_path: str = field(init=False) | |||
| model_kwargs: dict = field(init=False) | |||
| @dataclass | |||
| class PickleModelTemplate(ModelTemplate): | |||
| model_kwargs: dict | |||
| pickle_filepath: str | |||
| def __post_init__(self): | |||
| self.class_name = "PickleLoadedModel" | |||
| self.template_path = os.path.join(C.package_path, "tests", "templates", "pickle_model.py") | |||
| default_model_kwargs = { | |||
| "predict_method": "predict", | |||
| "fit_method": "fit", | |||
| "finetune_method": "finetune", | |||
| "pickle_filename": "model.pkl", | |||
| } | |||
| default_model_kwargs.update(self.model_kwargs) | |||
| self.model_kwargs = default_model_kwargs | |||
| @dataclass | |||
| class StatSpecTemplate: | |||
| filepath: str | |||
| type: str = field(default="RKMETableSpecification") | |||
| class LearnwareTemplate: | |||
| @staticmethod | |||
| def generate_requirements(filepath, requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None): | |||
| requirements = [] if requirements is None else requirements | |||
| operators = {"==", "~=", ">=", "<=", ">", "<"} | |||
| requirements_str = "" | |||
| for requirement in requirements: | |||
| if isinstance(requirement, str): | |||
| line_str = requirement.strip() + "\n" | |||
| elif isinstance(requirement, tuple): | |||
| assert requirement[1] in operators, f"The operator of requirements is not supported." | |||
| line_str = requirement[0].strip() + requirement[1].strip() + requirement[2].strip() + "\n" | |||
| else: | |||
| raise TypeError(f"requirement must be type str/tuple, rather than {type(requirement)}") | |||
| requirements_str += line_str | |||
| with open(filepath, "w") as fdout: | |||
| fdout.write(requirements_str) | |||
| @staticmethod | |||
| def generate_learnware_yaml(filepath, model_config: Optional[dict] = None, stat_spec_config: Optional[List[dict]] = None): | |||
| learnware_config = {} | |||
| if model_config is not None: | |||
| learnware_config["model"] = model_config | |||
| if stat_spec_config is not None: | |||
| learnware_config["stat_specifications"] = stat_spec_config | |||
| save_dict_to_yaml(learnware_config, filepath) | |||
| @staticmethod | |||
| def generate_learnware_zipfile( | |||
| learnware_zippath: str, | |||
| model_template: ModelTemplate, | |||
| stat_spec_template: StatSpecTemplate, | |||
| requirements: Optional[List[Union[Tuple[str, str, str], str]]] = None, | |||
| ): | |||
| with tempfile.TemporaryDirectory(suffix="learnware_template") as tempdir: | |||
| requirement_filepath = os.path.join(tempdir, "requirements.txt") | |||
| LearnwareTemplate.generate_requirements(requirement_filepath, requirements) | |||
| model_filepath = os.path.join(tempdir, "__init__.py") | |||
| copyfile(model_template.template_path, model_filepath) | |||
| learnware_yaml_filepath = os.path.join(tempdir, "learnware.yaml") | |||
| model_config = { | |||
| "class_name": model_template.class_name, | |||
| "kwargs": model_template.model_kwargs, | |||
| } | |||
| stat_spec_config = { | |||
| "module_path": "learnware.specification", | |||
| "class_name": stat_spec_template.type, | |||
| "file_name": "stat_spec.json", | |||
| "kwargs": {} | |||
| } | |||
| copyfile(stat_spec_template.filepath, os.path.join(tempdir, stat_spec_config["file_name"])) | |||
| LearnwareTemplate.generate_learnware_yaml(learnware_yaml_filepath, model_config, stat_spec_config=[stat_spec_config]) | |||
| if isinstance(model_template, PickleModelTemplate): | |||
| pickle_filepath = os.path.join(tempdir, model_template.model_kwargs["pickle_filename"]) | |||
| copyfile(model_template.pickle_filepath, pickle_filepath) | |||
| convert_folder_to_zipfile(tempdir, learnware_zippath) | |||
| @@ -0,0 +1,33 @@ | |||
| import os | |||
| import pickle | |||
| import numpy as np | |||
| from learnware.model.base import BaseModel | |||
| class PickleLoadedModel(BaseModel): | |||
| def __init__( | |||
| self, | |||
| input_shape, | |||
| output_shape, | |||
| predict_method="predict", | |||
| fit_method="fit", | |||
| finetune_method="finetune", | |||
| pickle_filename="model.pkl", | |||
| ): | |||
| super(PickleLoadedModel, self).__init__(input_shape=input_shape, output_shape=output_shape) | |||
| dir_path = os.path.dirname(os.path.abspath(__file__)) | |||
| self.pickle_filepath = os.path.join(dir_path, pickle_filename) | |||
| with open(self.pickle_filepath, "rb") as fd: | |||
| self.model = pickle.load(fd) | |||
| self.predict_method = predict_method | |||
| self.fit_method = fit_method | |||
| self.finetune_method = finetune_method | |||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||
| return getattr(self.model, self.predict_method)(X) | |||
| def fit(self, X: np.ndarray, y: np.ndarray): | |||
| getattr(self.model, self.fit_method)(X, y) | |||
| def finetune(self, X: np.ndarray, y: np.ndarray): | |||
| getattr(self.model, self.finetune_method)(X, y) | |||
| @@ -0,0 +1,9 @@ | |||
| import unittest | |||
| def parametrize(test_class, **kwargs): | |||
| test_loader = unittest.TestLoader() | |||
| test_names = test_loader.getTestCaseNames(test_class) | |||
| _suite = unittest.TestSuite() | |||
| for name in test_names: | |||
| _suite.addTest(test_class(name, **kwargs)) | |||
| return _suite | |||
| @@ -3,7 +3,7 @@ import zipfile | |||
| from .import_utils import is_torch_available | |||
| from .module import get_module_by_module_path | |||
| from .file import read_yaml_to_dict, save_dict_to_yaml | |||
| from .file import read_yaml_to_dict, save_dict_to_yaml, convert_folder_to_zipfile | |||
| from .gpu import setup_seed, choose_device, allocate_cuda_idx | |||
| from ..config import get_platform, SystemType | |||
| @@ -1,5 +1,6 @@ | |||
| import os | |||
| import yaml | |||
| import zipfile | |||
| def save_dict_to_yaml(dict_value: dict, save_path: str): | |||
| """save dict object into yaml file""" | |||
| @@ -12,3 +13,13 @@ def read_yaml_to_dict(yaml_path: str): | |||
| with open(yaml_path, "r") as file: | |||
| dict_value = yaml.load(file.read(), Loader=yaml.FullLoader) | |||
| return dict_value | |||
| def convert_folder_to_zipfile(folder_path, zip_path): | |||
| with zipfile.ZipFile(zip_path, "w") as zip_obj: | |||
| for foldername, subfolders, filenames in os.walk(folder_path): | |||
| for filename in filenames: | |||
| file_path = os.path.join(foldername, filename) | |||
| zip_info = zipfile.ZipInfo(filename) | |||
| zip_info.compress_type = zipfile.ZIP_STORED | |||
| with open(file_path, "rb") as file: | |||
| zip_obj.writestr(zip_info, file.read()) | |||
| @@ -0,0 +1,103 @@ | |||
| import os | |||
| import unittest | |||
| import tempfile | |||
| import logging | |||
| import learnware | |||
| learnware.init(logging_level=logging.WARNING) | |||
| from learnware.learnware import Learnware | |||
| from learnware.client import LearnwareClient | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo, EasySemanticChecker | |||
| from learnware.config import C | |||
| class TestSearch(unittest.TestCase): | |||
| client = LearnwareClient() | |||
| @classmethod | |||
| def setUpClass(cls): | |||
| cls.market = instantiate_learnware_market(market_id="search_test", name="hetero", rebuild=True) | |||
| if cls.client.is_connected(): | |||
| cls._build_learnware_market() | |||
| @classmethod | |||
| def _build_learnware_market(cls): | |||
| table_learnware_ids = ["00001951", "00001980", "00001987"] | |||
| image_learnware_ids = ["00000851", "00000858", "00000841"] | |||
| text_learnware_ids = ["00000652", "00000637"] | |||
| learnware_ids = table_learnware_ids + image_learnware_ids + text_learnware_ids | |||
| with tempfile.TemporaryDirectory(prefix="learnware_search_test") as tempdir: | |||
| for learnware_id in learnware_ids: | |||
| learnware_zippath = os.path.join(tempdir, f"learnware_{learnware_id}.zip") | |||
| try: | |||
| cls.client.download_learnware(learnware_id=learnware_id, save_path=learnware_zippath) | |||
| semantic_spec = ( | |||
| cls.client.load_learnware(learnware_path=learnware_zippath) | |||
| .get_specification() | |||
| .get_semantic_spec() | |||
| ) | |||
| except Exception: | |||
| print("'learnware_id' is passed due to the network problem.") | |||
| cls.market.add_learnware( | |||
| learnware_zippath, | |||
| learnware_id=learnware_id, | |||
| semantic_spec=semantic_spec, | |||
| checker_names=["EasySemanticChecker"], | |||
| ) | |||
| def _skip_test(self): | |||
| if not self.client.is_connected(): | |||
| print("Client can not connect!") | |||
| return True | |||
| return False | |||
| def test_image_search(self): | |||
| if not self._skip_test(): | |||
| learnware_id = "00000619" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_image_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| def test_text_search(self): | |||
| if not self._skip_test(): | |||
| learnware_id = "00000653" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_text_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| def test_table_search(self): | |||
| if not self._skip_test(): | |||
| learnware_id = "00001950" | |||
| try: | |||
| learnware: Learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| except Exception: | |||
| print("'test_table_search' is passed due to the network problem.") | |||
| user_info = BaseUserInfo(stat_info=learnware.get_specification().get_stat_spec()) | |||
| search_result = self.market.search_learnware(user_info) | |||
| print("Single Search Results:", search_result.get_single_results()) | |||
| print("Multiple Search Results:", search_result.get_multiple_results()) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestSearch("test_image_search")) | |||
| _suite.addTest(TestSearch("test_text_search")) | |||
| _suite.addTest(TestSearch("test_table_search")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| @@ -1,8 +0,0 @@ | |||
| model: | |||
| class_name: MyModel | |||
| kwargs: {} | |||
| stat_specifications: | |||
| - module_path: learnware.specification | |||
| class_name: RKMETableSpecification | |||
| file_name: stat.json | |||
| kwargs: {} | |||
| @@ -1,16 +0,0 @@ | |||
| from learnware.model import BaseModel | |||
| import numpy as np | |||
| import joblib | |||
| import os | |||
| class MyModel(BaseModel): | |||
| def __init__(self): | |||
| super(MyModel, self).__init__(input_shape=(20,), output_shape=(1,)) | |||
| dir_path = os.path.dirname(os.path.abspath(__file__)) | |||
| model_path = os.path.join(dir_path, "ridge.pkl") | |||
| model = joblib.load(model_path) | |||
| self.model = model | |||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||
| return self.model.predict(X) | |||
| @@ -1,16 +0,0 @@ | |||
| from learnware.model import BaseModel | |||
| import numpy as np | |||
| import joblib | |||
| import os | |||
| class MyModel(BaseModel): | |||
| def __init__(self): | |||
| super(MyModel, self).__init__(input_shape=(30,), output_shape=(1,)) | |||
| dir_path = os.path.dirname(os.path.abspath(__file__)) | |||
| model_path = os.path.join(dir_path, "ridge.pkl") | |||
| model = joblib.load(model_path) | |||
| self.model = model | |||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||
| return self.model.predict(X) | |||
| @@ -1 +0,0 @@ | |||
| learnware == 0.1.0.999 | |||
| @@ -1,414 +0,0 @@ | |||
| import torch | |||
| import unittest | |||
| import os | |||
| import copy | |||
| import joblib | |||
| import zipfile | |||
| import numpy as np | |||
| import multiprocessing | |||
| from sklearn.linear_model import Ridge | |||
| from sklearn.datasets import make_regression | |||
| from shutil import copyfile, rmtree | |||
| from learnware.client import LearnwareClient | |||
| from sklearn.metrics import mean_squared_error | |||
| import learnware | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.specification import RKMETableSpecification, generate_rkme_table_spec | |||
| from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser | |||
| from example_learnwares.config import ( | |||
| input_shape_list, | |||
| input_description_list, | |||
| output_description_list, | |||
| user_description_list, | |||
| ) | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| user_semantic = { | |||
| "Data": {"Values": ["Table"], "Type": "Class"}, | |||
| "Task": { | |||
| "Values": ["Regression"], | |||
| "Type": "Class", | |||
| }, | |||
| "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, | |||
| "Scenario": {"Values": ["Education"], "Type": "Tag"}, | |||
| "Description": {"Values": "", "Type": "String"}, | |||
| "Name": {"Values": "", "Type": "String"}, | |||
| "License": {"Values": ["MIT"], "Type": "Class"}, | |||
| } | |||
| def check_learnware(learnware_name, dir_path=os.path.join(curr_root, "learnware_pool")): | |||
| print(f"Checking Learnware: {learnware_name}") | |||
| zip_file_path = os.path.join(dir_path, learnware_name) | |||
| client = LearnwareClient() | |||
| # if check_learnware doesn't raise an exception, return True, otherwise, return false | |||
| try: | |||
| client.check_learnware(zip_file_path) | |||
| return True | |||
| except Exception as e: | |||
| print(f"Learnware {learnware_name} failed the check: {e}") | |||
| return False | |||
| class TestMarket(unittest.TestCase): | |||
| @classmethod | |||
| def setUpClass(cls) -> None: | |||
| np.random.seed(2023) | |||
| learnware.init() | |||
| def _init_learnware_market(self, organizer_kwargs=None): | |||
| """initialize learnware market""" | |||
| hetero_market = instantiate_learnware_market( | |||
| market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs | |||
| ) | |||
| return hetero_market | |||
| def test_prepare_learnware_randomly(self, learnware_num=5): | |||
| self.zip_path_list = [] | |||
| for i in range(learnware_num): | |||
| dir_path = os.path.join(curr_root, "learnware_pool", "ridge_%d" % (i)) | |||
| os.makedirs(dir_path, exist_ok=True) | |||
| print("Preparing Learnware: %d" % (i)) | |||
| example_learnware_idx = i % 2 | |||
| input_dim = input_shape_list[example_learnware_idx] | |||
| learnware_example_dir = "example_learnwares" | |||
| X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_dim, noise=0.1, random_state=42) | |||
| clf = Ridge(alpha=1.0) | |||
| clf.fit(X, y) | |||
| joblib.dump(clf, os.path.join(dir_path, "ridge.pkl")) | |||
| spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0) | |||
| spec.save(os.path.join(dir_path, "stat.json")) | |||
| init_file = os.path.join(dir_path, "__init__.py") | |||
| copyfile( | |||
| os.path.join(curr_root, learnware_example_dir, f"model{example_learnware_idx}.py"), init_file | |||
| ) # cp example_init.py init_file | |||
| yaml_file = os.path.join(dir_path, "learnware.yaml") | |||
| copyfile( | |||
| os.path.join(curr_root, learnware_example_dir, "learnware.yaml"), yaml_file | |||
| ) # cp example.yaml yaml_file | |||
| env_file = os.path.join(dir_path, "requirements.txt") | |||
| copyfile(os.path.join(curr_root, learnware_example_dir, "requirements.txt"), env_file) | |||
| zip_file = dir_path + ".zip" | |||
| # zip -q -r -j zip_file dir_path | |||
| with zipfile.ZipFile(zip_file, "w") as zip_obj: | |||
| for foldername, subfolders, filenames in os.walk(dir_path): | |||
| for filename in filenames: | |||
| file_path = os.path.join(foldername, filename) | |||
| zip_info = zipfile.ZipInfo(filename) | |||
| zip_info.compress_type = zipfile.ZIP_STORED | |||
| with open(file_path, "rb") as file: | |||
| zip_obj.writestr(zip_info, file.read()) | |||
| rmtree(dir_path) # rm -r dir_path | |||
| self.zip_path_list.append(zip_file) | |||
| def test_generated_learnwares(self): | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| dir_path = os.path.join(curr_root, "learnware_pool") | |||
| # Execute multi-process checking using Pool | |||
| mp_context = multiprocessing.get_context("spawn") | |||
| with mp_context.Pool() as pool: | |||
| results = pool.starmap(check_learnware, [(name, dir_path) for name in os.listdir(dir_path)]) | |||
| # Use an assert statement to ensure that all checks return True | |||
| self.assertTrue(all(results), "Not all learnwares passed the check") | |||
| def test_upload_delete_learnware(self, learnware_num=5, delete=True): | |||
| hetero_market = self._init_learnware_market() | |||
| self.test_prepare_learnware_randomly(learnware_num) | |||
| self.learnware_num = learnware_num | |||
| print("Total Item:", len(hetero_market)) | |||
| assert len(hetero_market) == 0, f"The market should be empty!" | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) | |||
| semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) | |||
| semantic_spec["Input"] = input_description_list[idx % 2] | |||
| semantic_spec["Output"] = output_description_list[idx % 2] | |||
| hetero_market.add_learnware(zip_path, semantic_spec) | |||
| print("Total Item:", len(hetero_market)) | |||
| assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = hetero_market.get_learnware_ids() | |||
| print("Available ids After Uploading Learnwares:", curr_inds) | |||
| assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| if delete: | |||
| for learnware_id in curr_inds: | |||
| hetero_market.delete_learnware(learnware_id) | |||
| self.learnware_num -= 1 | |||
| assert ( | |||
| len(hetero_market) == self.learnware_num | |||
| ), f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = hetero_market.get_learnware_ids() | |||
| print("Available ids After Deleting Learnwares:", curr_inds) | |||
| assert len(curr_inds) == 0, f"The market should be empty!" | |||
| return hetero_market | |||
| def test_train_market_model(self, learnware_num=5): | |||
| hetero_market = self._init_learnware_market( | |||
| organizer_kwargs={"auto_update": False, "auto_update_limit": learnware_num} | |||
| ) | |||
| self.test_prepare_learnware_randomly(learnware_num) | |||
| self.learnware_num = learnware_num | |||
| print("Total Item:", len(hetero_market)) | |||
| assert len(hetero_market) == 0, f"The market should be empty!" | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) | |||
| semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) | |||
| semantic_spec["Input"] = input_description_list[idx % 2] | |||
| semantic_spec["Output"] = output_description_list[idx % 2] | |||
| hetero_market.add_learnware(zip_path, semantic_spec) | |||
| print("Total Item:", len(hetero_market)) | |||
| assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = hetero_market.get_learnware_ids() | |||
| print("Available ids After Uploading Learnwares:", curr_inds) | |||
| assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| # organizer=hetero_market.learnware_organizer | |||
| # organizer.train(hetero_market.learnware_organizer.learnware_list.values()) | |||
| return hetero_market | |||
| def test_search_semantics(self, learnware_num=5): | |||
| hetero_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(hetero_market)) | |||
| assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| assert len(single_result) == 1, f"Exact semantic search failed!" | |||
| for search_item in single_result: | |||
| semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", search_item.learnware.id, semantic_spec1) | |||
| assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!" | |||
| semantic_spec["Name"]["Values"] = "laernwaer" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!" | |||
| for search_item in single_result: | |||
| semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", search_item.learnware.id, semantic_spec1) | |||
| def test_stat_search(self, learnware_num=5): | |||
| hetero_market = self.test_train_market_model(learnware_num) | |||
| print("Total Item:", len(hetero_market)) | |||
| # hetero test | |||
| print("+++++ HETERO TEST ++++++") | |||
| user_dim = 15 | |||
| test_folder = os.path.join(curr_root, "test_stat") | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| unzip_dir = os.path.join(test_folder, f"{idx}") | |||
| # unzip -o -q zip_path -d unzip_dir | |||
| if os.path.exists(unzip_dir): | |||
| rmtree(unzip_dir) | |||
| os.makedirs(unzip_dir, exist_ok=True) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=unzip_dir) | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "stat.json")) | |||
| z = user_spec.get_z() | |||
| z = z[:, :user_dim] | |||
| device = user_spec.device | |||
| z = torch.tensor(z, device=device) | |||
| user_spec.z = z | |||
| print(">> normal case test:") | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Input"] = copy.deepcopy(input_description_list[idx % 2]) | |||
| semantic_spec["Input"]["Dimension"] = user_dim | |||
| # keep only the first user_dim descriptions | |||
| semantic_spec["Input"]["Description"] = { | |||
| str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) | |||
| } | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| print(f"search result of user{idx}:") | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| for multiple_item in multiple_result: | |||
| print( | |||
| f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}" | |||
| ) | |||
| # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec | |||
| print(">> test for key 'Task' has empty 'Values':") | |||
| semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"} | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| assert len(single_result) == 0, f"Statistical search failed!" | |||
| # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type" | |||
| print(">> delele key 'Task' test:") | |||
| semantic_spec.pop("Task") | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| assert len(single_result) == 0, f"Statistical search failed!" | |||
| # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification." | |||
| print(">> mismatch dim test") | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Input"] = copy.deepcopy(input_description_list[idx % 2]) | |||
| semantic_spec["Input"]["Dimension"] = user_dim - 2 | |||
| semantic_spec["Input"]["Description"] = { | |||
| str(key): semantic_spec["Input"]["Description"][str(key)] for key in range(user_dim) | |||
| } | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| assert len(single_result) == 0, f"Statistical search failed!" | |||
| rmtree(test_folder) # rm -r test_folder | |||
| # homo test | |||
| print("\n+++++ HOMO TEST ++++++") | |||
| test_folder = os.path.join(curr_root, "test_stat") | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| unzip_dir = os.path.join(test_folder, f"{idx}") | |||
| # unzip -o -q zip_path -d unzip_dir | |||
| if os.path.exists(unzip_dir): | |||
| rmtree(unzip_dir) | |||
| os.makedirs(unzip_dir, exist_ok=True) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=unzip_dir) | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "stat.json")) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| assert len(single_result) >= 1, f"Statistical search failed!" | |||
| print(f"search result of user{idx}:") | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| for multiple_item in multiple_result: | |||
| print(f"mixture_score: {multiple_item.score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| rmtree(test_folder) # rm -r test_folder | |||
| def test_model_reuse(self, learnware_num=5): | |||
| # generate toy regression problem | |||
| X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0) | |||
| # generate rkme | |||
| user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0) | |||
| # generate specification | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Input"] = user_description_list[0] | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| # learnware market search | |||
| hetero_market = self.test_train_market_model(learnware_num) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| # print search results | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| for multiple_item in multiple_result: | |||
| print( | |||
| f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}" | |||
| ) | |||
| # single model reuse | |||
| hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression") | |||
| hetero_learnware.align(user_spec, X[:100], y[:100]) | |||
| single_predict_y = hetero_learnware.predict(X) | |||
| # multi model reuse | |||
| hetero_learnware_list = [] | |||
| for learnware in multiple_result[0].learnwares: | |||
| hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression") | |||
| hetero_learnware.align(user_spec, X[:100], y[:100]) | |||
| hetero_learnware_list.append(hetero_learnware) | |||
| # Use averaging ensemble reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = AveragingReuser(learnware_list=hetero_learnware_list, mode="mean") | |||
| ensemble_predict_y = reuse_ensemble.predict(user_data=X) | |||
| # Use ensemble pruning reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = EnsemblePruningReuser(learnware_list=hetero_learnware_list, mode="regression") | |||
| reuse_ensemble.fit(X[:100], y[:100]) | |||
| ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=X) | |||
| print("Single model RMSE by finetune:", mean_squared_error(y, single_predict_y, squared=False)) | |||
| print("Averaging Reuser RMSE:", mean_squared_error(y, ensemble_predict_y, squared=False)) | |||
| print("Ensemble Pruning Reuser RMSE:", mean_squared_error(y, ensemble_pruning_predict_y, squared=False)) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestMarket("test_prepare_learnware_randomly")) | |||
| _suite.addTest(TestMarket("test_generated_learnwares")) | |||
| _suite.addTest(TestMarket("test_upload_delete_learnware")) | |||
| _suite.addTest(TestMarket("test_train_market_model")) | |||
| _suite.addTest(TestMarket("test_search_semantics")) | |||
| _suite.addTest(TestMarket("test_stat_search")) | |||
| _suite.addTest(TestMarket("test_model_reuse")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| @@ -3,17 +3,20 @@ import json | |||
| import zipfile | |||
| import unittest | |||
| import tempfile | |||
| import argparse | |||
| from learnware.client import LearnwareClient | |||
| from learnware.specification import Specification | |||
| from learnware.specification import generate_semantic_spec | |||
| from learnware.market import BaseUserInfo | |||
| class TestAllLearnware(unittest.TestCase): | |||
| def setUp(self): | |||
| unittest.TestCase.setUpClass() | |||
| dir_path = os.path.dirname(__file__) | |||
| config_path = os.path.join(dir_path, "config.json") | |||
| client = LearnwareClient() | |||
| @classmethod | |||
| def setUpClass(cls) -> None: | |||
| config_path = os.path.join(os.path.dirname(__file__), "config.json") | |||
| if not os.path.exists(config_path): | |||
| data = {"email": None, "token": None} | |||
| with open(config_path, "w") as file: | |||
| @@ -21,40 +24,56 @@ class TestAllLearnware(unittest.TestCase): | |||
| with open(config_path, "r") as file: | |||
| data = json.load(file) | |||
| email = data["email"] | |||
| token = data["token"] | |||
| email = data.get("email") | |||
| token = data.get("token") | |||
| if email is None or token is None: | |||
| raise ValueError("Please set email and token in config.json.") | |||
| self.client = LearnwareClient() | |||
| self.client.login(email, token) | |||
| print("Please set email and token in config.json.") | |||
| else: | |||
| cls.client.login(email, token) | |||
| def _skip_test(self): | |||
| if not self.client.is_login(): | |||
| print("Client does not login!") | |||
| return True | |||
| return False | |||
| def test_all_learnware(self): | |||
| max_learnware_num = 1000 | |||
| semantic_spec = self.client.create_semantic_specification() | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={}) | |||
| result = self.client.search_learnware(user_info, page_size=max_learnware_num) | |||
| print(f"result size: {len(result)}") | |||
| print(f"key in result: {[key for key in result[0]]}") | |||
| failed_ids = [] | |||
| learnware_ids = [res["learnware_id"] for res in result] | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| for idx in learnware_ids: | |||
| zip_path = os.path.join(tempdir, f"test_{idx}.zip") | |||
| self.client.download_learnware(idx, zip_path) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_file: | |||
| with zip_file.open("semantic_specification.json") as json_file: | |||
| semantic_spec = json.load(json_file) | |||
| try: | |||
| LearnwareClient.check_learnware(zip_path, semantic_spec) | |||
| print(f"check learnware {idx} succeed") | |||
| except: | |||
| failed_ids.append(idx) | |||
| print(f"check learnware {idx} failed!!!") | |||
| print(f"The currently failed learnware ids: {failed_ids}") | |||
| if not self._skip_test(): | |||
| max_learnware_num = 2000 | |||
| semantic_spec = generate_semantic_spec() | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={}) | |||
| result = self.client.search_learnware(user_info, page_size=max_learnware_num) | |||
| learnware_ids = result["single"]["learnware_ids"] | |||
| keys = [key for key in result["single"]["semantic_specifications"][0]] | |||
| print(f"result size: {len(learnware_ids)}") | |||
| print(f"key in result: {keys}") | |||
| failed_ids = [] | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| for idx in learnware_ids: | |||
| zip_path = os.path.join(tempdir, f"test_{idx}.zip") | |||
| self.client.download_learnware(idx, zip_path) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_file: | |||
| with zip_file.open("semantic_specification.json") as json_file: | |||
| semantic_spec = json.load(json_file) | |||
| try: | |||
| LearnwareClient.check_learnware(zip_path, semantic_spec) | |||
| print(f"check learnware {idx} succeed") | |||
| except: | |||
| failed_ids.append(idx) | |||
| print(f"check learnware {idx} failed!!!") | |||
| print(f"The currently failed learnware ids: {failed_ids}") | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestAllLearnware("test_all_learnware")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| @@ -4,7 +4,6 @@ import zipfile | |||
| import unittest | |||
| import tempfile | |||
| from learnware.client import LearnwareClient | |||
| @@ -13,6 +12,13 @@ class TestCheckLearnware(unittest.TestCase): | |||
| unittest.TestCase.setUpClass() | |||
| self.client = LearnwareClient() | |||
| def test_check_learnware_pip_only_zip(self): | |||
| learnware_id = "00000208" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| self.zip_path = os.path.join(tempdir, "test.zip") | |||
| self.client.download_learnware(learnware_id, self.zip_path) | |||
| LearnwareClient.check_learnware(self.zip_path) | |||
| def test_check_learnware_pip(self): | |||
| learnware_id = "00000208" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| @@ -69,5 +75,17 @@ class TestCheckLearnware(unittest.TestCase): | |||
| LearnwareClient.check_learnware(self.zip_path, semantic_spec) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestCheckLearnware("test_check_learnware_pip_only_zip")) | |||
| _suite.addTest(TestCheckLearnware("test_check_learnware_pip")) | |||
| _suite.addTest(TestCheckLearnware("test_check_learnware_conda")) | |||
| _suite.addTest(TestCheckLearnware("test_check_learnware_dependency")) | |||
| _suite.addTest(TestCheckLearnware("test_check_learnware_image")) | |||
| _suite.addTest(TestCheckLearnware("test_check_learnware_text")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| @@ -0,0 +1,51 @@ | |||
| import unittest | |||
| import numpy as np | |||
| from learnware.client import LearnwareClient | |||
| from learnware.client.container import LearnwaresContainer | |||
| class TestContainer(unittest.TestCase): | |||
| def __init__(self, method_name='runTest', mode="all"): | |||
| super(TestContainer, self).__init__(method_name) | |||
| self.modes = [] | |||
| if mode in {"all", "conda"}: | |||
| self.modes.append("conda") | |||
| if mode in {"all", "docker"}: | |||
| self.modes.append("docker") | |||
| def setUp(self): | |||
| self.client = LearnwareClient() | |||
| def _test_container_with_pip(self, mode): | |||
| learnware_id = "00000147" | |||
| learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| with LearnwaresContainer(learnware, ignore_error=False, mode=mode) as env_container: | |||
| learnware = env_container.get_learnwares_with_container()[0] | |||
| input_array = np.random.random(size=(20, 23)) | |||
| print(learnware.predict(input_array)) | |||
| def _test_container_with_conda(self, mode): | |||
| learnware_id = "00000148" | |||
| learnware = self.client.load_learnware(learnware_id=learnware_id) | |||
| with LearnwaresContainer(learnware, ignore_error=False, mode=mode) as env_container: | |||
| learnware = env_container.get_learnwares_with_container()[0] | |||
| input_array = np.random.random(size=(20, 204)) | |||
| print(learnware.predict(input_array)) | |||
| def test_container_with_pip(self): | |||
| for mode in self.modes: | |||
| self._test_container_with_pip(mode=mode) | |||
| def test_container_with_conda(self): | |||
| for mode in self.modes: | |||
| self._test_container_with_conda(mode=mode) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestContainer("test_container_with_pip", mode="all")) | |||
| _suite.addTest(TestContainer("test_container_with_conda", mode="all")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| @@ -1,82 +0,0 @@ | |||
| import os | |||
| import unittest | |||
| import zipfile | |||
| import numpy as np | |||
| import learnware | |||
| from learnware.learnware import get_learnware_from_dirpath | |||
| from learnware.client import LearnwareClient | |||
| from learnware.client.container import ModelCondaContainer, LearnwaresContainer | |||
| from learnware.reuse import AveragingReuser | |||
| class TestLearnwareLoad(unittest.TestCase): | |||
| def setUp(self): | |||
| unittest.TestCase.setUpClass() | |||
| self.client = LearnwareClient() | |||
| root = os.path.dirname(__file__) | |||
| self.learnware_ids = ["00000910", "00000899", "00000900"] | |||
| self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]] | |||
| def test_load_single_learnware_by_zippath(self): | |||
| for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths): | |||
| self.client.download_learnware(learnware_id, zip_path) | |||
| learnware_list = [ | |||
| self.client.load_learnware(learnware_path=zippath, runnable_option="conda") for zippath in self.zip_paths | |||
| ] | |||
| reuser = AveragingReuser(learnware_list, mode="mean") | |||
| input_array = np.random.random(size=(20, 40)) | |||
| print(reuser.predict(input_array)) | |||
| for learnware in learnware_list: | |||
| print(learnware.id, learnware.predict(input_array)) | |||
| def test_load_multi_learnware_by_zippath(self): | |||
| for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths): | |||
| self.client.download_learnware(learnware_id, zip_path) | |||
| learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="conda") | |||
| reuser = AveragingReuser(learnware_list, mode="mean") | |||
| input_array = np.random.random(size=(20, 40)) | |||
| print(reuser.predict(input_array)) | |||
| for learnware in learnware_list: | |||
| print(learnware.id, learnware.predict(input_array)) | |||
| def test_load_single_learnware_by_id(self): | |||
| learnware_list = [ | |||
| self.client.load_learnware(learnware_id=idx, runnable_option="conda") for idx in self.learnware_ids | |||
| ] | |||
| reuser = AveragingReuser(learnware_list, mode="mean") | |||
| input_array = np.random.random(size=(20, 40)) | |||
| print(reuser.predict(input_array)) | |||
| for learnware in learnware_list: | |||
| print(learnware.id, learnware.predict(input_array)) | |||
| def test_load_multi_learnware_by_id(self): | |||
| learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="conda") | |||
| reuser = AveragingReuser(learnware_list, mode="mean") | |||
| input_array = np.random.random(size=(20, 40)) | |||
| print(reuser.predict(input_array)) | |||
| for learnware in learnware_list: | |||
| print(learnware.id, learnware.predict(input_array)) | |||
| def test_load_single_learnware_by_id_pip(self): | |||
| learnware_id = "00000147" | |||
| learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="conda") | |||
| input_array = np.random.random(size=(20, 23)) | |||
| print(learnware.predict(input_array)) | |||
| def test_load_single_learnware_by_id_conda(self): | |||
| learnware_id = "00000148" | |||
| learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="conda") | |||
| input_array = np.random.random(size=(20, 204)) | |||
| print(learnware.predict(input_array)) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -1,57 +0,0 @@ | |||
| import os | |||
| import unittest | |||
| import zipfile | |||
| import numpy as np | |||
| import learnware | |||
| from learnware.learnware import get_learnware_from_dirpath | |||
| from learnware.client import LearnwareClient | |||
| from learnware.client.container import ModelCondaContainer, LearnwaresContainer | |||
| from learnware.reuse import AveragingReuser | |||
| class TestLearnwareLoad(unittest.TestCase): | |||
| def setUp(self): | |||
| unittest.TestCase.setUpClass() | |||
| self.client = LearnwareClient() | |||
| root = os.path.dirname(__file__) | |||
| self.learnware_ids = ["00000910", "00000899", "00000900"] | |||
| self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]] | |||
| def test_load_multi_learnware_by_zippath(self): | |||
| for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths): | |||
| self.client.download_learnware(learnware_id, zip_path) | |||
| learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option="docker") | |||
| reuser = AveragingReuser(learnware_list, mode="mean") | |||
| input_array = np.random.random(size=(20, 40)) | |||
| print(reuser.predict(input_array)) | |||
| for learnware in learnware_list: | |||
| print(learnware.id, learnware.predict(input_array)) | |||
| def test_load_multi_learnware_by_id(self): | |||
| learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option="docker") | |||
| reuser = AveragingReuser(learnware_list, mode="mean") | |||
| input_array = np.random.random(size=(20, 40)) | |||
| print(reuser.predict(input_array)) | |||
| for learnware in learnware_list: | |||
| print(learnware.id, learnware.predict(input_array)) | |||
| def test_load_single_learnware_by_id_pip(self): | |||
| learnware_id = "00000147" | |||
| learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="docker") | |||
| input_array = np.random.random(size=(20, 23)) | |||
| print(learnware.predict(input_array)) | |||
| def test_load_single_learnware_by_id_conda(self): | |||
| learnware_id = "00000148" | |||
| learnware = self.client.load_learnware(learnware_id=learnware_id, runnable_option="docker") | |||
| input_array = np.random.random(size=(20, 204)) | |||
| print(learnware.predict(input_array)) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -0,0 +1,61 @@ | |||
| import os | |||
| import unittest | |||
| import numpy as np | |||
| from learnware.client import LearnwareClient | |||
| from learnware.reuse import AveragingReuser | |||
| class TestLearnwareLoad(unittest.TestCase): | |||
| def __init__(self, method_name='runTest', mode="all"): | |||
| super(TestLearnwareLoad, self).__init__(method_name) | |||
| self.runnable_options = [] | |||
| if mode in {"all", "conda"}: | |||
| self.runnable_options.append("conda") | |||
| if mode in {"all", "docker"}: | |||
| self.runnable_options.append("docker") | |||
| def setUp(self): | |||
| self.client = LearnwareClient() | |||
| root = os.path.dirname(__file__) | |||
| self.learnware_ids = ["00000910", "00000899", "00000900"] | |||
| self.zip_paths = [os.path.join(root, x) for x in ["1.zip", "2.zip", "3.zip"]] | |||
| def _test_load_learnware_by_zippath(self, runnable_option): | |||
| for learnware_id, zip_path in zip(self.learnware_ids, self.zip_paths): | |||
| self.client.download_learnware(learnware_id, zip_path) | |||
| learnware_list = self.client.load_learnware(learnware_path=self.zip_paths, runnable_option=runnable_option) | |||
| reuser = AveragingReuser(learnware_list, mode="vote_by_label") | |||
| input_array = np.random.random(size=(20, 13)) | |||
| print(reuser.predict(input_array)) | |||
| for learnware in learnware_list: | |||
| print(learnware.id, learnware.predict(input_array)) | |||
| def _test_load_learnware_by_id(self, runnable_option): | |||
| learnware_list = self.client.load_learnware(learnware_id=self.learnware_ids, runnable_option=runnable_option) | |||
| reuser = AveragingReuser(learnware_list, mode="vote_by_label") | |||
| input_array = np.random.random(size=(20, 13)) | |||
| print(reuser.predict(input_array)) | |||
| for learnware in learnware_list: | |||
| print(learnware.id, learnware.predict(input_array)) | |||
| def test_load_learnware_by_zippath(self): | |||
| for runnable_option in self.runnable_options: | |||
| self._test_load_learnware_by_zippath(runnable_option=runnable_option) | |||
| def test_load_learnware_by_id(self): | |||
| for runnable_option in self.runnable_options: | |||
| self._test_load_learnware_by_id(runnable_option=runnable_option) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestLearnwareLoad("test_load_learnware_by_zippath", mode="all")) | |||
| _suite.addTest(TestLearnwareLoad("test_load_learnware_by_id", mode="all")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| @@ -1,34 +0,0 @@ | |||
| import zipfile | |||
| import numpy as np | |||
| from learnware.learnware import get_learnware_from_dirpath | |||
| from learnware.client.container import LearnwaresContainer | |||
| from learnware.reuse import AveragingReuser | |||
| from learnware.tests.module import get_semantic_specification | |||
| if __name__ == "__main__": | |||
| semantic_specification = get_semantic_specification() | |||
| zip_paths = [ | |||
| "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/rf_tic.zip", | |||
| "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/svc_tic.zip", | |||
| ] | |||
| dir_paths = [ | |||
| "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/rf_tic", | |||
| "/home/bixd/workspace/learnware/Learnware/tests/test_learnware_client/svc_tic", | |||
| ] | |||
| learnware_list = [] | |||
| for id, (zip_path, dir_path) in enumerate(zip(zip_paths, dir_paths)): | |||
| with zipfile.ZipFile(zip_path, "r") as z_file: | |||
| z_file.extractall(dir_path) | |||
| learnware = get_learnware_from_dirpath(f"test_id{id}", semantic_specification, dir_path) | |||
| learnware_list.append(learnware) | |||
| with LearnwaresContainer(learnware_list) as env_container: | |||
| learnware_list = env_container.get_learnwares_with_container() | |||
| reuser = AveragingReuser(learnware_list, mode="vote") | |||
| input_array = np.random.randint(0, 3, size=(20, 9)) | |||
| print(reuser.predict(input_array).argmax(axis=1)) | |||
| for id, ind_learner in enumerate(learnware_list): | |||
| print(f"learner_{id}", reuser.predict(input_array).argmax(axis=1)) | |||
| @@ -4,13 +4,16 @@ import unittest | |||
| import tempfile | |||
| from learnware.client import LearnwareClient | |||
| from learnware.specification import generate_semantic_spec | |||
| class TestAllLearnware(unittest.TestCase): | |||
| def setUp(self): | |||
| unittest.TestCase.setUpClass() | |||
| dir_path = os.path.dirname(__file__) | |||
| config_path = os.path.join(dir_path, "config.json") | |||
| class TestUpload(unittest.TestCase): | |||
| client = LearnwareClient() | |||
| @classmethod | |||
| def setUpClass(cls) -> None: | |||
| config_path = os.path.join(os.path.dirname(__file__), "config.json") | |||
| if not os.path.exists(config_path): | |||
| data = {"email": None, "token": None} | |||
| with open(config_path, "w") as file: | |||
| @@ -18,52 +21,65 @@ class TestAllLearnware(unittest.TestCase): | |||
| with open(config_path, "r") as file: | |||
| data = json.load(file) | |||
| email = data["email"] | |||
| token = data["token"] | |||
| email = data.get("email") | |||
| token = data.get("token") | |||
| if email is None or token is None: | |||
| raise ValueError("Please set email and token in config.json.") | |||
| self.client = LearnwareClient() | |||
| self.client.login(email, token) | |||
| print("Please set email and token in config.json.") | |||
| else: | |||
| cls.client.login(email, token) | |||
| def _skip_test(self): | |||
| if not self.client.is_login(): | |||
| print("Client does not login!") | |||
| return True | |||
| return False | |||
| def test_upload(self): | |||
| input_description = { | |||
| "Dimension": 13, | |||
| "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"}, | |||
| } | |||
| output_description = { | |||
| "Dimension": 1, | |||
| "Description": { | |||
| "0": "the probability of being a cat", | |||
| }, | |||
| } | |||
| semantic_spec = self.client.create_semantic_specification( | |||
| name="learnware_example", | |||
| description="Just a example for uploading a learnware", | |||
| data_type="Table", | |||
| task_type="Classification", | |||
| library_type="Scikit-learn", | |||
| scenarios=["Business", "Financial"], | |||
| input_description=input_description, | |||
| output_description=output_description, | |||
| ) | |||
| assert isinstance(semantic_spec, dict) | |||
| download_learnware_id = "00000084" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| zip_path = os.path.join(tempdir, f"test.zip") | |||
| self.client.download_learnware(download_learnware_id, zip_path) | |||
| learnware_id = self.client.upload_learnware( | |||
| learnware_zip_path=zip_path, semantic_specification=semantic_spec | |||
| if not self._skip_test(): | |||
| input_description = { | |||
| "Dimension": 13, | |||
| "Description": {"0": "age", "1": "weight", "2": "body length", "3": "animal type", "4": "claw length"}, | |||
| } | |||
| output_description = { | |||
| "Dimension": 2, | |||
| "Description": {"0": "cat", "1": "not cat"}, | |||
| } | |||
| semantic_spec = generate_semantic_spec( | |||
| name="learnware_example", | |||
| description="Just a example for uploading a learnware", | |||
| data_type="Table", | |||
| task_type="Classification", | |||
| library_type="Scikit-learn", | |||
| scenarios=["Business", "Financial"], | |||
| license="MIT", | |||
| input_description=input_description, | |||
| output_description=output_description, | |||
| ) | |||
| assert isinstance(semantic_spec, dict) | |||
| download_learnware_id = "00000084" | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| zip_path = os.path.join(tempdir, f"test.zip") | |||
| self.client.download_learnware(download_learnware_id, zip_path) | |||
| learnware_id = self.client.upload_learnware( | |||
| learnware_zip_path=zip_path, semantic_specification=semantic_spec | |||
| ) | |||
| uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] | |||
| assert learnware_id in uploaded_ids | |||
| self.client.delete_learnware(learnware_id) | |||
| uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] | |||
| assert learnware_id not in uploaded_ids | |||
| uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] | |||
| assert learnware_id in uploaded_ids | |||
| self.client.delete_learnware(learnware_id) | |||
| uploaded_ids = [learnware["learnware_id"] for learnware in self.client.list_learnware()] | |||
| assert learnware_id not in uploaded_ids | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| _suite.addTest(TestUpload("test_upload")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| runner = unittest.TextTestRunner() | |||
| runner.run(suite()) | |||
| @@ -0,0 +1,43 @@ | |||
| import os | |||
| import json | |||
| import string | |||
| import random | |||
| import torch | |||
| import unittest | |||
| import tempfile | |||
| import numpy as np | |||
| from learnware.specification import RKMETableSpecification, HeteroMapTableSpecification | |||
| from learnware.specification import generate_stat_spec | |||
| from learnware.market.heterogeneous.organizer import HeteroMap | |||
| class TestTableRKME(unittest.TestCase): | |||
| def setUp(self): | |||
| self.hetero_map = HeteroMap() | |||
| def _test_hetero_spec(self, X): | |||
| rkme: RKMETableSpecification = generate_stat_spec(type="table", X=X) | |||
| hetero_spec = self.hetero_map.hetero_mapping(rkme_spec=rkme, features=dict()) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| hetero_spec.save(rkme_path) | |||
| with open(rkme_path, "r") as f: | |||
| data = json.load(f) | |||
| assert data["type"] == "HeteroMapTableSpecification" | |||
| rkme2 = HeteroMapTableSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "HeteroMapTableSpecification" | |||
| def test_hetero_rkme(self): | |||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5000, 200))) | |||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(10000, 100))) | |||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(5, 20))) | |||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(1, 50))) | |||
| self._test_hetero_spec(np.random.uniform(-10000, 10000, size=(100, 150))) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -0,0 +1,38 @@ | |||
| import os | |||
| import json | |||
| import torch | |||
| import unittest | |||
| import tempfile | |||
| import numpy as np | |||
| from learnware.specification import RKMEImageSpecification | |||
| from learnware.specification import generate_stat_spec | |||
| class TestImageRKME(unittest.TestCase): | |||
| @staticmethod | |||
| def _test_image_rkme(X): | |||
| image_rkme = generate_stat_spec(type="image", X=X, steps=10) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| image_rkme.save(rkme_path) | |||
| with open(rkme_path, "r") as f: | |||
| data = json.load(f) | |||
| assert data["type"] == "RKMEImageSpecification" | |||
| rkme2 = RKMEImageSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMEImageSpecification" | |||
| def test_image_rkme(self): | |||
| self._test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) | |||
| self._test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128))) | |||
| self._test_image_rkme(np.random.randint(0, 255, size=(50, 3, 128, 128)) / 255) | |||
| self._test_image_rkme(torch.randint(0, 255, (2000, 3, 32, 32))) | |||
| self._test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128))) | |||
| self._test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -1,104 +0,0 @@ | |||
| import os | |||
| import json | |||
| import string | |||
| import random | |||
| import torch | |||
| import unittest | |||
| import tempfile | |||
| import numpy as np | |||
| from learnware.specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification | |||
| from learnware.specification import generate_stat_spec | |||
| class TestRKME(unittest.TestCase): | |||
| def test_rkme(self): | |||
| def _test_table_rkme(X): | |||
| rkme = generate_stat_spec(type="table", X=X) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| rkme.save(rkme_path) | |||
| with open(rkme_path, "r") as f: | |||
| data = json.load(f) | |||
| assert data["type"] == "RKMETableSpecification" | |||
| rkme2 = RKMETableSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMETableSpecification" | |||
| _test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200))) | |||
| _test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100))) | |||
| _test_table_rkme(np.random.uniform(-10000, 10000, size=(5, 20))) | |||
| _test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50))) | |||
| _test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150))) | |||
| def test_image_rkme(self): | |||
| def _test_image_rkme(X): | |||
| image_rkme = generate_stat_spec(type="image", X=X, steps=10) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| image_rkme.save(rkme_path) | |||
| with open(rkme_path, "r") as f: | |||
| data = json.load(f) | |||
| assert data["type"] == "RKMEImageSpecification" | |||
| rkme2 = RKMEImageSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMEImageSpecification" | |||
| _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32))) | |||
| _test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128))) | |||
| _test_image_rkme(np.random.randint(0, 255, size=(50, 3, 128, 128)) / 255) | |||
| _test_image_rkme(torch.randint(0, 255, (2000, 3, 32, 32))) | |||
| _test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128))) | |||
| _test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255) | |||
| def test_text_rkme(self): | |||
| def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): | |||
| text_list = [] | |||
| for i in range(num): | |||
| length = random.randint(min_len, max_len) | |||
| if text_type == "en": | |||
| characters = string.ascii_letters + string.digits + string.punctuation | |||
| result_str = "".join(random.choice(characters) for i in range(length)) | |||
| text_list.append(result_str) | |||
| elif text_type == "zh": | |||
| result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) | |||
| text_list.append(result_str) | |||
| else: | |||
| raise ValueError("Type should be en or zh") | |||
| return text_list | |||
| def _test_text_rkme(X): | |||
| rkme = generate_stat_spec(type="text", X=X) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| rkme.save(rkme_path) | |||
| with open(rkme_path, "r") as f: | |||
| data = json.load(f) | |||
| assert data["type"] == "RKMETextSpecification" | |||
| rkme2 = RKMETextSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMETextSpecification" | |||
| return rkme2.get_z().shape[1] | |||
| dim1 = _test_text_rkme(generate_random_text_list(3000, "en")) | |||
| dim2 = _test_text_rkme(generate_random_text_list(100, "en")) | |||
| dim3 = _test_text_rkme(generate_random_text_list(50, "zh")) | |||
| dim4 = _test_text_rkme(generate_random_text_list(5000, "zh")) | |||
| dim5 = _test_text_rkme(generate_random_text_list(1, "zh")) | |||
| assert dim1 == dim2 and dim2 == dim3 and dim3 == dim4 and dim4 == dim5 | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -0,0 +1,36 @@ | |||
| import os | |||
| import json | |||
| import unittest | |||
| import tempfile | |||
| import numpy as np | |||
| from learnware.specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification | |||
| from learnware.specification import generate_stat_spec | |||
| class TestTableRKME(unittest.TestCase): | |||
| @staticmethod | |||
| def _test_table_rkme(X): | |||
| rkme = generate_stat_spec(type="table", X=X) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| rkme.save(rkme_path) | |||
| with open(rkme_path, "r") as f: | |||
| data = json.load(f) | |||
| assert data["type"] == "RKMETableSpecification" | |||
| rkme2 = RKMETableSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMETableSpecification" | |||
| def test_table_rkme(self): | |||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200))) | |||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100))) | |||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(5, 20))) | |||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50))) | |||
| self._test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150))) | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -0,0 +1,58 @@ | |||
| import os | |||
| import json | |||
| import string | |||
| import random | |||
| import unittest | |||
| import tempfile | |||
| from learnware.specification import RKMETextSpecification | |||
| from learnware.specification import generate_stat_spec | |||
| class TestTextRKME(unittest.TestCase): | |||
| @staticmethod | |||
| def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000): | |||
| text_list = [] | |||
| for i in range(num): | |||
| length = random.randint(min_len, max_len) | |||
| if text_type == "en": | |||
| characters = string.ascii_letters + string.digits + string.punctuation | |||
| result_str = "".join(random.choice(characters) for i in range(length)) | |||
| text_list.append(result_str) | |||
| elif text_type == "zh": | |||
| result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length)) | |||
| text_list.append(result_str) | |||
| else: | |||
| raise ValueError("Type should be en or zh") | |||
| return text_list | |||
| @staticmethod | |||
| def _test_text_rkme(X): | |||
| rkme = generate_stat_spec(type="text", X=X) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: | |||
| rkme_path = os.path.join(tempdir, "rkme.json") | |||
| rkme.save(rkme_path) | |||
| with open(rkme_path, "r") as f: | |||
| data = json.load(f) | |||
| assert data["type"] == "RKMETextSpecification" | |||
| rkme2 = RKMETextSpecification() | |||
| rkme2.load(rkme_path) | |||
| assert rkme2.type == "RKMETextSpecification" | |||
| return rkme2.get_z().shape[1] | |||
| def test_text_rkme(self): | |||
| dim1 = self._test_text_rkme(self.generate_random_text_list(3000, "en")) | |||
| dim2 = self._test_text_rkme(self.generate_random_text_list(100, "en")) | |||
| dim3 = self._test_text_rkme(self.generate_random_text_list(50, "zh")) | |||
| dim4 = self._test_text_rkme(self.generate_random_text_list(5000, "zh")) | |||
| dim5 = self._test_text_rkme(self.generate_random_text_list(1, "zh")) | |||
| assert dim1 == dim2 and dim2 == dim3 and dim3 == dim4 and dim4 == dim5 | |||
| if __name__ == "__main__": | |||
| unittest.main() | |||
| @@ -1,10 +0,0 @@ | |||
| ## How to Generate Environment Yaml | |||
| * create env config for conda: | |||
| ```shell | |||
| conda env export | grep -v "^prefix: " > environment.yml | |||
| ``` | |||
| * recover env from config | |||
| ``` | |||
| conda env create -f environment.yml | |||
| ``` | |||
| @@ -1,27 +0,0 @@ | |||
| name: learnware_example_env | |||
| channels: | |||
| - defaults | |||
| dependencies: | |||
| - _libgcc_mutex=0.1=main | |||
| - _openmp_mutex=5.1=1_gnu | |||
| - ca-certificates=2023.01.10=h06a4308_0 | |||
| - ld_impl_linux-64=2.38=h1181459_1 | |||
| - libffi=3.4.2=h6a678d5_6 | |||
| - libgcc-ng=11.2.0=h1234567_1 | |||
| - libgomp=11.2.0=h1234567_1 | |||
| - libstdcxx-ng=11.2.0=h1234567_1 | |||
| - ncurses=6.4=h6a678d5_0 | |||
| - openssl=1.1.1t=h7f8727e_0 | |||
| - pip=23.0.1=py38h06a4308_0 | |||
| - python=3.8.16=h7a1cb2a_3 | |||
| - readline=8.2=h5eee18b_0 | |||
| - setuptools=66.0.0=py38h06a4308_0 | |||
| - sqlite=3.41.2=h5eee18b_0 | |||
| - tk=8.6.12=h1ccaba5_0 | |||
| - wheel=0.38.4=py38h06a4308_0 | |||
| - xz=5.2.10=h5eee18b_1 | |||
| - zlib=1.2.13=h5eee18b_0 | |||
| - pip: | |||
| - joblib==1.2.0 | |||
| - learnware==0.0.1.99 | |||
| - numpy==1.19.5 | |||
| @@ -1,8 +0,0 @@ | |||
| model: | |||
| class_name: SVM | |||
| kwargs: {} | |||
| stat_specifications: | |||
| - module_path: learnware.specification | |||
| class_name: RKMETableSpecification | |||
| file_name: svm.json | |||
| kwargs: {} | |||
| @@ -1,20 +0,0 @@ | |||
| import os | |||
| import joblib | |||
| import numpy as np | |||
| from learnware.model import BaseModel | |||
| class SVM(BaseModel): | |||
| def __init__(self): | |||
| super(SVM, self).__init__(input_shape=(64,), output_shape=(10,)) | |||
| dir_path = os.path.dirname(os.path.abspath(__file__)) | |||
| self.model = joblib.load(os.path.join(dir_path, "svm.pkl")) | |||
| def fit(self, X: np.ndarray, y: np.ndarray): | |||
| pass | |||
| def predict(self, X: np.ndarray) -> np.ndarray: | |||
| return self.model.predict_proba(X) | |||
| def finetune(self, X: np.ndarray, y: np.ndarray): | |||
| pass | |||
| @@ -0,0 +1,321 @@ | |||
| import torch | |||
| import pickle | |||
| import unittest | |||
| import os | |||
| import logging | |||
| import tempfile | |||
| import zipfile | |||
| from sklearn.linear_model import Ridge | |||
| from sklearn.datasets import make_regression | |||
| from shutil import copyfile, rmtree | |||
| from sklearn.metrics import mean_squared_error | |||
| import learnware | |||
| learnware.init(logging_level=logging.WARNING) | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec | |||
| from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser | |||
| from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate | |||
| from hetero_config import input_shape_list, input_description_list, output_description_list, user_description_list | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| class TestHeteroWorkflow(unittest.TestCase): | |||
| universal_semantic_config = { | |||
| "data_type": "Table", | |||
| "task_type": "Regression", | |||
| "library_type": "Scikit-learn", | |||
| "scenarios": "Education", | |||
| "license": "MIT", | |||
| } | |||
| def _init_learnware_market(self, organizer_kwargs=None): | |||
| """initialize learnware market""" | |||
| hetero_market = instantiate_learnware_market( | |||
| market_id="hetero_toy", name="hetero", rebuild=True, organizer_kwargs=organizer_kwargs | |||
| ) | |||
| return hetero_market | |||
| def test_prepare_learnware_randomly(self, learnware_num=5): | |||
| self.zip_path_list = [] | |||
| for i in range(learnware_num): | |||
| learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool_hetero") | |||
| os.makedirs(learnware_pool_dirpath, exist_ok=True) | |||
| learnware_zippath = os.path.join(learnware_pool_dirpath, "ridge_%d.zip" % (i)) | |||
| print("Preparing Learnware: %d" % (i)) | |||
| X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_shape_list[i % 2], noise=0.1, random_state=42) | |||
| clf = Ridge(alpha=1.0) | |||
| clf.fit(X, y) | |||
| pickle_filepath = os.path.join(learnware_pool_dirpath, "ridge.pkl") | |||
| with open(pickle_filepath, "wb") as fout: | |||
| pickle.dump(clf, fout) | |||
| spec = generate_rkme_table_spec(X=X, gamma=0.1) | |||
| spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json") | |||
| spec.save(spec_filepath) | |||
| LearnwareTemplate.generate_learnware_zipfile( | |||
| learnware_zippath=learnware_zippath, | |||
| model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(input_shape_list[i % 2],), "output_shape": (1,)}), | |||
| stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), | |||
| requirements=["scikit-learn==0.22"], | |||
| ) | |||
| self.zip_path_list.append(learnware_zippath) | |||
| def _upload_delete_learnware(self, hetero_market, learnware_num, delete): | |||
| self.test_prepare_learnware_randomly(learnware_num) | |||
| self.learnware_num = learnware_num | |||
| print("Total Item:", len(hetero_market)) | |||
| assert len(hetero_market) == 0, f"The market should be empty!" | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| semantic_spec = generate_semantic_spec( | |||
| name=f"learnware_{idx}", | |||
| description=f"test_learnware_number_{idx}", | |||
| input_description=input_description_list[idx % 2], | |||
| output_description=output_description_list[idx % 2], | |||
| **self.universal_semantic_config | |||
| ) | |||
| hetero_market.add_learnware(zip_path, semantic_spec) | |||
| print("Total Item:", len(hetero_market)) | |||
| assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = hetero_market.get_learnware_ids() | |||
| print("Available ids After Uploading Learnwares:", curr_inds) | |||
| assert len(curr_inds) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| if delete: | |||
| for learnware_id in curr_inds: | |||
| hetero_market.delete_learnware(learnware_id) | |||
| self.learnware_num -= 1 | |||
| assert ( | |||
| len(hetero_market) == self.learnware_num | |||
| ), f"The number of learnwares must be {self.learnware_num}!" | |||
| curr_inds = hetero_market.get_learnware_ids() | |||
| print("Available ids After Deleting Learnwares:", curr_inds) | |||
| assert len(curr_inds) == 0, f"The market should be empty!" | |||
| return hetero_market | |||
| def test_upload_delete_learnware(self, learnware_num=5, delete=True): | |||
| hetero_market = self._init_learnware_market() | |||
| return self._upload_delete_learnware(hetero_market, learnware_num, delete) | |||
| def test_train_market_model(self, learnware_num=5, delete=False): | |||
| hetero_market = self._init_learnware_market( | |||
| organizer_kwargs={"auto_update": True, "auto_update_limit": learnware_num} | |||
| ) | |||
| hetero_market = self._upload_delete_learnware(hetero_market, learnware_num, delete) | |||
| # organizer=hetero_market.learnware_organizer | |||
| # organizer.train(hetero_market.learnware_organizer.learnware_list.values()) | |||
| return hetero_market | |||
| def test_search_semantics(self, learnware_num=5): | |||
| hetero_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(hetero_market)) | |||
| assert len(hetero_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| semantic_spec = generate_semantic_spec( | |||
| name=f"learnware_{learnware_num - 1}", | |||
| **self.universal_semantic_config, | |||
| ) | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print(f"Search result1:") | |||
| assert len(single_result) == 1, f"Exact semantic search failed!" | |||
| for search_item in single_result: | |||
| semantic_spec1 = search_item.learnware.get_specification().get_semantic_spec() | |||
| print("Choose learnware:", search_item.learnware.id) | |||
| assert semantic_spec1["Name"]["Values"] == semantic_spec["Name"]["Values"], f"Exact semantic search failed!" | |||
| semantic_spec["Name"]["Values"] = "laernwaer" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print(f"Search result2:") | |||
| assert len(single_result) == self.learnware_num, f"Fuzzy semantic search failed!" | |||
| for search_item in single_result: | |||
| print("Choose learnware:", search_item.learnware.id) | |||
| def test_hetero_stat_search(self, learnware_num=5): | |||
| hetero_market = self.test_train_market_model(learnware_num, delete=False) | |||
| print("Total Item:", len(hetero_market)) | |||
| user_dim = 15 | |||
| with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=test_folder) | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(test_folder, "stat_spec.json")) | |||
| z = user_spec.get_z() | |||
| z = z[:, :user_dim] | |||
| device = user_spec.device | |||
| z = torch.tensor(z, device=device) | |||
| user_spec.z = z | |||
| print(">> normal case test:") | |||
| semantic_spec = generate_semantic_spec( | |||
| input_description={ | |||
| "Dimension": user_dim, | |||
| "Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)}, | |||
| }, | |||
| **self.universal_semantic_config, | |||
| ) | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| print(f"search result of user{idx}:") | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| for multiple_item in multiple_result: | |||
| print( | |||
| f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}" | |||
| ) | |||
| # inproper key "Task" in semantic_spec, use homo search and print invalid semantic_spec | |||
| print(">> test for key 'Task' has empty 'Values':") | |||
| semantic_spec["Task"] = {"Values": ["Segmentation"], "Type": "Class"} | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| assert len(single_result) == 0, f"Statistical search failed!" | |||
| # delete key "Task" in semantic_spec, use homo search and print WARNING INFO with "User doesn't provide correct task type" | |||
| print(">> delele key 'Task' test:") | |||
| semantic_spec.pop("Task") | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| assert len(single_result) == 0, f"Statistical search failed!" | |||
| # modify semantic info with mismatch dim, use homo search and print "User data feature dimensions mismatch with semantic specification." | |||
| print(">> mismatch dim test") | |||
| semantic_spec = generate_semantic_spec( | |||
| input_description={ | |||
| "Dimension": user_dim - 2, | |||
| "Description": {str(key): input_description_list[idx % 2]["Description"][str(key)] for key in range(user_dim)}, | |||
| }, | |||
| **self.universal_semantic_config, | |||
| ) | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| assert len(single_result) == 0, f"Statistical search failed!" | |||
| def test_homo_stat_search(self, learnware_num=5): | |||
| hetero_market = self.test_train_market_model(learnware_num, delete=False) | |||
| print("Total Item:", len(hetero_market)) | |||
| with tempfile.TemporaryDirectory(prefix="learnware_test_hetero") as test_folder: | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=test_folder) | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(test_folder, "stat_spec.json")) | |||
| user_semantic = generate_semantic_spec(**self.universal_semantic_config) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| assert len(single_result) >= 1, f"Statistical search failed!" | |||
| print(f"search result of user{idx}:") | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| for multiple_item in multiple_result: | |||
| print(f"mixture_score: {multiple_item.score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_item.learnwares]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| def test_model_reuse(self, learnware_num=5): | |||
| # generate toy regression problem | |||
| X, y = make_regression(n_samples=5000, n_informative=10, n_features=15, noise=0.1, random_state=0) | |||
| # generate rkme | |||
| user_spec = generate_rkme_table_spec(X=X, gamma=0.1, cuda_idx=0) | |||
| # generate specification | |||
| semantic_spec = generate_semantic_spec(input_description=user_description_list[0], **self.universal_semantic_config) | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec, stat_info={"RKMETableSpecification": user_spec}) | |||
| # learnware market search | |||
| hetero_market = self.test_train_market_model(learnware_num, delete=False) | |||
| search_result = hetero_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| # print search results | |||
| for single_item in single_result: | |||
| print(f"score: {single_item.score}, learnware_id: {single_item.learnware.id}") | |||
| for multiple_item in multiple_result: | |||
| print( | |||
| f"mixture_score: {multiple_item.score}, mixture_learnware_ids: {[item.id for item in multiple_item.learnwares]}" | |||
| ) | |||
| # single model reuse | |||
| hetero_learnware = HeteroMapAlignLearnware(single_result[0].learnware, mode="regression") | |||
| hetero_learnware.align(user_spec, X[:100], y[:100]) | |||
| single_predict_y = hetero_learnware.predict(X) | |||
| # multi model reuse | |||
| hetero_learnware_list = [] | |||
| for learnware in multiple_result[0].learnwares: | |||
| hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression") | |||
| hetero_learnware.align(user_spec, X[:100], y[:100]) | |||
| hetero_learnware_list.append(hetero_learnware) | |||
| # Use averaging ensemble reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = AveragingReuser(learnware_list=hetero_learnware_list, mode="mean") | |||
| ensemble_predict_y = reuse_ensemble.predict(user_data=X) | |||
| # Use ensemble pruning reuser to reuse the searched learnwares to make prediction | |||
| reuse_ensemble = EnsemblePruningReuser(learnware_list=hetero_learnware_list, mode="regression") | |||
| reuse_ensemble.fit(X[:100], y[:100]) | |||
| ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=X) | |||
| print("Single model RMSE by finetune:", mean_squared_error(y, single_predict_y, squared=False)) | |||
| print("Averaging Reuser RMSE:", mean_squared_error(y, ensemble_predict_y, squared=False)) | |||
| print("Ensemble Pruning Reuser RMSE:", mean_squared_error(y, ensemble_pruning_predict_y, squared=False)) | |||
| def suite(): | |||
| _suite = unittest.TestSuite() | |||
| #_suite.addTest(TestHeteroWorkflow("test_prepare_learnware_randomly")) | |||
| #_suite.addTest(TestHeteroWorkflow("test_upload_delete_learnware")) | |||
| #_suite.addTest(TestHeteroWorkflow("test_train_market_model")) | |||
| _suite.addTest(TestHeteroWorkflow("test_search_semantics")) | |||
| _suite.addTest(TestHeteroWorkflow("test_hetero_stat_search")) | |||
| _suite.addTest(TestHeteroWorkflow("test_homo_stat_search")) | |||
| _suite.addTest(TestHeteroWorkflow("test_model_reuse")) | |||
| return _suite | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner(verbosity=2) | |||
| runner.run(suite()) | |||
| @@ -1,37 +1,34 @@ | |||
| import sys | |||
| import unittest | |||
| import os | |||
| import copy | |||
| import joblib | |||
| import logging | |||
| import tempfile | |||
| import pickle | |||
| import zipfile | |||
| import numpy as np | |||
| from sklearn import svm | |||
| from sklearn.datasets import load_digits | |||
| from sklearn.model_selection import train_test_split | |||
| from shutil import copyfile, rmtree | |||
| import learnware | |||
| learnware.init(logging_level=logging.WARNING) | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.specification import RKMETableSpecification, generate_rkme_table_spec | |||
| from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser | |||
| from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| user_semantic = { | |||
| "Data": {"Values": ["Table"], "Type": "Class"}, | |||
| "Task": { | |||
| "Values": ["Classification"], | |||
| "Type": "Class", | |||
| }, | |||
| "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, | |||
| "Scenario": {"Values": ["Education"], "Type": "Tag"}, | |||
| "Description": {"Values": "", "Type": "String"}, | |||
| "Name": {"Values": "", "Type": "String"}, | |||
| "License": {"Values": ["MIT"], "Type": "Class"}, | |||
| } | |||
| class TestWorkflow(unittest.TestCase): | |||
| universal_semantic_config = { | |||
| "data_type": "Table", | |||
| "task_type": "Classification", | |||
| "library_type": "Scikit-learn", | |||
| "scenarios": "Education", | |||
| "license": "MIT", | |||
| } | |||
| def _init_learnware_market(self): | |||
| """initialize learnware market""" | |||
| easy_market = instantiate_learnware_market(market_id="sklearn_digits_easy", name="easy", rebuild=True) | |||
| @@ -42,45 +39,30 @@ class TestWorkflow(unittest.TestCase): | |||
| X, y = load_digits(return_X_y=True) | |||
| for i in range(learnware_num): | |||
| dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i)) | |||
| os.makedirs(dir_path, exist_ok=True) | |||
| learnware_pool_dirpath = os.path.join(curr_root, "learnware_pool") | |||
| os.makedirs(learnware_pool_dirpath, exist_ok=True) | |||
| learnware_zippath = os.path.join(learnware_pool_dirpath, "svm_%d.zip" % (i)) | |||
| print("Preparing Learnware: %d" % (i)) | |||
| data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True) | |||
| clf = svm.SVC(kernel="linear", probability=True) | |||
| clf.fit(data_X, data_y) | |||
| joblib.dump(clf, os.path.join(dir_path, "svm.pkl")) | |||
| pickle_filepath = os.path.join(learnware_pool_dirpath, "model.pkl") | |||
| with open(pickle_filepath, "wb") as fout: | |||
| pickle.dump(clf, fout) | |||
| spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| spec.save(os.path.join(dir_path, "svm.json")) | |||
| init_file = os.path.join(dir_path, "__init__.py") | |||
| copyfile( | |||
| os.path.join(curr_root, "learnware_example/example_init.py"), init_file | |||
| ) # cp example_init.py init_file | |||
| yaml_file = os.path.join(dir_path, "learnware.yaml") | |||
| copyfile(os.path.join(curr_root, "learnware_example/example.yaml"), yaml_file) # cp example.yaml yaml_file | |||
| env_file = os.path.join(dir_path, "environment.yaml") | |||
| copyfile(os.path.join(curr_root, "learnware_example/environment.yaml"), env_file) | |||
| zip_file = dir_path + ".zip" | |||
| # zip -q -r -j zip_file dir_path | |||
| with zipfile.ZipFile(zip_file, "w") as zip_obj: | |||
| for foldername, subfolders, filenames in os.walk(dir_path): | |||
| for filename in filenames: | |||
| file_path = os.path.join(foldername, filename) | |||
| zip_info = zipfile.ZipInfo(filename) | |||
| zip_info.compress_type = zipfile.ZIP_STORED | |||
| with open(file_path, "rb") as file: | |||
| zip_obj.writestr(zip_info, file.read()) | |||
| rmtree(dir_path) # rm -r dir_path | |||
| self.zip_path_list.append(zip_file) | |||
| spec_filepath = os.path.join(learnware_pool_dirpath, "stat_spec.json") | |||
| spec.save(spec_filepath) | |||
| LearnwareTemplate.generate_learnware_zipfile( | |||
| learnware_zippath=learnware_zippath, | |||
| model_template=PickleModelTemplate(pickle_filepath=pickle_filepath, model_kwargs={"input_shape":(64,), "output_shape": (10,), "predict_method": "predict_proba"}), | |||
| stat_spec_template=StatSpecTemplate(filepath=spec_filepath, type="RKMETableSpecification"), | |||
| requirements=["scikit-learn==0.22"], | |||
| ) | |||
| self.zip_path_list.append(learnware_zippath) | |||
| def test_upload_delete_learnware(self, learnware_num=5, delete=True): | |||
| easy_market = self._init_learnware_market() | |||
| @@ -91,20 +73,22 @@ class TestWorkflow(unittest.TestCase): | |||
| assert len(easy_market) == 0, f"The market should be empty!" | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = "learnware_%d" % (idx) | |||
| semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx) | |||
| semantic_spec["Input"] = { | |||
| "Dimension": 64, | |||
| "Description": { | |||
| f"{i}": f"The value in the grid {i // 8}{i % 8} of the image of hand-written digit." | |||
| for i in range(64) | |||
| semantic_spec = generate_semantic_spec( | |||
| name=f"learnware_{idx}", | |||
| description=f"test_learnware_number_{idx}", | |||
| input_description={ | |||
| "Dimension": 64, | |||
| "Description": { | |||
| f"{i}": f"The value in the grid {i // 8}{i % 8} of the image of hand-written digit." | |||
| for i in range(64) | |||
| }, | |||
| }, | |||
| output_description={ | |||
| "Dimension": 10, | |||
| "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)}, | |||
| }, | |||
| } | |||
| semantic_spec["Output"] = { | |||
| "Dimension": 10, | |||
| "Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)}, | |||
| } | |||
| **self.universal_semantic_config | |||
| ) | |||
| easy_market.add_learnware(zip_path, semantic_spec) | |||
| print("Total Item:", len(easy_market)) | |||
| @@ -129,70 +113,52 @@ class TestWorkflow(unittest.TestCase): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| assert len(easy_market) == self.learnware_num, f"The number of learnwares must be {self.learnware_num}!" | |||
| test_folder = os.path.join(curr_root, "test_semantics") | |||
| # unzip -o -q zip_path -d unzip_dir | |||
| if os.path.exists(test_folder): | |||
| rmtree(test_folder) | |||
| os.makedirs(test_folder, exist_ok=True) | |||
| with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj: | |||
| zip_obj.extractall(path=test_folder) | |||
| semantic_spec = copy.deepcopy(user_semantic) | |||
| semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}" | |||
| semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}" | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| search_result = easy_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| print("User info:", user_info.get_semantic_spec()) | |||
| print(f"Search result:") | |||
| for search_item in single_result: | |||
| print( | |||
| "Choose learnware:", | |||
| search_item.learnware.id, | |||
| search_item.learnware.get_specification().get_semantic_spec(), | |||
| with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder: | |||
| with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj: | |||
| zip_obj.extractall(path=test_folder) | |||
| semantic_spec = generate_semantic_spec( | |||
| name=f"learnware_{learnware_num - 1}", | |||
| description=f"test_learnware_number_{learnware_num - 1}", | |||
| **self.universal_semantic_config, | |||
| ) | |||
| user_info = BaseUserInfo(semantic_spec=semantic_spec) | |||
| search_result = easy_market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| rmtree(test_folder) # rm -r test_folder | |||
| print(f"Search result:") | |||
| for search_item in single_result: | |||
| print("Choose learnware:",search_item.learnware.id) | |||
| def test_stat_search(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| print("Total Item:", len(easy_market)) | |||
| test_folder = os.path.join(curr_root, "test_stat") | |||
| with tempfile.TemporaryDirectory(prefix="learnware_test_workflow") as test_folder: | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=test_folder) | |||
| for idx, zip_path in enumerate(self.zip_path_list): | |||
| unzip_dir = os.path.join(test_folder, f"{idx}") | |||
| # unzip -o -q zip_path -d unzip_dir | |||
| if os.path.exists(unzip_dir): | |||
| rmtree(unzip_dir) | |||
| os.makedirs(unzip_dir, exist_ok=True) | |||
| with zipfile.ZipFile(zip_path, "r") as zip_obj: | |||
| zip_obj.extractall(path=unzip_dir) | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(unzip_dir, "svm.json")) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_results = easy_market.search_learnware(user_info) | |||
| user_spec = RKMETableSpecification() | |||
| user_spec.load(os.path.join(test_folder, "stat_spec.json")) | |||
| user_semantic = generate_semantic_spec(**self.universal_semantic_config) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec}) | |||
| search_results = easy_market.search_learnware(user_info) | |||
| single_result = search_results.get_single_results() | |||
| multiple_result = search_results.get_multiple_results() | |||
| assert len(single_result) >= 1, f"Statistical search failed!" | |||
| print(f"search result of user{idx}:") | |||
| for search_item in single_result: | |||
| print(f"score: {search_item.score}, learnware_id: {search_item.learnware.id}") | |||
| single_result = search_results.get_single_results() | |||
| multiple_result = search_results.get_multiple_results() | |||
| for mixture_item in multiple_result: | |||
| print(f"mixture_score: {mixture_item.score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_item.learnwares]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| assert len(single_result) >= 1, f"Statistical search failed!" | |||
| print(f"search result of user{idx}:") | |||
| for search_item in single_result: | |||
| print(f"score: {search_item.score}, learnware_id: {search_item.learnware.id}") | |||
| rmtree(test_folder) # rm -r test_folder | |||
| for mixture_item in multiple_result: | |||
| print(f"mixture_score: {mixture_item.score}\n") | |||
| mixture_id = " ".join([learnware.id for learnware in mixture_item.learnwares]) | |||
| print(f"mixture_learnware: {mixture_id}\n") | |||
| def test_learnware_reuse(self, learnware_num=5): | |||
| easy_market = self.test_upload_delete_learnware(learnware_num, delete=False) | |||
| @@ -202,6 +168,7 @@ class TestWorkflow(unittest.TestCase): | |||
| train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True) | |||
| stat_spec = generate_rkme_table_spec(X=data_X, gamma=0.1, cuda_idx=0) | |||
| user_semantic = generate_semantic_spec(**self.universal_semantic_config) | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec}) | |||
| search_results = easy_market.search_learnware(user_info) | |||
| @@ -243,5 +210,5 @@ def suite(): | |||
| if __name__ == "__main__": | |||
| runner = unittest.TextTestRunner() | |||
| runner = unittest.TextTestRunner(verbosity=2) | |||
| runner.run(suite()) | |||