diff --git a/README.md b/README.md
index a01c2d7..9b02df8 100644
--- a/README.md
+++ b/README.md
@@ -106,7 +106,7 @@ We also demonstrate the detail format of learnware zipfile in [DOC link], and al
Users can start an ``Learnware`` workflow according to the following steps:
-### Initialize a Learware Market
+### Initialize a Learnware Market
The ``EasyMarket`` class implements the most basic set of functions in a ``Learnware``.
You can use the following code snippet to initialize a basic ``Learnware`` named "demo":
diff --git a/docs/_static/img/image_labeled.png b/docs/_static/img/image_labeled.png
deleted file mode 100644
index b375f19..0000000
Binary files a/docs/_static/img/image_labeled.png and /dev/null differ
diff --git a/docs/_static/img/image_labeled.svg b/docs/_static/img/image_labeled.svg
new file mode 100644
index 0000000..0de61b1
--- /dev/null
+++ b/docs/_static/img/image_labeled.svg
@@ -0,0 +1,1423 @@
+
+
+
diff --git a/docs/_static/img/text_labeled.svg b/docs/_static/img/text_labeled.svg
new file mode 100644
index 0000000..669247e
--- /dev/null
+++ b/docs/_static/img/text_labeled.svg
@@ -0,0 +1,1358 @@
+
+
+
diff --git a/docs/_static/img/text_labeled_curves.png b/docs/_static/img/text_labeled_curves.png
deleted file mode 100644
index bfb9d7e..0000000
Binary files a/docs/_static/img/text_labeled_curves.png and /dev/null differ
diff --git a/docs/advanced/anchor.rst b/docs/advanced/anchor.rst
index 75fdcc1..db39bba 100644
--- a/docs/advanced/anchor.rst
+++ b/docs/advanced/anchor.rst
@@ -4,15 +4,15 @@ Anchor learnware
Anchor learnwares are a small fraction of representative learnwares that helps locate user's requirements through user feedback. The learnware market can choose or generate several learnwares as anchor learnwares corresponding to the specification island. If the user does not have sufficient training data for constructing an RKME requirement, the learnware market can send several anchor learnwares to the user. By feeding her own data to these anchor learnwares, some information such as (precision, recall) or other performance indicators, can be generated and returned to the market. These information could help the market identify potentially helpful models, e.g., by identifying models that are far from anchors exhibiting poor performance whereas close to anchors exhibiting relatively better performance in the specification island.
-To fulfill the anchor learnware method, you need to implement the following functions in ``anchor.py``:
+To fulfill the anchor learnware method, you need to implement the following functions in ``learnware/market/anchor/``:
-- First, you should design how the market chooses or generates anchor learnwares. This can be realized by selecting prototype models through functional space clustering, and more interesting designs can be explored. The function ``AnchoredMarket.update_anchor_learnware_list`` is reserved for it. The functions ``AnchoredMarket._update_anchor_learnware`` and ``AnchoredMarket._delete_anchor_learnware`` have been completed as auxiliary.
+- First, you should design how the market chooses or generates anchor learnwares. This can be realized by selecting prototype models through functional space clustering, and more interesting designs can be explored. The function ``AnchoredOrganizer.update_anchor_learnware_list`` is reserved for it. The functions ``AnchoredOrganizer._update_anchor_learnware`` and ``AnchoredOrganizer._delete_anchor_learnware`` have been completed as auxiliary.
-- Second, when a user comes with no RKME(or other statistical) specifications, the market should choose several anchor learnwares and send them to the user. This process is done by ``AnchoredMarket.search_anchor_learnware``, and the chosen anchors are stored in ``AnchoredUserInfo`` by ``AnchoredUserInfo.add_anchor_learnware``.
+- Second, when a user comes with no RKME(or other statistical) specifications, the market should choose several anchor learnwares and send them to the user. This process is done by ``AnchoredSearcher.search_anchor_learnware``, and the chosen anchors are stored in ``AnchoredUserInfo`` by ``AnchoredUserInfo.add_anchor_learnware_ids``.
- Third, the market should specify which performance indicator should the user return. By feeding the user's data to these anchor learnwares, the returned information is calculated and stored in ``AnchoredUserInfo`` by ``AnchoredUserInfo.update_stat_info``.
-- Fourth, according to the returned information from the user, the market should identify the helpful learnwares for the user. This process is done in ``AnchoredMarket.search_learnware``.
+- Fourth, according to the returned information from the user, the market should identify the helpful learnwares for the user. This process is done in ``AnchoredSearcher.search_learnware``.
diff --git a/docs/advanced/evolve.rst b/docs/advanced/evolve.rst
index 85201fa..4847c7b 100644
--- a/docs/advanced/evolve.rst
+++ b/docs/advanced/evolve.rst
@@ -1,5 +1,5 @@
==============================
-Specification Evolvement
+Evolvable Specification
==============================
The specification is the core of the learnware paradigm.
@@ -8,16 +8,16 @@ As the number of learnwares in the market increases, the knowledge held in the l
This growth makes it possible for specification evolvement, enabling the market to generate new specifications for each learnware that more accurately characterize the properties of each model and its relationships with others.
As a result, the learnware market can more effectively identify learnwares beneficial for user tasks.
-To achieve the evolvement of specifications, you need to implement the class ``EvolvedMarket`` in the following way:
+To achieve evolvable specifications, you need to implement the class ``EvolvedOrganizer`` in ``learnware/market/evolve/``:
-- First, design a method for the learnware market to generate new statistical specifications for learnwares and implement the function ``EvolvedMarket.generate_new_stat_specification``.
-- Second, use the function ``EvolvedMarket.generate_new_stat_specification`` to implement the function ``EvolvedMarket.evolve_learnware_list``, which enables learnwares to evolve by assigning new statistical specifications.
+- First, design a method for the learnware market to generate new statistical specifications for learnwares and implement the function ``EvolvedOrganizer.generate_new_stat_specification``.
+- Second, use the function ``EvolvedOrganizer.generate_new_stat_specification`` to implement the function ``EvolvedOrganizer.evolve_learnware_list``, which enables learnwares to evolve by assigning new statistical specifications.
When implementing the anchor design, it is essential to develop an appropriate evolvement method for anchor learnwares based on the specific anchor selection method.
In the anchor design, the learnware market sends anchor learnware to users, who then provide statistical information about the anchor learnwares on their tasks to the market.
Based on this statistical feedback from users, the market can more accurately characterize anchor learnwares and continuously evolve them.
-To realize specification evolvement, including anchor learnwares, you need to additionally implement the class ``EvolvedAnchoredMarket`` in the following way:
+To realize evolvable specifications, including anchor learnwares, you need to additionally implement the class ``EvolvedAnchoredOrganizer`` in ``learnware/market/evolve_anchor/``:
-- First, based on the specific anchor selection method, design an appropriate evolvement method for anchor learnwares and implement the function ``EvolvedAnchoredMarket.evolve_anchor_learnware_list``.
-- Second, utilize the statistical feedback from users to implement the function ``EvolvedAnchoredMarket.evolve_anchor_learnware_by_user``, which enables anchor learnwares to evolve continually as users interact with the learnware market.
\ No newline at end of file
+- First, based on the specific anchor selection method, design an appropriate evolvement method for anchor learnwares and implement the function ``EvolvedAnchoredOrganizer.evolve_anchor_learnware_list``.
+- Second, utilize the statistical feedback from users to implement the function ``EvolvedAnchoredOrganizer.evolve_anchor_learnware_by_user``, which enables anchor learnwares to evolve continually as users interact with the learnware market.
\ No newline at end of file
diff --git a/docs/components/market.rst b/docs/components/market.rst
index e08c20f..d512c44 100644
--- a/docs/components/market.rst
+++ b/docs/components/market.rst
@@ -26,7 +26,7 @@ Current Checkers
The ``learnware`` package provide two different implementation of ``market`` where both of them share the same ``checker`` list. So we first introduce the details of ``checker``\ s.
-The ``checker``s check a learnware object in different aspects, including environment configuration (``CondaChecker``), semantic specifications (``EasySemanticChecker``), and statistical specifications (``EasyStatChecker``). The ``__call__`` method of each checker is designed to be invoked as a function to conduct the respective checks on the learnware and return the outcomes. It defines three types of learnwares: ``INVALID_LEARNWARE`` denotes the learnware does not pass the check, ``NONUSABLE_LEARNWARE`` denotes the learnware pass the check but cannot make prediction, ``USABLE_LEARWARE`` denotes the leanrware pass the check and can make prediction. Currently, we have three ``checker``\ s, which are described below.
+The ``checker``s check a learnware object in different aspects, including environment configuration (``CondaChecker``), semantic specifications (``EasySemanticChecker``), and statistical specifications (``EasyStatChecker``). The ``__call__`` method of each checker is designed to be invoked as a function to conduct the respective checks on the learnware and return the outcomes. It defines three types of learnwares: ``INVALID_LEARNWARE`` denotes the learnware does not pass the check, ``NONUSABLE_LEARNWARE`` denotes the learnware pass the check but cannot make prediction, ``USABLE_LEARNWARE`` denotes the leanrware pass the check and can make prediction. Currently, we have three ``checker``\ s, which are described below.
``CondaChecker``
diff --git a/docs/components/model.rst b/docs/components/model.rst
index ba4e48b..76952ea 100644
--- a/docs/components/model.rst
+++ b/docs/components/model.rst
@@ -3,15 +3,55 @@
Model
================================
+A learnware is a well-performed trained model with a specification, where the model is an indispensable component of the learnware.
+
+
+In this section, we will first introduce the ``BaseModel``, which defines the standard format for models in the learnware package.
+Following that, we will introduce the ``ModelContainer``, which implements model deployment in conda virtual environments and Docker containers.
+
BaseModel
======================================
+The ``BaseModel`` class is a fundamental component of the learnware package and serves as a standard interface for defining machine learning models.
+This class is created to make it easier for users to submit learnwares to the market.
+It helps ensure that submitted models follow a clear set of rules and requirements.
+
+The model in a learnware should inherit the ``BaseModel`` class.
+Here's a more detailed explanation of key components:
+
+- ``input_shape``: Specify the shape of the input features your model expects.
+- ``output_shape``: Define the shape of the output predictions generated by your model.
+- ``predict``: Implement the predict method to make predictions using your model.
+- ``fit`` (optional): Use the fit method for training a model with input data and labels.
+- ``finetune`` (optional): Utilize the finetune method for further adjusting pre-existing models sourced from the market.
+
+By adhering to these standards, the compatibility and quality of submitted learnwares in the market are ensured.
+
ModelContainer
======================================
-CondaContainer
+The ``ModelContainer`` class is an essential component of the learnware package, designed to facilitate the management, deployment, and execution of machine learning models within a containerized environment.
+It inherits from the ``BaseModel`` class and extends its functionality to encapsulate model deployment and execution.
+
+ModelCondaContainer
+---------------------
+
+The ``ModelCondaContainer`` class is an extension of the ``ModelContainer`` class within the learnware package.
+Its primary purpose is to enable the management, deployment, and execution of machine learning models in a containerized environment, with a specific focus on using Conda virtual environments.
+This class inherits functionality from ``ModelContainer`` while providing additional capabilities related to Conda-based model execution.
+
+Specifically, the ``ModelCondaContainer`` supports the automatic creation of new Conda virtual environments based on the ``requirements.txt`` file (for pip installation) or ``environment.yaml`` file (for Conda installation) included within the learnware itself.
+It also installs the environment dependencies of the learnware, enabling it to run.
+
+ModelDockerContainer
---------------------
+The ``ModelDockerContainer`` class is a specialized extension of the ``ModelContainer`` class within the learnware package.
+It is designed to manage, deploy, and execute machine learning models within a containerized environment, specifically using Docker containers.
+This class inherits functionality from ``ModelContainer`` and enhances it with features related to Docker-based model execution.
+
+Compared to ``ModelCondaContainer``, ``ModelDockerContainer`` confines the model's execution within a Docker container.
+It installs the learnware's virtual environment inside the Docker container, isolating the learnware's execution from the host machine, thus enhancing the security of the learnware.
-DockerContainer
----------------------
\ No newline at end of file
+Similar to the ``ModelCondaContainer`` class, the ``ModelDockerContainer`` class also supports both types of environment dependency files for learnware: ``requirements.txt`` for pip-based installation and ``environment.yaml`` for conda-based installation.
+It automates the creation of Docker containers and the installation of learnware's environment dependencies within the container, enabling the learnware to run.
\ No newline at end of file
diff --git a/docs/components/spec.rst b/docs/components/spec.rst
index ea801ac..5c4a3d5 100644
--- a/docs/components/spec.rst
+++ b/docs/components/spec.rst
@@ -86,17 +86,18 @@ By randomly sampling a subset of the dataset, we can construct Image Specificati
.. code-block:: python
- import torchvision
- from torch.utils.data import DataLoader
- from learnware.specification import generate_rkme_image_spec
+ import torchvision
+ from torch.utils.data import DataLoader
+ from learnware.specification import generate_rkme_image_spec
- cifar10 = torchvision.datasets.CIFAR10(
- root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
- X, _ = next(iter(DataLoader(cifar10, batch_size=len(cifar10))))
+ cifar10 = torchvision.datasets.CIFAR10(
+ root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor()
+ )
+ X, _ = next(iter(DataLoader(cifar10, batch_size=len(cifar10))))
- spec = generate_rkme_image_spec(X, sample_size=5000)
- spec.save("cifar10.json")
+ spec = generate_rkme_image_spec(X, sample_size=5000)
+ spec.save("cifar10.json")
Privacy Protection
^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -113,6 +114,7 @@ The RBF not only exposes the real data (plotted in the corresponding position in
Text Specification
--------------------------
+Different from tabular data, each text input is a string of different length, so we should first transform them to equal-length arrays. Sentence embedding is used here to complete this transformation. We choose the model ``paraphrase-multilingual-MiniLM-L12-v2``, a lightweight multilingual embedding model. Then, we calculate the RKME specification on the embedding, just like we do with tabular data. Besides, we use the package ``langdetect`` to detect and store the language of the text inputs for further search. We hope to search for the learnware which supports the language of the user task.
System Specification
======================================
diff --git a/docs/index.rst b/docs/index.rst
index 33441a3..d1962eb 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -51,7 +51,7 @@ Document Structure
:caption: ADVANCED TOPICS:
Anchor Learnware
- Specification Evolvement
+ Evolvable Specification
.. toctree::
:maxdepth: 3
diff --git a/docs/start/exp.rst b/docs/start/exp.rst
index 34db05e..3931e1f 100644
--- a/docs/start/exp.rst
+++ b/docs/start/exp.rst
@@ -49,18 +49,18 @@ Homo Experiments
In homogeneous experiments, the 55 stores in the Corporacion dataset are considered as 55 users. Each store uses the same feature engineering method
and their own test set as user data. These users then search for and reuse homogeneous learnwares in the market which exactly match the feature spaces of their tasks.
-The Mean Squared Error (MSE) of search and reuse is presented in the table below:
+The Mean Squared Error (MSE) of search and reuse across all users is presented in the table below:
+-----------------------------------+---------------------+
-| Mean in Market (Single) | 0.323 ± 0.041 |
+| Mean in Market (Single) | 0.331 |
+-----------------------------------+---------------------+
-| Best in Market (Single) | 0.302 ± 0.036 |
+| Best in Market (Single) | 0.151 |
+-----------------------------------+---------------------+
-| Top-1 Reuse (Single) | 0.307 ± 0.037 |
+| Top-1 Reuse (Single) | 0.280 |
+-----------------------------------+---------------------+
-| Job Selector Reuse (Multiple) | 0.308 ± 0.038 |
+| Job Selector Reuse (Multiple) | 0.274 |
+-----------------------------------+---------------------+
-| Average Ensemble Reuse (Multiple) | 0.304 ± 0.036 |
+| Average Ensemble Reuse (Multiple) | 0.267 |
+-----------------------------------+---------------------+
When users have both test data and limited training data derived from their original data, reusing single or multiple searched learnwares from the market can often yield
@@ -91,15 +91,15 @@ we tested various heterogeneous learnware reuse methods (without using user's la
The average MSE performance across 41 users are as follows:
+-----------------------------------+---------------------+
-| Mean in Market (Single) | 1.459 ± 1.066 |
+| Mean in Market (Single) | 1.459 |
+-----------------------------------+---------------------+
-| Best in Market (Single) | 1.226 ± 1.032 |
+| Best in Market (Single) | 1.226 |
+-----------------------------------+---------------------+
-| Top-1 Reuse (Single) | 1.407 ± 1.061 |
+| Top-1 Reuse (Single) | 1.407 |
+-----------------------------------+---------------------+
-| Average Ensemble Reuse (Multiple) | 1.312 ± 1.099 |
+| Average Ensemble Reuse (Multiple) | 1.312 |
+-----------------------------------+---------------------+
-| User model with 50 labeled data | 1.267 ± 1.055 |
+| User model with 50 labeled data | 1.267 |
+-----------------------------------+---------------------+
From the results, it is noticeable that the learnware market still perform quite well even when users lack labeled data,
@@ -122,6 +122,33 @@ The average results across 10 users are depicted in the figure below:
We can observe that heterogeneous learnwares are beneficial when there's a limited amount of the user's labeled training data available,
aiding in better alignment with the user's specific task. This underscores the potential of learnwares to be applied to tasks beyond their original purpose.
+Image Experiment
+====================
+
+For the CIFAR-10 dataset, we sampled the training set unevenly by category and constructed unbalanced training datasets for the 50 learnwares that contained only some of the categories. This makes it unlikely that there exists any learnware in the learnware market that can accurately handle all categories of data; only the learnware whose training data is closest to the data distribution of the target task is likely to perform well on the target task. Specifically, the probability of each category being sampled obeys a random multinomial distribution, with a non-zero probability of sampling on only 4 categories, and the sampling ratio is 0.4: 0.4: 0.1: 0.1. Ultimately, the training set for each learnware contains 12,000 samples covering the data of 4 categories in CIFAR-10.
+
+We constructed 50 target tasks using data from the test set of CIFAR-10. Similar to constructing the training set for the learnwares, in order to allow for some variation between tasks, we sampled the test set unevenly. Specifically, the probability of each category being sampled obeys a random multinomial distribution, with non-zero sampling probability on 6 categories, and the sampling ratio is 0.3: 0.3: 0.1: 0.1: 0.1: 0.1. Ultimately, each target task contains 3000 samples covering the data of 6 categories in CIFAR-10.
+
+With this experimental setup, we evaluated the performance of RKME Image by calculating the mean accuracy across all users.
+
++-----------------------------------+---------------------+
+| Mean in Market (Single) | 0.346 |
++-----------------------------------+---------------------+
+| Best in Market (Single) | 0.688 |
++-----------------------------------+---------------------+
+| Top-1 Reuse (Single) | 0.534 |
++-----------------------------------+---------------------+
+| Job Selector Reuse (Multiple) | 0.534 |
++-----------------------------------+---------------------+
+| Average Ensemble Reuse (Multiple) | 0.676 |
++-----------------------------------+---------------------+
+
+In some specific settings, the user will have a small number of labeled samples. In such settings, learning the weight of selected learnwares on a limited number of labeled samples can result in a better performance than training directly on a limited number of labeled samples.
+
+.. image:: ../_static/img/image_labeled.svg
+ :align: center
+
+
Text Experiment
====================
@@ -147,69 +174,42 @@ Results
* ``unlabeled_text_example``:
-The accuracy of search and reuse is presented in the table below:
+The table below presents the mean accuracy of search and reuse across all users:
+-----------------------------------+---------------------+
-| Mean in Market (Single) | 0.507 ± 0.030 |
+| Mean in Market (Single) | 0.507 |
+-----------------------------------+---------------------+
-| Best in Market (Single) | 0.859 ± 0.051 |
+| Best in Market (Single) | 0.859 |
+-----------------------------------+---------------------+
-| Top-1 Reuse (Single) | 0.846 ± 0.054 |
+| Top-1 Reuse (Single) | 0.846 |
+-----------------------------------+---------------------+
-| Job Selector Reuse (Multiple) | 0.845 ± 0.053 |
+| Job Selector Reuse (Multiple) | 0.845 |
+-----------------------------------+---------------------+
-| Average Ensemble Reuse (Multiple) | 0.862 ± 0.051 |
+| Average Ensemble Reuse (Multiple) | 0.862 |
+-----------------------------------+---------------------+
* ``labeled_text_example``:
We present the change curves in classification error rates for both the user's self-trained model and the multiple learnware reuse(EnsemblePrune), showcasing their performance on the user's test data as the user's training data increases. The average results across 10 users are depicted below:
-.. image:: ../_static/img/text_labeled_curves.png
+.. image:: ../_static/img/text_labeled.svg
:align: center
- :alt: Text Limited Labeled Data
+ :alt: Results on Text Experimental Scenario
From the figure above, it is evident that when the user's own training data is limited, the performance of multiple learnware reuse surpasses that of the user's own model. As the user's training data grows, it is expected that the user's model will eventually outperform the learnware reuse. This underscores the value of reusing learnware to significantly conserve training data and achieve superior performance when user training data is limited.
-
-Image Experiment
-====================
-
-For the CIFAR-10 dataset, we sampled the training set unevenly by category and constructed unbalanced training datasets for the 50 learnwares that contained only some of the categories. This makes it unlikely that there exists any learnware in the learnware market that can accurately handle all categories of data; only the learnware whose training data is closest to the data distribution of the target task is likely to perform well on the target task. Specifically, the probability of each category being sampled obeys a random multinomial distribution, with a non-zero probability of sampling on only 4 categories, and the sampling ratio is 0.4: 0.4: 0.1: 0.1. Ultimately, the training set for each learnware contains 12,000 samples covering the data of 4 categories in CIFAR-10.
-
-We constructed 50 target tasks using data from the test set of CIFAR-10. Similar to constructing the training set for the learnwares, in order to allow for some variation between tasks, we sampled the test set unevenly. Specifically, the probability of each category being sampled obeys a random multinomial distribution, with non-zero sampling probability on 6 categories, and the sampling ratio is 0.3: 0.3: 0.1: 0.1: 0.1: 0.1. Ultimately, each target task contains 3000 samples covering the data of 6 categories in CIFAR-10.
-
-With this experimental setup, we evaluated the performance of RKME Image using 1 - Accuracy as the loss.
-
-+-----------------------------------+---------------------+
-| Mean in Market (Single) | 0.655 ± 0.021 |
-+-----------------------------------+---------------------+
-| Best in Market (Single) | 0.304 ± 0.046 |
-+-----------------------------------+---------------------+
-| Top-1 Reuse (Single) | 0.406 ± 0.128 |
-+-----------------------------------+---------------------+
-| Job Selector Reuse (Multiple) | 0.406 ± 0.128 |
-+-----------------------------------+---------------------+
-| Average Ensemble Reuse (Multiple) | 0.310 ± 0.112 |
-+-----------------------------------+---------------------+
-
-In some specific settings, the user will have a small number of labelled samples. In such settings, learning the weight of selected learnwares on a limited number of labelled samples can result in a better performance than training directly on a limited number of labelled samples.
-
-.. image:: ../_static/img/image_labeled.png
- :align: center
-
Get Start Examples
=========================
-Examples for `PFS, M5` and `CIFAR10` are available at [xxx]. You can run { main.py } directly to reproduce related experiments.
-The test code is mainly composed of three parts, namely data preparation (optional), specification generation and market construction, and search test.
-You can load data prepared by as and skip the data preparation step.
+We utilize the `fire` module to construct our experiments, including table, image and text scenario.
+
+Examples for `Image` are available at [examples/dataset_image_workflow].
+You can execute the experiment with the following commands:
+* `python workflow.py image_example`: Run both the unlabeled_image_example and labeled_image_example experiments. The results will be printed in the terminal, and the curves will be automatically saved in the `figs` directory.
-Examples for the `20-newsgroup` dataset are available at [examples/dataset_text_workflow].
-We utilize the `fire` module to construct our experiments. You can execute the experiment with the following commands:
+Examples for `Text` are available at [examples/dataset_text_workflow].
+You can execute the experiment with the following commands:
-* `python main.py prepare_market`: Prepares the market.
-* `python main.py unlabeled_text_example`: Executes the unlabeled_text_example experiment; the results will be printed in the terminal.
-* `python main.py labeled_text_example`: Executes the labeled_text_example experiment; result curves will be automatically saved in the `figs` directory.
-* Additionally, you can use `python main.py unlabeled_text_example True` to combine steps 1 and 2. The same approach applies to running labeled_text_example directly.
\ No newline at end of file
+* `python workflow.py unlabeled_text_example`: Run the unlabeled_text_example experiment. The results will be printed in the terminal.
+* `python workflow.py labeled_text_example`: Run the labeled_text_example experiment. The result curves will be automatically saved in the `figs` directory.
\ No newline at end of file
diff --git a/docs/start/quick.rst b/docs/start/quick.rst
index c33bf6b..d6e4715 100644
--- a/docs/start/quick.rst
+++ b/docs/start/quick.rst
@@ -92,7 +92,7 @@ Learnware Market Workflow
Users can start a ``Learnware Market`` workflow according to the following steps:
-Initialize a Learware Market
+Initialize a Learnware Market
-------------------------------
The ``EasyMarket`` class provides the core functions of a ``Learnware Market``.
diff --git a/docs/workflows/client.rst b/docs/workflows/client.rst
index b76fb8e..69114ad 100644
--- a/docs/workflows/client.rst
+++ b/docs/workflows/client.rst
@@ -19,7 +19,7 @@ How to Use Client
============================
-Initialize a Learware Client
+Initialize a Learnware Client
-------------------------------
diff --git a/examples/__init__.py b/examples/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/dataset_image_workflow/README.md b/examples/dataset_image_workflow/README.md
new file mode 100644
index 0000000..35d1f16
--- /dev/null
+++ b/examples/dataset_image_workflow/README.md
@@ -0,0 +1,33 @@
+# Image Dataset Workflow Example
+
+## Introduction
+
+For the CIFAR-10 dataset, we sampled the training set unevenly by category and constructed unbalanced training datasets for the 50 learnwares that contained only some of the categories. This makes it unlikely that there exists any learnware in the learnware market that can accurately handle all categories of data; only the learnware whose training data is closest to the data distribution of the target task is likely to perform well on the target task. Specifically, the probability of each category being sampled obeys a random multinomial distribution, with a non-zero probability of sampling on only 4 categories, and the sampling ratio is 0.4: 0.4: 0.1: 0.1. Ultimately, the training set for each learnware contains 12,000 samples covering the data of 4 categories in CIFAR-10.
+
+We constructed 50 target tasks using data from the test set of CIFAR-10. Similar to constructing the training set for the learnwares, in order to allow for some variation between tasks, we sampled the test set unevenly. Specifically, the probability of each category being sampled obeys a random multinomial distribution, with non-zero sampling probability on 6 categories, and the sampling ratio is 0.3: 0.3: 0.1: 0.1: 0.1: 0.1. Ultimately, each target task contains 3000 samples covering the data of 6 categories in CIFAR-10.
+
+## Run the code
+
+Run the following command to start the ``image_example``.
+
+```bash
+python workflow.py image_example
+```
+
+## Results
+
+With the experimental setup above, we evaluated the performance of RKME Image by calculating the mean accuracy across all users.
+
+| Metric | Value |
+|--------------------------------------|---------------------|
+| Mean in Market (Single) | 0.346 |
+| Best in Market (Single) | 0.688 |
+| Top-1 Reuse (Single) | 0.534 |
+| Job Selector Reuse (Multiple) | 0.534 |
+| Average Ensemble Reuse (Multiple) | 0.676 |
+
+In some specific settings, the user will have a small number of labeled samples. In such settings, learning the weight of selected learnwares on a limited number of labeled samples can result in a better performance than training directly on a limited number of labeled samples.
+
+
From the figure above, it is evident that when the user's own training data is limited, the performance of multiple learnware reuse surpasses that of the user's own model. As the user's training data grows, it is expected that the user's model will eventually outperform the learnware reuse. This underscores the value of reusing learnware to significantly conserve training data and achieve superior performance when user training data is limited.
diff --git a/examples/dataset_text_workflow/workflow.py b/examples/dataset_text_workflow/workflow.py
index dfbacda..21ad9b0 100644
--- a/examples/dataset_text_workflow/workflow.py
+++ b/examples/dataset_text_workflow/workflow.py
@@ -49,25 +49,25 @@ class TextDatasetWorkflow:
]
labels = ["User Model", "Multiple Learnware Reuse (EnsemblePrune)"]
- user_mat, pruning_mat = all_user_curves_data
- user_mat, pruning_mat = np.array(user_mat), np.array(pruning_mat)
- for mat, style, label in zip([user_mat, pruning_mat], styles, labels):
- mean_curve, std_curve = 1 - np.mean(mat, axis=0), np.std(mat, axis=0)
+ user_array, pruning_array = all_user_curves_data
+ for array, style, label in zip([user_array, pruning_array], styles, labels):
+ mean_curve = np.array([item[0] for item in array])
+ std_curve = np.array([item[1] for item in array])
plt.plot(mean_curve, **style, label=label)
plt.fill_between(
range(len(mean_curve)),
- mean_curve - 0.5 * std_curve,
- mean_curve + 0.5 * std_curve,
+ mean_curve - std_curve,
+ mean_curve + std_curve,
color=style["color"],
alpha=0.2,
)
- plt.xlabel("Labeled Data Size")
- plt.ylabel("1 - Accuracy")
- plt.title(f"Text Limited Labeled Data")
- plt.legend()
+ plt.xlabel("Amout of Labeled User Data", fontsize=14)
+ plt.ylabel("1 - Accuracy", fontsize=14)
+ plt.title(f"Results on Text Experimental Scenario", fontsize=16)
+ plt.legend(fontsize=14)
plt.tight_layout()
- plt.savefig(os.path.join(self.fig_path, "text_labeled_curves.png"), bbox_inches="tight", dpi=700)
+ plt.savefig(os.path.join(self.fig_path, "text_labeled_curves.svg"), bbox_inches="tight", dpi=700)
def _prepare_market(self, rebuild=False):
client = LearnwareClient()
@@ -189,9 +189,9 @@ class TextDatasetWorkflow:
self.root_path = os.path.dirname(os.path.abspath(__file__))
self.fig_path = os.path.join(self.root_path, "figs")
self.curve_path = os.path.join(self.root_path, "curves")
+ self._prepare_market(rebuild)
if train_flag:
- self._prepare_market(rebuild)
os.makedirs(self.fig_path, exist_ok=True)
os.makedirs(self.curve_path, exist_ok=True)
@@ -230,7 +230,6 @@ class TextDatasetWorkflow:
mixture_learnware_list = multiple_result[0].learnwares
else:
mixture_learnware_list = [single_result[0].learnware]
- print(len(train_x))
for n_label, repeated in zip(self.n_labeled_list, self.repeated_list):
user_model_score_list, reuse_pruning_score_list = [], []
@@ -257,7 +256,9 @@ class TextDatasetWorkflow:
single_score_mat.append([best_acc] * repeated)
user_model_score_mat.append(user_model_score_list)
pruning_score_mat.append(reuse_pruning_score_list)
- print(n_label, np.mean(user_model_score_mat[-1]), np.mean(pruning_score_mat[-1]))
+ print(
+ f"user_label_num: {n_label}, user_acc: {np.mean(user_model_score_mat[-1])}, pruning_acc: {np.mean(pruning_score_mat[-1])}"
+ )
logger.info(f"Saving Curves for User_{i}")
user_curves_data = (single_score_mat, user_model_score_mat, pruning_score_mat)
@@ -265,19 +266,25 @@ class TextDatasetWorkflow:
pickle.dump(user_curves_data, f)
pruning_curves_data, user_model_curves_data = [], []
- for i in range(self.text_benchmark.user_num):
- with open(os.path.join(self.curve_path, f"curve{str(i)}.pkl"), "rb") as f:
+ total_user_model_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))]
+ total_pruning_score_mat = [np.zeros(self.repeated_list[i]) for i in range(len(self.n_labeled_list))]
+ for user_idx in range(self.text_benchmark.user_num):
+ with open(os.path.join(self.curve_path, f"curve{str(user_idx)}.pkl"), "rb") as f:
user_curves_data = pickle.load(f)
(single_score_mat, user_model_score_mat, pruning_score_mat) = user_curves_data
- for i in range(len(single_score_mat)):
- user_model_score_mat[i] = np.mean(user_model_score_mat[i])
- pruning_score_mat[i] = np.mean(pruning_score_mat[i])
- if len(user_model_score_mat) < 6:
- for i in range(6 - len(user_model_score_mat)):
- user_model_score_mat.append(user_model_score_mat[-1])
- pruning_score_mat.append(pruning_score_mat[-1])
- user_model_curves_data.append(user_model_score_mat[:6])
- pruning_curves_data.append(pruning_score_mat[:6])
+
+ for i in range(len(self.n_labeled_list)):
+ total_user_model_score_mat[i] += 1 - np.array(user_model_score_mat[i])
+ total_pruning_score_mat[i] += 1 - np.array(pruning_score_mat[i])
+
+ for i in range(len(self.n_labeled_list)):
+ total_user_model_score_mat[i] /= self.text_benchmark.user_num
+ total_pruning_score_mat[i] /= self.text_benchmark.user_num
+ user_model_curves_data.append(
+ (np.mean(total_user_model_score_mat[i]), np.std(total_user_model_score_mat[i]))
+ )
+ pruning_curves_data.append((np.mean(total_pruning_score_mat[i]), np.std(total_pruning_score_mat[i])))
+
self._plot_labeled_peformance_curves([user_model_curves_data, pruning_curves_data])
diff --git a/learnware/market/base.py b/learnware/market/base.py
index 76e1f71..a7d3123 100644
--- a/learnware/market/base.py
+++ b/learnware/market/base.py
@@ -410,7 +410,7 @@ class BaseOrganizer:
----------
ids : Union[str, List[str]]
Give a id or a list of ids
- str: id of targer learware
+ str: id of target learnware
List[str]: A list of ids of target learnwares
Returns
@@ -428,7 +428,7 @@ class BaseOrganizer:
----------
ids : Union[str, List[str]]
Give a id or a list of ids
- str: id of targer learware
+ str: id of target learnware
List[str]: A list of ids of target learnwares
Returns
@@ -503,7 +503,7 @@ class BaseSearcher:
class BaseChecker:
INVALID_LEARNWARE = -1
NONUSABLE_LEARNWARE = 0
- USABLE_LEARWARE = 1
+ USABLE_LEARNWARE = 1
def reset(self, **kwargs):
pass
diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py
index 95c0f1a..8c1b6b1 100644
--- a/learnware/market/easy/checker.py
+++ b/learnware/market/easy/checker.py
@@ -217,4 +217,4 @@ class EasyStatChecker(BaseChecker):
logger.warning(message)
return self.INVALID_LEARNWARE, message
- return self.USABLE_LEARWARE, "EasyStatChecker Success"
+ return self.USABLE_LEARNWARE, "EasyStatChecker Success"
diff --git a/learnware/market/easy/organizer.py b/learnware/market/easy/organizer.py
index dfcfb42..6cba25f 100644
--- a/learnware/market/easy/organizer.py
+++ b/learnware/market/easy/organizer.py
@@ -233,7 +233,7 @@ class EasyOrganizer(BaseOrganizer):
----------
ids : Union[str, List[str]]
Give a id or a list of ids
- str: id of target learware
+ str: id of target learnware
List[str]: A list of ids of target learnwares
Returns
@@ -265,7 +265,7 @@ class EasyOrganizer(BaseOrganizer):
----------
ids : Union[str, List[str]]
Give a id or a list of ids
- str: id of targer learware
+ str: id of target learnware
List[str]: A list of ids of target learnwares
Returns
@@ -297,7 +297,7 @@ class EasyOrganizer(BaseOrganizer):
----------
ids : Union[str, List[str]]
Give a id or a list of ids
- str: id of targer learware
+ str: id of target learnware
List[str]: A list of ids of target learnwares
Returns
@@ -340,11 +340,11 @@ class EasyOrganizer(BaseOrganizer):
"""
if check_status is None:
filtered_ids = list(self.use_flags.keys())
- elif check_status in [BaseChecker.NONUSABLE_LEARNWARE, BaseChecker.USABLE_LEARWARE]:
+ elif check_status in [BaseChecker.NONUSABLE_LEARNWARE, BaseChecker.USABLE_LEARNWARE]:
filtered_ids = [key for key, value in self.use_flags.items() if value == check_status]
else:
logger.warning(
- f"check_status must be in [{BaseChecker.NONUSABLE_LEARNWARE}, {BaseChecker.USABLE_LEARWARE}]!"
+ f"check_status must be in [{BaseChecker.NONUSABLE_LEARNWARE}, {BaseChecker.USABLE_LEARNWARE}]!"
)
return None
diff --git a/learnware/market/heterogeneous/organizer/__init__.py b/learnware/market/heterogeneous/organizer/__init__.py
index 0d15b3e..1046d2c 100644
--- a/learnware/market/heterogeneous/organizer/__init__.py
+++ b/learnware/market/heterogeneous/organizer/__init__.py
@@ -39,7 +39,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
logger.info(f"Reload market mapping from checkpoint {self.market_mapping_path}")
self.market_mapping = HeteroMap.load(checkpoint=self.market_mapping_path)
if not rebuild:
- usable_ids = self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE)
+ usable_ids = self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARNWARE)
hetero_ids = self._get_hetero_learnware_ids(usable_ids)
for hetero_id in hetero_ids:
self._reload_learnware_hetero_spec(hetero_id)
@@ -95,14 +95,14 @@ class HeteroMapTableOrganizer(EasyOrganizer):
zip_path, semantic_spec, check_status, learnware_id
)
- if learnwere_status == BaseChecker.USABLE_LEARWARE and len(self._get_hetero_learnware_ids(learnware_id)):
- self._update_learware_hetero_spec(learnware_id)
+ if learnwere_status == BaseChecker.USABLE_LEARNWARE and len(self._get_hetero_learnware_ids(learnware_id)):
+ self._update_learnware_hetero_spec(learnware_id)
if self.auto_update:
self.count_down -= 1
if self.count_down == 0:
training_learnware_ids = self._get_hetero_learnware_ids(
- self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE)
+ self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARNWARE)
)
training_learnwares = self.get_learnware_by_ids(training_learnware_ids)
logger.info(f"Verified leanwares for training: {training_learnware_ids}")
@@ -113,7 +113,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
f"Market mapping train completed. Now update HeteroMapTableSpecification for {training_learnware_ids}"
)
self.market_mapping = updated_market_mapping
- self._update_learware_hetero_spec(training_learnware_ids)
+ self._update_learnware_hetero_spec(training_learnware_ids)
self.count_down = self.auto_update_limit
@@ -167,9 +167,9 @@ class HeteroMapTableOrganizer(EasyOrganizer):
"""
old_semantic_spec = self.learnware_list[id].get_specification().get_semantic_spec()
final_status = super(HeteroMapTableOrganizer, self).update_learnware(id, zip_path, semantic_spec, check_status)
- if final_status == BaseChecker.USABLE_LEARWARE and len(self._get_hetero_learnware_ids(id)):
+ if final_status == BaseChecker.USABLE_LEARNWARE and len(self._get_hetero_learnware_ids(id)):
if zip_path is not None or old_semantic_spec.get("Input", {}) != semantic_spec.get("Input", {}):
- self._update_learware_hetero_spec(id)
+ self._update_learnware_hetero_spec(id)
return final_status
def _reload_learnware_hetero_spec(self, learnware_id):
@@ -180,7 +180,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
hetero_spec.load(hetero_spec_path)
self.learnware_list[learnware_id].update_stat_spec(hetero_spec.type, hetero_spec)
else:
- self._update_learware_hetero_spec(learnware_id)
+ self._update_learnware_hetero_spec(learnware_id)
logger.info(f"Reload HeteroMapTableSpecification for hetero spec {learnware_id} succeed!")
except Exception as err:
logger.error(f"Reload HeteroMapTableSpecification for hetero spec {learnware_id} failed! due to {err}.")
@@ -198,14 +198,14 @@ class HeteroMapTableOrganizer(EasyOrganizer):
if len(self._get_hetero_learnware_ids(learnware_id)):
self._reload_learnware_hetero_spec(learnware_id)
- def _update_learware_hetero_spec(self, ids: Union[str, List[str]]):
+ def _update_learnware_hetero_spec(self, ids: Union[str, List[str]]):
"""Update learnware by ids, attempting to generate HeteroMapTableSpecification for them.
Parameters
----------
ids : Union[str, List[str]]
Give a id or a list of ids
- str: id of target learware
+ str: id of target learnware
List[str]: A list of ids of target learnwares
"""
if isinstance(ids, str):
@@ -233,7 +233,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
----------
ids : Union[str, List[str]]
Give a id or a list of ids
- str: id of target learware
+ str: id of target learnware
List[str]: A list of ids of target learnwares
Returns
diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py
index b2f39fe..8968793 100644
--- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py
+++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py
@@ -287,6 +287,9 @@ class HeteroMap(nn.Module):
# go through transformers, get the first cls embedding
encoder_output = self.encoder(**outputs) # bs, seqlen+1, hidden_dim
output_features = encoder_output[:, 0, :]
+
+ del inputs, outputs, encoder_output
+ torch.cuda.empty_cache()
return output_features
@@ -316,6 +319,8 @@ class HeteroMap(nn.Module):
with torch.no_grad():
output_features = self._extract_features(bs_x_test).detach().cpu().numpy()
output_feas_list.append(output_features)
+ del output_features
+ torch.cuda.empty_cache()
all_output_features = np.concatenate(output_feas_list, 0)
return all_output_features
diff --git a/learnware/reuse/ensemble_pruning.py b/learnware/reuse/ensemble_pruning.py
index 49c65b5..cf1ffb7 100644
--- a/learnware/reuse/ensemble_pruning.py
+++ b/learnware/reuse/ensemble_pruning.py
@@ -148,7 +148,9 @@ class EnsemblePruningReuser(BaseReuser):
import geatpy as ea
except ModuleNotFoundError:
raise ModuleNotFoundError(f"EnsemblePruningReuser is not available because 'geatpy' is not installed! Please install it manually (only support python_version<3.11).")
-
+
+ if torch.is_tensor(v_true):
+ v_true = v_true.detach().cpu().numpy()
model_num = v_predict.shape[1]
diff --git a/learnware/reuse/job_selector.py b/learnware/reuse/job_selector.py
index 467e063..277a01e 100644
--- a/learnware/reuse/job_selector.py
+++ b/learnware/reuse/job_selector.py
@@ -59,8 +59,11 @@ class JobSelectorReuser(BaseReuser):
for idx in range(len(self.learnware_list)):
data_idx_list = np.where(select_result == idx)[0]
if len(data_idx_list) > 0:
- # pred_y = self.learnware_list[idx].predict(raw_user_data[data_idx_list])
- pred_y = self.learnware_list[idx].predict([raw_user_data[i] for i in data_idx_list])
+ if isinstance(raw_user_data, list):
+ pred_y = self.learnware_list[idx].predict([raw_user_data[i] for i in data_idx_list])
+ else:
+ pred_y = self.learnware_list[idx].predict(raw_user_data[data_idx_list])
+
if isinstance(pred_y, torch.Tensor):
pred_y = pred_y.detach().cpu().numpy()
# elif isinstance(pred_y, tf.Tensor):
@@ -89,6 +92,9 @@ class JobSelectorReuser(BaseReuser):
user_data : np.ndarray
User's raw data.
"""
+ if torch.is_tensor(user_data):
+ user_data = user_data.detach().cpu().numpy()
+
if len(self.learnware_list) == 1:
# user_data_num = user_data.shape[0]
user_data_num = len(user_data)
@@ -118,9 +124,9 @@ class JobSelectorReuser(BaseReuser):
task_spec = learnware_rkme_spec_list[i]
if self.use_herding:
task_herding_num = max(5, int(self.herding_num * task_mixture_weight[i]))
- herding_X_i = task_spec.herding(task_herding_num).detach().cpu().numpy()
+ herding_X_i = task_spec.herding(task_herding_num)
else:
- herding_X_i = task_spec.z.detach().cpu().numpy()
+ herding_X_i = task_spec.get_z()
task_herding_num = herding_X_i.shape[0]
task_val_num = task_herding_num // 5
@@ -172,7 +178,7 @@ class JobSelectorReuser(BaseReuser):
user_data : np.ndarray
Raw user data.
task_rkme_list : List[RKMETableSpecification]
- The list of learwares' rkmes whose mixture approximates the user's rkme
+ The list of learnwares' rkmes whose mixture approximates the user's rkme
task_rkme_matrix : np.ndarray
Inner product matrix calculated from task_rkme_list.
"""
@@ -223,8 +229,10 @@ class JobSelectorReuser(BaseReuser):
try:
from lightgbm import LGBMClassifier, early_stopping
except ModuleNotFoundError:
- raise ModuleNotFoundError(f"JobSelectorReuser is not available because 'lightgbm' is not installed! Please install it manually.")
-
+ raise ModuleNotFoundError(
+ f"JobSelectorReuser is not available because 'lightgbm' is not installed! Please install it manually."
+ )
+
score_best = -1
learning_rate = [0.01]
max_depth = [66]
diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py
index 3ce9ad5..f89014a 100644
--- a/learnware/specification/regular/image/rkme.py
+++ b/learnware/specification/regular/image/rkme.py
@@ -366,7 +366,7 @@ class RKMEImageSpecification(RegularStatSpecification):
indices = torch.multinomial(self.beta, T, replacement=True)
mock = self.z[indices] + torch.randn_like(self.z[indices]) * 0.01
- return mock.numpy()
+ return mock.detach().cpu().numpy()
def _sampling_candidates(self, N: int) -> np.ndarray:
raise NotImplementedError()
diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py
index abecf6f..ade034d 100644
--- a/learnware/specification/regular/table/rkme.py
+++ b/learnware/specification/regular/table/rkme.py
@@ -411,7 +411,7 @@ class RKMETableSpecification(RegularStatSpecification):
S_shape = tuple([S.shape[0]] + list(Z_shape)[1:])
S = S.reshape(S_shape)
- return S
+ return S.detach().cpu().numpy()
def save(self, filepath: str):
"""Save the computed RKME specification to a specified path in JSON format.
@@ -457,7 +457,9 @@ class RKMETableSpecification(RegularStatSpecification):
for d in self.get_states():
if d in rkme_load.keys():
if d == "type" and rkme_load[d] != self.type:
- raise TypeError(f"The type of loaded RKME ({rkme_load[d]}) is different from the expected type ({self.type})!")
+ raise TypeError(
+ f"The type of loaded RKME ({rkme_load[d]}) is different from the expected type ({self.type})!"
+ )
setattr(self, d, rkme_load[d])