Table benchmark updatetags/v0.3.2
| @@ -44,4 +44,5 @@ tmp/ | |||
| learnware_pool/ | |||
| PFS/ | |||
| data/ | |||
| examples/results/ | |||
| examples/results/ | |||
| examples/*/results/ | |||
| @@ -1,6 +1,5 @@ | |||
| from learnware.tests.benchmarks import BenchmarkConfig | |||
| image_benchmark_config = BenchmarkConfig( | |||
| name="CIFAR-10", | |||
| user_num=100, | |||
| @@ -1,6 +1,6 @@ | |||
| import torch | |||
| import numpy as np | |||
| from torch import optim, nn | |||
| import torch | |||
| from torch import nn, optim | |||
| from torch.utils.data import DataLoader, Dataset | |||
| from learnware.utils import choose_device | |||
| @@ -1,24 +1,25 @@ | |||
| import os | |||
| import fire | |||
| import time | |||
| import torch | |||
| import pickle | |||
| import random | |||
| import tempfile | |||
| import numpy as np | |||
| import time | |||
| import fire | |||
| import matplotlib.pyplot as plt | |||
| import numpy as np | |||
| import torch | |||
| from config import image_benchmark_config | |||
| from model import ConvModel | |||
| from torch.utils.data import TensorDataset | |||
| from utils import evaluate, train_model | |||
| from learnware.utils import choose_device | |||
| from learnware.client import LearnwareClient | |||
| from learnware.logger import get_module_logger | |||
| from learnware.market import BaseUserInfo, instantiate_learnware_market | |||
| from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser | |||
| from learnware.specification import generate_stat_spec | |||
| from learnware.tests.benchmarks import LearnwareBenchmark | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser | |||
| from model import ConvModel | |||
| from utils import train_model, evaluate | |||
| from config import image_benchmark_config | |||
| from learnware.utils import choose_device | |||
| logger = get_module_logger("image_workflow", level="INFO") | |||
| @@ -0,0 +1,104 @@ | |||
| # Tabular Dataset Workflow Example | |||
| ## Introduction | |||
| On various tabular datasets, we initially evaluate the performance of identifying and reusing learnwares from the learnware market that share the same feature space as the user's tasks. Additionally, since tabular tasks often come from heterogeneous feature spaces, we also assess the identification and reuse of learnwares from different feature spaces. | |||
| ### Settings | |||
| Our study utilize three public datasets in the field of sales forecasting: [Predict Future Sales (PFS)](https://www.kaggle.com/c/competitive-data-science-predict-future-sales/data), [M5 Forecasting (M5)](https://www.kaggle.com/competitions/m5-forecasting-accuracy/data), and [Corporacion](https://www.kaggle.com/competitions/favorita-grocery-sales-forecasting/data). To enrich the data, we apply diverse feature engineering methods to these datasets. Then we divide each dataset by store and further split the data for each store into training and test sets. A LightGBM is trained on each Corporacion and PFS training set, while the test sets and M5 datasets are reversed to construct user tasks. This results in an experimental market consisting of 265 learnwares, encompassing five types of feature spaces and two types of label spaces. All these learnwares have been uploaded to the [Beimingwu system](https://bmwu.cloud/). | |||
| ### Baseline algorithms | |||
| The most basic way to reuse a learnware is Top-1 reuser, which directly uses the single learnware chosen by RKME specification. Besides, we implement two data-free reusers and two data-dependent reusers that works on single or multiple helpful learnwares identified from the market. When users have no labeled data, JobSelector reuser selects different learnwares for different samples by training a job selector classifier; AverageEnsemble reuser uses an ensemble method to make predictions. In cases where users possess both test data and limited labeled training data, EnsemblePruning reuser selectively ensembles a subset of learnwares to choose the ones that are most suitable for the user’s task; FeatureAugment reuser regards each received learnware as a feature augmentor, taking its output as a new feature and then builds a simple model on the augmented feature set. JobSelector and FeatureAugment are only effective for tabular data, while others are also useful for text and image data. | |||
| ## Homogeneous Cases | |||
| In the homogeneous cases, the 53 stores within the PFS dataset function as 53 individual users. Each store utilizes its own test data as user data and applies the same feature engineering approach used in the learnware market. These users could subsequently search for homogeneous learnwares within the market that possessed the same feature spaces as their tasks. | |||
| We conduct a comparison among different baseline algorithms when the users have no labeled data or limited amounts of labeled data. The average losses over all users are illustrated in the table below. It shows that unlabeled methods are much better than random choosing and deploying one learnware from the market. | |||
| <div align=center> | |||
| | Setting | MSE | | |||
| |-----------------------------------|--------| | |||
| | Mean in Market (Single) | 0.897 | | |||
| | Best in Market (Single) | 0.756 | | |||
| | Top-1 Reuse (Single) | 0.830 | | |||
| | Job Selector Reuse (Multiple) | 0.848 | | |||
| | Average Ensemble Reuse (Multiple) | 0.816 | | |||
| </div> | |||
| The figure below showcases the results for different amounts of labeled data provided by the user; for each user, we conducted multiple experiments repeatedly and calculated the mean and standard deviation of the losses; the average losses over all users are illustrated in the figure. It illustrates that when users have limited training data, identifying and reusing single or multiple learnwares yields superior performance compared to user's self-trained models. | |||
| <div align=center> | |||
| <img src="../../docs/_static/img/Homo_labeled_curves.svg" width="500" height="auto" style="max-width: 100%;"/> | |||
| </div> | |||
| ## Heterogeneous Cases | |||
| Based on the similarity of tasks between the market's learnwares and the users, the heterogeneous cases can be further categorized into different feature engineering and different task scenarios. | |||
| ### Different Feature Engineering Scenarios | |||
| We consider the 41 stores within the PFS dataset as users, generating their user data using a unique feature engineering approach that differ from the methods employed by the learnwares in the market. As a result, while some learnwares in the market are also designed for the PFS dataset, the feature spaces do not align exactly. | |||
| In this experimental setup, we examine various data-free reusers. The results in the following table indicate that even when users lack labeled data, the market exhibits strong performance, particularly with the AverageEnsemble method that reuses multiple learnwares. | |||
| <div align=center> | |||
| | Setting | MSE | | |||
| |-----------------------------------|--------| | |||
| | Mean in Market (Single) | 1.149 | | |||
| | Best in Market (Single) | 1.038 | | |||
| | Top-1 Reuse (Single) | 1.105 | | |||
| | Average Ensemble Reuse (Multiple) | 1.081 | | |||
| </div> | |||
| ### Different Task Scenarios | |||
| We employ three distinct feature engineering methods on all the ten stores from the M5 dataset, resulting in a total of 30 users. Although the overall task of sales forecasting aligns with the tasks addressed by the learnwares in the market, there are no learnwares specifically designed to satisfy the M5 sales forecasting requirements. | |||
| In the following figure, we present the loss curves for the user's self-trained model and several learnware reuse methods. It is evident that heterogeneous learnwares prove beneficial with a limited amount of the user's labeled data, facilitating better alignment with the user's specific task. | |||
| <div align=center> | |||
| <img src="../../docs/_static/img/Hetero_labeled_curves.svg" width="500" height="auto" style="max-width: 100%;"/> | |||
| </div> | |||
| ## Reproduction | |||
| ### Installation | |||
| To reproduce the above experiment, you need to install the necessary dependencies on top of the environment of `learnware` package. The specific commands are as follows: | |||
| ```bash | |||
| python -m pip install -r requirements.txt | |||
| ``` | |||
| ### Run the code | |||
| Run the following command to get the table results in `Homogeneous Cases`: | |||
| ```bash | |||
| python workflow.py unlabeled_homo_table_example | |||
| ``` | |||
| Run the following command to get the figure results in `Homogeneous Cases`: | |||
| ```bash | |||
| python workflow.py labeled_homo_table_example | |||
| ``` | |||
| Run the following command to get the table results in `Heterogeneous Cases`: | |||
| ```bash | |||
| python workflow.py cross_feat_eng_hetero_table_example | |||
| ``` | |||
| Run the following command to get the figure results in `Heterogeneous Cases`: | |||
| ```bash | |||
| python workflow.py cross_task_hetero_table_example | |||
| ``` | |||
| @@ -0,0 +1,119 @@ | |||
| import os | |||
| import random | |||
| import tempfile | |||
| import time | |||
| import traceback | |||
| import numpy as np | |||
| import requests | |||
| from config import market_mapping_params | |||
| from methods import loss_func_rmse, test_methods | |||
| from utils import set_seed | |||
| from learnware.client import LearnwareClient | |||
| from learnware.logger import get_module_logger | |||
| from learnware.market import instantiate_learnware_market | |||
| from learnware.reuse.utils import fill_data_with_mean | |||
| from learnware.tests.benchmarks import LearnwareBenchmark | |||
| logger = get_module_logger("base_table", level="INFO") | |||
| class TableWorkflow: | |||
| def __init__(self, benchmark_config, name="easy", rebuild=False, retrain=False): | |||
| self.root_path = os.path.abspath(os.path.join(__file__, "..")) | |||
| self.result_path = os.path.join(self.root_path, "results") | |||
| self.curves_result_path = os.path.join(self.result_path, "curves") | |||
| os.makedirs(self.result_path, exist_ok=True) | |||
| os.makedirs(self.curves_result_path, exist_ok=True) | |||
| self._prepare_market(benchmark_config, name, rebuild, retrain) | |||
| @staticmethod | |||
| def _limited_data(method, test_info, loss_func): | |||
| def subset_generator(): | |||
| for subset in test_info["train_subsets"]: | |||
| yield subset | |||
| all_scores = [] | |||
| for subset in subset_generator(): | |||
| subset_scores = [] | |||
| for sample in subset: | |||
| x_train, y_train = sample["x_train"], sample["y_train"] | |||
| model = method(x_train, y_train, test_info) | |||
| subset_scores.append(loss_func(model.predict(test_info["test_x"]), test_info["test_y"])) | |||
| all_scores.append(subset_scores) | |||
| return all_scores | |||
| @staticmethod | |||
| def get_train_subsets(n_labeled_list, n_repeat_list, train_x, train_y): | |||
| np.random.seed(1) | |||
| random.seed(1) | |||
| train_x = fill_data_with_mean(train_x) | |||
| train_subsets = [] | |||
| for n_label, repeated in zip(n_labeled_list, n_repeat_list): | |||
| train_subsets.append([]) | |||
| if n_label > len(train_x): | |||
| n_label = len(train_x) | |||
| for _ in range(repeated): | |||
| subset_idxs = np.random.choice(len(train_x), n_label, replace=False) | |||
| train_subsets[-1].append( | |||
| {"x_train": np.array(train_x[subset_idxs]), "y_train": np.array(train_y[subset_idxs])} | |||
| ) | |||
| return train_subsets | |||
| def _prepare_market(self, benchmark_config, name, rebuild, retrain): | |||
| client = LearnwareClient() | |||
| self.benchmark = LearnwareBenchmark().get_benchmark(benchmark_config) | |||
| self.market = instantiate_learnware_market( | |||
| market_id=self.benchmark.name, | |||
| name=name, | |||
| rebuild=rebuild, | |||
| organizer_kwargs={ | |||
| "auto_update": True, | |||
| "auto_update_limit": len(self.benchmark.learnware_ids), | |||
| **market_mapping_params, | |||
| } | |||
| if retrain | |||
| else None, | |||
| ) | |||
| self.user_semantic = client.get_semantic_specification(self.benchmark.learnware_ids[0]) | |||
| self.user_semantic["Name"]["Values"] = "" | |||
| if len(self.market) == 0 or rebuild is True: | |||
| if retrain: | |||
| set_seed(0) | |||
| for learnware_id in self.benchmark.learnware_ids: | |||
| with tempfile.TemporaryDirectory(prefix="table_benchmark_") as tempdir: | |||
| zip_path = os.path.join(tempdir, f"{learnware_id}.zip") | |||
| for i in range(20): | |||
| try: | |||
| semantic_spec = client.get_semantic_specification(learnware_id) | |||
| client.download_learnware(learnware_id, zip_path) | |||
| self.market.add_learnware(zip_path, semantic_spec) | |||
| break | |||
| except (requests.exceptions.RequestException, IOError, Exception) as e: | |||
| logger.info( | |||
| f"An error occurred when downloading {learnware_id}: {e}\n{traceback.format_exc()}, retrying..." | |||
| ) | |||
| time.sleep(1) | |||
| continue | |||
| def test_method(self, test_info, recorders, loss_func=loss_func_rmse): | |||
| method_name_full = test_info["method_name"] | |||
| method_name = ( | |||
| method_name_full if method_name_full == "user_model" else "_".join(method_name_full.split("_")[1:]) | |||
| ) | |||
| method = test_methods[method_name_full] | |||
| user, idx = test_info["user"], test_info["idx"] | |||
| recorder = recorders[method_name_full] | |||
| save_root_path = os.path.join(self.curves_result_path, f"{user}/{user}_{idx}") | |||
| os.makedirs(save_root_path, exist_ok=True) | |||
| save_path = os.path.join(save_root_path, f"{method_name}.json") | |||
| if recorder.should_test_method(user, idx, save_path): | |||
| scores = self._limited_data(method, test_info, loss_func) | |||
| recorder.record(user, scores) | |||
| recorder.save(save_path) | |||
| logger.info(f"Method {method_name} on {user}_{idx} finished") | |||
| @@ -0,0 +1,732 @@ | |||
| from learnware.tests.benchmarks import BenchmarkConfig | |||
| homo_n_labeled_list = [100, 200, 500, 1000, 2000, 4000, 6000, 8000, 10000] | |||
| homo_n_repeat_list = [10, 10, 10, 3, 3, 3, 3, 3, 3] | |||
| hetero_n_labeled_list = [10, 30, 50, 75, 100, 200, 500, 1000, 2000] | |||
| hetero_n_repeat_list = [10, 10, 10, 10, 10, 10, 3, 3, 3] | |||
| user_semantic = { | |||
| "Data": {"Values": ["Table"], "Type": "Class"}, | |||
| "Task": {"Values": ["Regression"], "Type": "Class"}, | |||
| "Library": {"Values": ["Others"], "Type": "Class"}, | |||
| "Scenario": {"Values": ["Business"], "Type": "Tag"}, | |||
| "Description": {"Values": "", "Type": "String"}, | |||
| "Name": {"Values": "", "Type": "String"}, | |||
| } | |||
| styles = { | |||
| "user_model": {"color": "navy", "marker": "o", "linestyle": "-"}, | |||
| "select_score": {"color": "gold", "marker": "s", "linestyle": "--"}, | |||
| "oracle_score": {"color": "darkorange", "marker": "^", "linestyle": "-."}, | |||
| "mean_score": {"color": "gray", "marker": "x", "linestyle": ":"}, | |||
| "single_aug": {"color": "gold", "marker": "s", "linestyle": "--"}, | |||
| "multiple_avg": {"color": "blue", "marker": "*", "linestyle": "-"}, | |||
| "multiple_aug": {"color": "purple", "marker": "d", "linestyle": "--"}, | |||
| "ensemble_pruning": {"color": "magenta", "marker": "d", "linestyle": "-."}, | |||
| } | |||
| labels = { | |||
| "user_model": "User Model", | |||
| "single_aug": "Single Learnware Reuse (FeatAug)", | |||
| "select_score": "Single Learnware Reuse (FeatAug)", | |||
| "multiple_aug": "Multiple Learnware Reuse (FeatAug)", | |||
| "ensemble_pruning": "Multiple Learnware Reuse (EnsemblePrune)", | |||
| "multiple_avg": "Multiple Learnware Reuse (Averaging)", | |||
| } | |||
| align_model_params = { | |||
| "network_type": "ArbitraryMapping", # ["ArbitraryMapping", "BaseMapping", "BaseMapping_BN", "BaseMapping_Dropout"] | |||
| "num_epoch": 50, | |||
| "lr": 1e-5, | |||
| "dropout_ratio": 0.2, | |||
| "activation": "relu", | |||
| "use_bn": True, | |||
| "hidden_dims": [128, 256, 128, 256], | |||
| } | |||
| market_mapping_params = { | |||
| "lr": 1e-4, | |||
| "num_epoch": 50, | |||
| "batch_size": 64, | |||
| "num_partition": 2, # num of column partitions for pos/neg sampling | |||
| "overlap_ratio": 0.7, # specify the overlap ratio of column partitions during the CL | |||
| "hidden_dim": 256, # the dimension of hidden embeddings | |||
| "num_layer": 6, # the number of transformer layers used in the encoder | |||
| "num_attention_head": 8, # the numebr of heads of multihead self-attention layer in the transformers, should be divisible by hidden_dim | |||
| "hidden_dropout_prob": 0.5, # the dropout ratio in the transformer encoder | |||
| "ffn_dim": 512, # the dimension of feed-forward layer in the transformer layer | |||
| "activation": "leakyrelu", | |||
| } | |||
| user_model_params = { | |||
| "M5": { | |||
| "lgb": { | |||
| "params": { | |||
| "boosting_type": "gbdt", | |||
| "objective": "rmse", | |||
| "metric": "rmse", | |||
| "learning_rate": 0.015, | |||
| "num_leaves": 300, | |||
| "max_depth": 500, | |||
| "n_estimators": 100000, | |||
| "boost_from_average": False, | |||
| "num_threads": 32, | |||
| "verbose": -1, | |||
| }, | |||
| "MAX_ROUNDS": 1000, | |||
| "early_stopping_rounds": 1000, | |||
| } | |||
| }, | |||
| "PFS": { | |||
| "lgb": { | |||
| "params": { | |||
| "boosting_type": "gbdt", | |||
| "num_leaves": 2**7 - 1, | |||
| "learning_rate": 0.01, | |||
| "objective": "rmse", | |||
| "metric": "rmse", | |||
| "feature_fraction": 0.75, | |||
| "bagging_fraction": 0.75, | |||
| "bagging_freq": 5, | |||
| "seed": 1, | |||
| "verbose": -100, | |||
| "n_estimators": 100000, | |||
| }, | |||
| "MAX_ROUNDS": 1000, | |||
| "early_stopping_rounds": 1000, | |||
| } | |||
| }, | |||
| "PFS_HOMO": { | |||
| "lgb": { | |||
| "params": { | |||
| "boosting_type": "gbdt", | |||
| "num_leaves": 2**7 - 1, | |||
| "learning_rate": 0.01, | |||
| "objective": "rmse", | |||
| "metric": "rmse", | |||
| "feature_fraction": 0.75, | |||
| "bagging_fraction": 0.75, | |||
| "bagging_freq": 5, | |||
| "seed": 1, | |||
| "verbose": -100, | |||
| "n_estimators": 100000, | |||
| }, | |||
| "MAX_ROUNDS": 1000, | |||
| "early_stopping_rounds": 1000, | |||
| } | |||
| }, | |||
| } | |||
| homo_table_benchmark_config = BenchmarkConfig( | |||
| name="PFS_HOMO", | |||
| user_num=53, | |||
| learnware_ids=[ | |||
| "00002265", | |||
| "00002266", | |||
| "00002267", | |||
| "00002268", | |||
| "00002269", | |||
| "00002270", | |||
| "00002271", | |||
| "00002272", | |||
| "00002273", | |||
| "00002274", | |||
| "00002275", | |||
| "00002276", | |||
| "00002277", | |||
| "00002278", | |||
| "00002279", | |||
| "00002280", | |||
| "00002281", | |||
| "00002282", | |||
| "00002283", | |||
| "00002284", | |||
| "00002285", | |||
| "00002286", | |||
| "00002287", | |||
| "00002288", | |||
| "00002289", | |||
| "00002290", | |||
| "00002291", | |||
| "00002292", | |||
| "00002293", | |||
| "00002294", | |||
| "00002295", | |||
| "00002296", | |||
| "00002297", | |||
| "00002298", | |||
| "00002299", | |||
| "00002300", | |||
| "00002301", | |||
| "00002302", | |||
| "00002303", | |||
| "00002304", | |||
| "00002305", | |||
| "00002306", | |||
| "00002307", | |||
| "00002308", | |||
| "00002309", | |||
| "00002310", | |||
| "00002311", | |||
| "00002312", | |||
| "00002313", | |||
| "00002314", | |||
| "00002315", | |||
| "00002316", | |||
| "00002317", | |||
| ], | |||
| test_data_path="PFS_HOMO/test_data.zip", | |||
| train_data_path="PFS_HOMO/train_data.zip", | |||
| extra_info_path="PFS_HOMO/extra_info.zip", | |||
| ) | |||
| hetero_cross_feat_eng_benchmark_config = BenchmarkConfig( | |||
| name="PFS", | |||
| user_num=41, | |||
| learnware_ids=[ | |||
| "00000342", | |||
| "00000343", | |||
| "00000344", | |||
| "00000345", | |||
| "00000346", | |||
| "00000347", | |||
| "00000348", | |||
| "00000349", | |||
| "00000350", | |||
| "00000351", | |||
| "00000352", | |||
| "00000353", | |||
| "00000354", | |||
| "00000355", | |||
| "00000356", | |||
| "00000357", | |||
| "00000358", | |||
| "00000359", | |||
| "00000360", | |||
| "00000361", | |||
| "00000362", | |||
| "00000363", | |||
| "00000364", | |||
| "00000365", | |||
| "00000366", | |||
| "00000367", | |||
| "00000368", | |||
| "00000369", | |||
| "00000370", | |||
| "00000371", | |||
| "00000372", | |||
| "00000373", | |||
| "00000374", | |||
| "00000375", | |||
| "00000376", | |||
| "00000377", | |||
| "00000378", | |||
| "00000379", | |||
| "00000380", | |||
| "00000381", | |||
| "00000382", | |||
| "00000383", | |||
| "00000384", | |||
| "00000385", | |||
| "00000386", | |||
| "00000387", | |||
| "00000388", | |||
| "00000389", | |||
| "00000390", | |||
| "00000391", | |||
| "00000392", | |||
| "00000393", | |||
| "00000394", | |||
| "00000395", | |||
| "00000396", | |||
| "00000397", | |||
| "00000398", | |||
| "00000399", | |||
| "00000400", | |||
| "00000401", | |||
| "00000402", | |||
| "00000403", | |||
| "00000404", | |||
| "00000405", | |||
| "00000406", | |||
| "00000407", | |||
| "00000408", | |||
| "00000409", | |||
| "00000410", | |||
| "00000411", | |||
| "00000412", | |||
| "00000413", | |||
| "00000414", | |||
| "00000415", | |||
| "00000416", | |||
| "00000417", | |||
| "00000418", | |||
| "00000419", | |||
| "00000420", | |||
| "00000421", | |||
| "00000422", | |||
| "00000423", | |||
| "00000424", | |||
| "00000425", | |||
| "00000426", | |||
| "00000427", | |||
| "00000428", | |||
| "00000429", | |||
| "00000430", | |||
| "00000431", | |||
| "00000432", | |||
| "00000433", | |||
| "00000434", | |||
| "00000435", | |||
| "00000436", | |||
| "00000437", | |||
| "00000438", | |||
| "00000439", | |||
| "00000440", | |||
| "00000441", | |||
| "00000442", | |||
| "00000443", | |||
| "00000444", | |||
| "00000730", | |||
| "00000731", | |||
| "00000732", | |||
| "00000733", | |||
| "00000734", | |||
| "00000735", | |||
| "00000736", | |||
| "00000737", | |||
| "00000738", | |||
| "00000739", | |||
| "00000740", | |||
| "00000741", | |||
| "00000742", | |||
| "00000743", | |||
| "00000744", | |||
| "00000745", | |||
| "00000746", | |||
| "00000747", | |||
| "00000748", | |||
| "00000749", | |||
| "00000750", | |||
| "00000751", | |||
| "00000752", | |||
| "00000753", | |||
| "00000754", | |||
| "00000755", | |||
| "00000756", | |||
| "00000757", | |||
| "00000758", | |||
| "00000759", | |||
| "00000760", | |||
| "00000761", | |||
| "00000762", | |||
| "00000763", | |||
| "00000764", | |||
| "00000765", | |||
| "00000766", | |||
| "00000767", | |||
| "00000768", | |||
| "00000769", | |||
| "00000770", | |||
| "00000771", | |||
| "00000772", | |||
| "00000773", | |||
| "00000774", | |||
| "00000775", | |||
| "00000776", | |||
| "00000777", | |||
| "00000778", | |||
| "00000779", | |||
| "00000780", | |||
| "00000781", | |||
| "00000782", | |||
| "00000783", | |||
| "00000786", | |||
| "00000787", | |||
| "00000788", | |||
| "00000789", | |||
| "00000790", | |||
| "00000791", | |||
| "00000792", | |||
| "00000793", | |||
| "00000794", | |||
| "00000795", | |||
| "00000796", | |||
| "00000797", | |||
| "00000798", | |||
| "00000799", | |||
| "00000800", | |||
| "00000801", | |||
| "00000802", | |||
| "00000803", | |||
| "00000804", | |||
| "00000805", | |||
| "00000806", | |||
| "00000807", | |||
| "00000808", | |||
| "00000809", | |||
| "00000810", | |||
| "00000811", | |||
| "00000812", | |||
| "00000813", | |||
| "00000814", | |||
| "00000815", | |||
| "00000816", | |||
| "00000817", | |||
| "00000818", | |||
| "00000819", | |||
| "00000820", | |||
| "00000821", | |||
| "00000822", | |||
| "00000823", | |||
| "00000824", | |||
| "00000825", | |||
| "00000826", | |||
| "00000827", | |||
| "00000828", | |||
| "00000829", | |||
| "00000830", | |||
| "00000831", | |||
| "00000832", | |||
| "00000833", | |||
| "00000834", | |||
| "00000835", | |||
| "00000836", | |||
| "00000837", | |||
| "00000838", | |||
| "00000839", | |||
| "00000859", | |||
| "00000860", | |||
| "00000861", | |||
| "00000862", | |||
| "00000863", | |||
| "00000864", | |||
| "00000865", | |||
| "00000866", | |||
| "00000867", | |||
| "00000868", | |||
| "00000869", | |||
| "00000870", | |||
| "00000871", | |||
| "00000872", | |||
| "00000873", | |||
| "00000874", | |||
| "00000875", | |||
| "00000876", | |||
| "00000877", | |||
| "00000878", | |||
| "00000879", | |||
| "00000880", | |||
| "00000881", | |||
| "00000882", | |||
| "00000883", | |||
| "00000884", | |||
| "00000885", | |||
| "00000886", | |||
| "00000887", | |||
| "00000888", | |||
| "00000889", | |||
| "00000890", | |||
| "00000891", | |||
| "00000892", | |||
| "00000893", | |||
| "00000894", | |||
| "00000895", | |||
| "00000896", | |||
| "00000897", | |||
| "00000898", | |||
| "00000899", | |||
| "00000900", | |||
| "00000901", | |||
| "00000902", | |||
| "00000903", | |||
| "00000904", | |||
| "00000905", | |||
| "00000906", | |||
| "00000907", | |||
| "00000908", | |||
| "00000909", | |||
| "00000910", | |||
| "00000911", | |||
| "00000912", | |||
| ], | |||
| test_data_path="PFS/test_data.zip", | |||
| train_data_path="PFS/train_data.zip", | |||
| extra_info_path="PFS/extra_info.zip", | |||
| ) | |||
| hetero_cross_task_benchmark_config = BenchmarkConfig( | |||
| name="M5", | |||
| user_num=30, | |||
| learnware_ids=[ | |||
| "00000342", | |||
| "00000343", | |||
| "00000344", | |||
| "00000345", | |||
| "00000346", | |||
| "00000347", | |||
| "00000348", | |||
| "00000349", | |||
| "00000350", | |||
| "00000351", | |||
| "00000352", | |||
| "00000353", | |||
| "00000354", | |||
| "00000355", | |||
| "00000356", | |||
| "00000357", | |||
| "00000358", | |||
| "00000359", | |||
| "00000360", | |||
| "00000361", | |||
| "00000362", | |||
| "00000363", | |||
| "00000364", | |||
| "00000365", | |||
| "00000366", | |||
| "00000367", | |||
| "00000368", | |||
| "00000369", | |||
| "00000370", | |||
| "00000371", | |||
| "00000372", | |||
| "00000373", | |||
| "00000374", | |||
| "00000375", | |||
| "00000376", | |||
| "00000377", | |||
| "00000378", | |||
| "00000379", | |||
| "00000380", | |||
| "00000381", | |||
| "00000382", | |||
| "00000383", | |||
| "00000384", | |||
| "00000385", | |||
| "00000386", | |||
| "00000387", | |||
| "00000388", | |||
| "00000389", | |||
| "00000390", | |||
| "00000391", | |||
| "00000392", | |||
| "00000393", | |||
| "00000394", | |||
| "00000395", | |||
| "00000396", | |||
| "00000397", | |||
| "00000398", | |||
| "00000399", | |||
| "00000400", | |||
| "00000401", | |||
| "00000402", | |||
| "00000403", | |||
| "00000404", | |||
| "00000405", | |||
| "00000406", | |||
| "00000407", | |||
| "00000408", | |||
| "00000409", | |||
| "00000410", | |||
| "00000411", | |||
| "00000412", | |||
| "00000413", | |||
| "00000414", | |||
| "00000415", | |||
| "00000416", | |||
| "00000417", | |||
| "00000418", | |||
| "00000419", | |||
| "00000420", | |||
| "00000421", | |||
| "00000422", | |||
| "00000423", | |||
| "00000424", | |||
| "00000425", | |||
| "00000426", | |||
| "00000427", | |||
| "00000428", | |||
| "00000429", | |||
| "00000430", | |||
| "00000431", | |||
| "00000432", | |||
| "00000433", | |||
| "00000434", | |||
| "00000435", | |||
| "00000436", | |||
| "00000437", | |||
| "00000438", | |||
| "00000439", | |||
| "00000440", | |||
| "00000441", | |||
| "00000442", | |||
| "00000443", | |||
| "00000444", | |||
| "00000730", | |||
| "00000731", | |||
| "00000732", | |||
| "00000733", | |||
| "00000734", | |||
| "00000735", | |||
| "00000736", | |||
| "00000737", | |||
| "00000738", | |||
| "00000739", | |||
| "00000740", | |||
| "00000741", | |||
| "00000742", | |||
| "00000743", | |||
| "00000744", | |||
| "00000745", | |||
| "00000746", | |||
| "00000747", | |||
| "00000748", | |||
| "00000749", | |||
| "00000750", | |||
| "00000751", | |||
| "00000752", | |||
| "00000753", | |||
| "00000754", | |||
| "00000755", | |||
| "00000756", | |||
| "00000757", | |||
| "00000758", | |||
| "00000759", | |||
| "00000760", | |||
| "00000761", | |||
| "00000762", | |||
| "00000763", | |||
| "00000764", | |||
| "00000765", | |||
| "00000766", | |||
| "00000767", | |||
| "00000768", | |||
| "00000769", | |||
| "00000770", | |||
| "00000771", | |||
| "00000772", | |||
| "00000773", | |||
| "00000774", | |||
| "00000775", | |||
| "00000776", | |||
| "00000777", | |||
| "00000778", | |||
| "00000779", | |||
| "00000780", | |||
| "00000781", | |||
| "00000782", | |||
| "00000783", | |||
| "00000786", | |||
| "00000787", | |||
| "00000788", | |||
| "00000789", | |||
| "00000790", | |||
| "00000791", | |||
| "00000792", | |||
| "00000793", | |||
| "00000794", | |||
| "00000795", | |||
| "00000796", | |||
| "00000797", | |||
| "00000798", | |||
| "00000799", | |||
| "00000800", | |||
| "00000801", | |||
| "00000802", | |||
| "00000803", | |||
| "00000804", | |||
| "00000805", | |||
| "00000806", | |||
| "00000807", | |||
| "00000808", | |||
| "00000809", | |||
| "00000810", | |||
| "00000811", | |||
| "00000812", | |||
| "00000813", | |||
| "00000814", | |||
| "00000815", | |||
| "00000816", | |||
| "00000817", | |||
| "00000818", | |||
| "00000819", | |||
| "00000820", | |||
| "00000821", | |||
| "00000822", | |||
| "00000823", | |||
| "00000824", | |||
| "00000825", | |||
| "00000826", | |||
| "00000827", | |||
| "00000828", | |||
| "00000829", | |||
| "00000830", | |||
| "00000831", | |||
| "00000832", | |||
| "00000833", | |||
| "00000834", | |||
| "00000835", | |||
| "00000836", | |||
| "00000837", | |||
| "00000838", | |||
| "00000839", | |||
| "00000859", | |||
| "00000860", | |||
| "00000861", | |||
| "00000862", | |||
| "00000863", | |||
| "00000864", | |||
| "00000865", | |||
| "00000866", | |||
| "00000867", | |||
| "00000868", | |||
| "00000869", | |||
| "00000870", | |||
| "00000871", | |||
| "00000872", | |||
| "00000873", | |||
| "00000874", | |||
| "00000875", | |||
| "00000876", | |||
| "00000877", | |||
| "00000878", | |||
| "00000879", | |||
| "00000880", | |||
| "00000881", | |||
| "00000882", | |||
| "00000883", | |||
| "00000884", | |||
| "00000885", | |||
| "00000886", | |||
| "00000887", | |||
| "00000888", | |||
| "00000889", | |||
| "00000890", | |||
| "00000891", | |||
| "00000892", | |||
| "00000893", | |||
| "00000894", | |||
| "00000895", | |||
| "00000896", | |||
| "00000897", | |||
| "00000898", | |||
| "00000899", | |||
| "00000900", | |||
| "00000901", | |||
| "00000902", | |||
| "00000903", | |||
| "00000904", | |||
| "00000905", | |||
| "00000906", | |||
| "00000907", | |||
| "00000908", | |||
| "00000909", | |||
| "00000910", | |||
| "00000911", | |||
| "00000912", | |||
| ], | |||
| test_data_path="M5/test_data.zip", | |||
| train_data_path="M5/train_data.zip", | |||
| extra_info_path="M5/extra_info.zip", | |||
| ) | |||
| @@ -0,0 +1,174 @@ | |||
| import os | |||
| import warnings | |||
| import numpy as np | |||
| from base import TableWorkflow | |||
| from config import align_model_params, hetero_n_labeled_list, hetero_n_repeat_list, user_semantic | |||
| from methods import loss_func_rmse | |||
| from utils import Recorder, plot_performance_curves, set_seed | |||
| from learnware.logger import get_module_logger | |||
| from learnware.market import BaseUserInfo | |||
| from learnware.reuse import AveragingReuser, FeatureAlignLearnware | |||
| from learnware.specification import generate_stat_spec | |||
| warnings.filterwarnings("ignore") | |||
| logger = get_module_logger("hetero_test", level="INFO") | |||
| class HeterogeneousDatasetWorkflow(TableWorkflow): | |||
| def unlabeled_hetero_table_example(self): | |||
| set_seed(0) | |||
| logger.info("Total Item: %d" % len(self.market)) | |||
| learnware_rmse_list = [] | |||
| single_score_list = [] | |||
| ensemble_score_list = [] | |||
| all_learnwares = self.market.get_learnwares() | |||
| user = self.benchmark.name | |||
| for idx in range(self.benchmark.user_num): | |||
| test_x, test_y = self.benchmark.get_test_data(user_ids=idx) | |||
| test_x, test_y, feature_descriptions = test_x.values, test_y.values, test_x.columns | |||
| user_stat_spec = generate_stat_spec(type="table", X=test_x) | |||
| input_description = { | |||
| "Dimension": len(feature_descriptions), | |||
| "Description": {str(i): feature_descriptions[i] for i in range(len(feature_descriptions))}, | |||
| } | |||
| user_semantic["Input"] = input_description | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={user_stat_spec.type: user_stat_spec}) | |||
| logger.info(f"Searching Market for user: {user}_{idx}") | |||
| search_result = self.market.search_learnware(user_info, search_method="auto") | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| logger.info(f"hetero search result of user {user}_{idx}: {single_result[0].learnware.id}") | |||
| logger.info( | |||
| f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" | |||
| ) | |||
| single_hetero_learnware = FeatureAlignLearnware(single_result[0].learnware, **align_model_params) | |||
| single_hetero_learnware.align(user_rkme=user_stat_spec) | |||
| pred_y = single_hetero_learnware.predict(test_x) | |||
| single_score_list.append(loss_func_rmse(pred_y, test_y)) | |||
| rmse_list = [] | |||
| for learnware in all_learnwares: | |||
| hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params) | |||
| hetero_learnware.align(user_rkme=user_stat_spec) | |||
| pred_y = hetero_learnware.predict(test_x) | |||
| rmse_list.append(loss_func_rmse(pred_y, test_y)) | |||
| logger.info( | |||
| f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[0]}" | |||
| ) | |||
| if len(multiple_result) > 0: | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | |||
| logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") | |||
| mixture_learnware_list = [] | |||
| for learnware in multiple_result[0].learnwares: | |||
| hetero_learnware = FeatureAlignLearnware(learnware, **align_model_params) | |||
| hetero_learnware.align(user_rkme=user_stat_spec) | |||
| mixture_learnware_list.append(hetero_learnware) | |||
| else: | |||
| hetero_learnware = FeatureAlignLearnware(single_result[0].learnware, **align_model_params) | |||
| hetero_learnware.align(user_rkme=user_stat_spec) | |||
| mixture_learnware_list = [hetero_learnware] | |||
| # test reuse (ensemble) | |||
| reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="mean") | |||
| ensemble_predict_y = reuse_ensemble.predict(user_data=test_x) | |||
| ensemble_score = loss_func_rmse(ensemble_predict_y, test_y) | |||
| ensemble_score_list.append(ensemble_score) | |||
| logger.info(f"mixture reuse rmse (ensemble): {ensemble_score}") | |||
| learnware_rmse_list.append(rmse_list) | |||
| single_list = np.array(learnware_rmse_list) | |||
| avg_score_list = [np.mean(lst, axis=0) for lst in single_list] | |||
| oracle_score_list = [np.min(lst, axis=0) for lst in single_list] | |||
| logger.info( | |||
| "RMSE of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f" | |||
| % ( | |||
| np.mean(single_score_list), | |||
| np.std(single_score_list), | |||
| np.mean(avg_score_list), | |||
| np.std(avg_score_list), | |||
| np.mean(oracle_score_list), | |||
| np.std(oracle_score_list), | |||
| ) | |||
| ) | |||
| logger.info( | |||
| "Averaging Ensemble Reuse Performance: %.3f +/- %.3f" | |||
| % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) | |||
| ) | |||
| def labeled_hetero_table_example(self, skip_test): | |||
| set_seed(0) | |||
| logger.info("Total Items: %d" % len(self.market)) | |||
| methods = ["user_model", "hetero_single_aug", "hetero_multiple_avg", "hetero_ensemble_pruning"] | |||
| recorders = {method: Recorder() for method in methods} | |||
| user = self.benchmark.name | |||
| if not skip_test: | |||
| for idx in range(self.benchmark.user_num): | |||
| test_x, test_y = self.benchmark.get_test_data(user_ids=idx) | |||
| test_x, test_y = test_x.values, test_y.values | |||
| train_x, train_y = self.benchmark.get_train_data(user_ids=idx) | |||
| train_x, train_y, feature_descriptions = train_x.values, train_y.values, train_x.columns | |||
| train_subsets = self.get_train_subsets(hetero_n_labeled_list, hetero_n_repeat_list, train_x, train_y) | |||
| user_stat_spec = generate_stat_spec(type="table", X=test_x) | |||
| input_description = { | |||
| "Dimension": len(feature_descriptions), | |||
| "Description": {str(i): feature_descriptions[i] for i in range(len(feature_descriptions))}, | |||
| } | |||
| user_semantic["Input"] = input_description | |||
| user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={user_stat_spec.type: user_stat_spec}) | |||
| logger.info(f"Searching Market for user: {user}_{idx}") | |||
| search_result = self.market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| if len(multiple_result) > 0: | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | |||
| logger.info(f"Mixture score: {multiple_result[0].score}, Mixture learnware: {mixture_id}") | |||
| mixture_learnware_list = multiple_result[0].learnwares | |||
| else: | |||
| mixture_learnware_list = [single_result[0].learnware] | |||
| logger.info( | |||
| f"Hetero search result of user {user}_{idx}: mixture learnware num: {len(mixture_learnware_list)}" | |||
| ) | |||
| test_info = { | |||
| "user": user, | |||
| "idx": idx, | |||
| "train_subsets": train_subsets, | |||
| "test_x": test_x, | |||
| "test_y": test_y, | |||
| "n_labeled_list": hetero_n_labeled_list, | |||
| } | |||
| common_config = {"user_rkme": user_stat_spec, "learnwares": mixture_learnware_list} | |||
| method_configs = { | |||
| "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, | |||
| "hetero_single_aug": {"user_rkme": user_stat_spec, "single_learnware": single_result[0].learnware}, | |||
| "hetero_multiple_avg": common_config, | |||
| "hetero_ensemble_pruning": common_config, | |||
| } | |||
| for method_name in methods: | |||
| logger.info(f"Testing method {method_name}") | |||
| test_info["method_name"] = method_name | |||
| test_info.update(method_configs[method_name]) | |||
| self.test_method(test_info, recorders, loss_func=loss_func_rmse) | |||
| for method, recorder in recorders.items(): | |||
| recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) | |||
| plot_performance_curves( | |||
| self.curves_result_path, user, recorders, task="Hetero", n_labeled_list=hetero_n_labeled_list | |||
| ) | |||
| @@ -0,0 +1,167 @@ | |||
| import os | |||
| import warnings | |||
| import numpy as np | |||
| from base import TableWorkflow | |||
| from config import homo_n_labeled_list, homo_n_repeat_list | |||
| from methods import loss_func_rmse | |||
| from utils import Recorder, plot_performance_curves | |||
| from learnware.logger import get_module_logger | |||
| from learnware.market import BaseUserInfo | |||
| from learnware.reuse import AveragingReuser, JobSelectorReuser | |||
| from learnware.specification import generate_stat_spec | |||
| warnings.filterwarnings("ignore") | |||
| logger = get_module_logger("homo_table", level="INFO") | |||
| class HomogeneousDatasetWorkflow(TableWorkflow): | |||
| def unlabeled_homo_table_example(self): | |||
| logger.info("Total Item: %d" % (len(self.market))) | |||
| learnware_rmse_list = [] | |||
| single_score_list = [] | |||
| job_selector_score_list = [] | |||
| ensemble_score_list = [] | |||
| all_learnwares = self.market.get_learnwares() | |||
| user = self.benchmark.name | |||
| for idx in range(self.benchmark.user_num): | |||
| test_x, test_y = self.benchmark.get_test_data(user_ids=idx) | |||
| test_x, test_y = test_x.values, test_y.values | |||
| user_stat_spec = generate_stat_spec(type="table", X=test_x) | |||
| user_info = BaseUserInfo(semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec}) | |||
| logger.info(f"Searching Market for user: {user}_{idx}") | |||
| search_result = self.market.search_learnware(user_info, max_search_num=2) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| logger.info(f"search result of user {user}_{idx}:") | |||
| logger.info( | |||
| f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" | |||
| ) | |||
| pred_y = single_result[0].learnware.predict(test_x) | |||
| single_score_list.append(loss_func_rmse(pred_y, test_y)) | |||
| rmse_list = [] | |||
| for learnware in all_learnwares: | |||
| semantic_spec = learnware.specification.get_semantic_spec() | |||
| if semantic_spec["Input"]["Dimension"] == test_x.shape[1]: | |||
| pred_y = learnware.predict(test_x) | |||
| rmse_list.append(loss_func_rmse(pred_y, test_y)) | |||
| logger.info( | |||
| f"Top1-score: {single_result[0].score}, learnware_id: {single_result[0].learnware.id}, rmse: {single_score_list[-1]}" | |||
| ) | |||
| if len(multiple_result) > 0: | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | |||
| logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") | |||
| mixture_learnware_list = multiple_result[0].learnwares | |||
| else: | |||
| mixture_learnware_list = [single_result[0].learnware] | |||
| # test reuse (job selector) | |||
| reuse_baseline = JobSelectorReuser(learnware_list=mixture_learnware_list, herding_num=100) | |||
| reuse_predict = reuse_baseline.predict(user_data=test_x) | |||
| reuse_score = loss_func_rmse(reuse_predict, test_y) | |||
| job_selector_score_list.append(reuse_score) | |||
| logger.info(f"mixture reuse rmse (job selector): {reuse_score}") | |||
| # test reuse (ensemble) | |||
| reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="mean") | |||
| ensemble_predict_y = reuse_ensemble.predict(user_data=test_x) | |||
| ensemble_score = loss_func_rmse(ensemble_predict_y, test_y) | |||
| ensemble_score_list.append(ensemble_score) | |||
| logger.info(f"mixture reuse rmse (ensemble): {ensemble_score}") | |||
| learnware_rmse_list.append(rmse_list) | |||
| single_list = np.array(learnware_rmse_list) | |||
| avg_score_list = [np.mean(lst, axis=0) for lst in single_list] | |||
| oracle_score_list = [np.min(lst, axis=0) for lst in single_list] | |||
| logger.info( | |||
| "RMSE of selected learnware: %.3f +/- %.3f, Average performance: %.3f +/- %.3f, Oracle performace: %.3f +/- %.3f" | |||
| % ( | |||
| np.mean(single_score_list), | |||
| np.std(single_score_list), | |||
| np.mean(avg_score_list), | |||
| np.std(avg_score_list), | |||
| np.mean(oracle_score_list), | |||
| np.std(oracle_score_list), | |||
| ) | |||
| ) | |||
| logger.info( | |||
| "Average Job Selector Reuse Performance: %.3f +/- %.3f" | |||
| % (np.mean(job_selector_score_list), np.std(job_selector_score_list)) | |||
| ) | |||
| logger.info( | |||
| "Averaging Ensemble Reuse Performance: %.3f +/- %.3f" | |||
| % (np.mean(ensemble_score_list), np.std(ensemble_score_list)) | |||
| ) | |||
| def labeled_homo_table_example(self, skip_test): | |||
| logger.info("Total Item: %d" % (len(self.market))) | |||
| methods = ["user_model", "homo_single_aug", "homo_ensemble_pruning"] | |||
| recorders = {method: Recorder() for method in methods} | |||
| user = self.benchmark.name | |||
| if not skip_test: | |||
| for idx in range(self.benchmark.user_num): | |||
| test_x, test_y = self.benchmark.get_test_data(user_ids=idx) | |||
| test_x, test_y = test_x.values, test_y.values | |||
| train_x, train_y = self.benchmark.get_train_data(user_ids=idx) | |||
| train_x, train_y = train_x.values, train_y.values | |||
| train_subsets = self.get_train_subsets(homo_n_labeled_list, homo_n_repeat_list, train_x, train_y) | |||
| user_stat_spec = generate_stat_spec(type="table", X=test_x) | |||
| user_info = BaseUserInfo( | |||
| semantic_spec=self.user_semantic, stat_info={"RKMETableSpecification": user_stat_spec} | |||
| ) | |||
| logger.info(f"Searching Market for user: {user}_{idx}") | |||
| search_result = self.market.search_learnware(user_info) | |||
| single_result = search_result.get_single_results() | |||
| multiple_result = search_result.get_multiple_results() | |||
| logger.info(f"search result of user {user}_{idx}:") | |||
| logger.info( | |||
| f"single model num: {len(single_result)}, max_score: {single_result[0].score}, min_score: {single_result[-1].score}" | |||
| ) | |||
| if len(multiple_result) > 0: | |||
| mixture_id = " ".join([learnware.id for learnware in multiple_result[0].learnwares]) | |||
| logger.info(f"mixture_score: {multiple_result[0].score}, mixture_learnware: {mixture_id}") | |||
| mixture_learnware_list = multiple_result[0].learnwares | |||
| else: | |||
| mixture_learnware_list = [single_result[0].learnware] | |||
| test_info = { | |||
| "user": user, | |||
| "idx": idx, | |||
| "train_subsets": train_subsets, | |||
| "test_x": test_x, | |||
| "test_y": test_y, | |||
| } | |||
| common_config = {"learnwares": mixture_learnware_list} | |||
| method_configs = { | |||
| "user_model": {"dataset": self.benchmark.name, "model_type": "lgb"}, | |||
| "homo_single_aug": {"single_learnware": [single_result[0].learnware]}, | |||
| "homo_ensemble_pruning": common_config, | |||
| } | |||
| for method_name in methods: | |||
| logger.info(f"Testing method {method_name}") | |||
| test_info["method_name"] = method_name | |||
| test_info.update(method_configs[method_name]) | |||
| self.test_method(test_info, recorders, loss_func=loss_func_rmse) | |||
| for method, recorder in recorders.items(): | |||
| recorder.save(os.path.join(self.curves_result_path, f"{user}/{user}_{method}_performance.json")) | |||
| plot_performance_curves( | |||
| self.curves_result_path, user, recorders, task="Homo", n_labeled_list=homo_n_labeled_list | |||
| ) | |||
| @@ -0,0 +1,110 @@ | |||
| import numpy as np | |||
| from config import align_model_params | |||
| from sklearn.metrics import mean_squared_error | |||
| from sklearn.model_selection import train_test_split | |||
| from train import train_model | |||
| from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, HeteroMapAlignLearnware | |||
| def loss_func_rmse(y_true, y_pred): | |||
| return np.sqrt(mean_squared_error(y_true, y_pred)) | |||
| def user_model_score(x_train, y_train, test_info): | |||
| x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42) | |||
| user_model = train_model(x_train, y_train, x_val, y_val, test_info) | |||
| return user_model | |||
| class HomoScoringMethods: | |||
| @staticmethod | |||
| def single_aug_score(x_train, y_train, test_info): | |||
| single_learnware = test_info["single_learnware"] | |||
| reuse_single_augment = FeatureAugmentReuser(single_learnware, mode="regression") | |||
| reuse_single_augment.fit(x_train=x_train, y_train=y_train) | |||
| return reuse_single_augment | |||
| @staticmethod | |||
| def multiple_aug_score(x_train, y_train, test_info): | |||
| multiple_learnwares = test_info["learnwares"] | |||
| reuse_multiple_augment = FeatureAugmentReuser(multiple_learnwares, mode="regression") | |||
| reuse_multiple_augment.fit(x_train=x_train, y_train=y_train) | |||
| return reuse_multiple_augment | |||
| @staticmethod | |||
| def multiple_avg_score(x_train, y_train, test_info): | |||
| multiple_learnwares = test_info["learnwares"] | |||
| reuse_multiple_avg = AveragingReuser(multiple_learnwares, mode="mean") | |||
| return reuse_multiple_avg | |||
| @staticmethod | |||
| def multiple_ensemble_pruning_score(x_train, y_train, test_info): | |||
| multiple_learnwares = test_info["learnwares"] | |||
| if len(multiple_learnwares) == 1: | |||
| return multiple_learnwares[0] | |||
| reuse_pruning = EnsemblePruningReuser(multiple_learnwares, mode="regression") | |||
| reuse_pruning.fit(val_X=x_train, val_y=y_train) | |||
| return reuse_pruning | |||
| class HeteroMethods: | |||
| @staticmethod | |||
| def create_hetero_learnware_list(learnware_list, user_rkme, x_train, y_train): | |||
| hetero_learnware_list = [] | |||
| for learnware in learnware_list: | |||
| hetero_learnware = HeteroMapAlignLearnware(learnware, mode="regression", **align_model_params) | |||
| hetero_learnware.align(user_rkme, x_train, y_train) | |||
| hetero_learnware_list.append(hetero_learnware) | |||
| return hetero_learnware_list | |||
| @staticmethod | |||
| def single_aug_score(x_train, y_train, test_info): | |||
| user_rkme, single_learnware = test_info["user_rkme"], test_info["single_learnware"] | |||
| reuse_single_augment = HeteroMapAlignLearnware(single_learnware, mode="regression", **align_model_params) | |||
| reuse_single_augment.align(user_rkme=user_rkme, x_train=x_train, y_train=y_train) | |||
| return reuse_single_augment | |||
| @staticmethod | |||
| def multiple_aug_score(x_train, y_train, test_info): | |||
| user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"] | |||
| hetero_learnware_list = HeteroMethods.create_hetero_learnware_list( | |||
| multiple_learnwares, user_rkme, x_train, y_train | |||
| ) | |||
| reuse_multiple_augment = FeatureAugmentReuser(hetero_learnware_list, mode="regression") | |||
| reuse_multiple_augment.fit(x_train=x_train, y_train=y_train) | |||
| return reuse_multiple_augment | |||
| @staticmethod | |||
| def multiple_ensemble_pruning_score(x_train, y_train, test_info): | |||
| user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"] | |||
| hetero_learnware_list = HeteroMethods.create_hetero_learnware_list( | |||
| multiple_learnwares, user_rkme, x_train, y_train | |||
| ) | |||
| if len(hetero_learnware_list) == 1: | |||
| return hetero_learnware_list[0] | |||
| reuse_pruning = EnsemblePruningReuser(hetero_learnware_list, mode="regression") | |||
| reuse_pruning.fit(val_X=x_train, val_y=y_train) | |||
| return reuse_pruning | |||
| @staticmethod | |||
| def multiple_avg_score(x_train, y_train, test_info): | |||
| user_rkme, multiple_learnwares = test_info["user_rkme"], test_info["learnwares"] | |||
| hetero_learnware_list = HeteroMethods.create_hetero_learnware_list( | |||
| multiple_learnwares, user_rkme, x_train, y_train | |||
| ) | |||
| reuse_multiple_avg = AveragingReuser(hetero_learnware_list, mode="mean") | |||
| return reuse_multiple_avg | |||
| test_methods = { | |||
| "user_model": user_model_score, | |||
| "hetero_single_aug": HeteroMethods.single_aug_score, | |||
| "hetero_multiple_aug": HeteroMethods.multiple_aug_score, | |||
| "hetero_multiple_avg": HeteroMethods.multiple_avg_score, | |||
| "hetero_ensemble_pruning": HeteroMethods.multiple_ensemble_pruning_score, | |||
| "homo_single_aug": HomoScoringMethods.single_aug_score, | |||
| "homo_multiple_aug": HomoScoringMethods.multiple_aug_score, | |||
| "homo_multiple_avg": HomoScoringMethods.multiple_avg_score, | |||
| "homo_ensemble_pruning": HomoScoringMethods.multiple_ensemble_pruning_score, | |||
| } | |||
| @@ -0,0 +1 @@ | |||
| lightgbm==3.3.5 | |||
| @@ -0,0 +1,41 @@ | |||
| import lightgbm as lgb | |||
| from config import user_model_params | |||
| from lightgbm import early_stopping | |||
| from learnware.logger import get_module_logger | |||
| logger = get_module_logger("train_table", level="INFO") | |||
| def train_lgb(X_train, y_train, X_val, y_val, dataset): | |||
| model_param = user_model_params[dataset]["lgb"] | |||
| params = model_param["params"] | |||
| MAX_ROUNDS = model_param["MAX_ROUNDS"] | |||
| val_pred = [] | |||
| cate_vars = [] | |||
| dtrain = lgb.Dataset(X_train, label=y_train, categorical_feature=cate_vars) | |||
| dval = lgb.Dataset(X_val, label=y_val, reference=dtrain, categorical_feature=cate_vars) | |||
| bst = lgb.train( | |||
| params, | |||
| dtrain, | |||
| num_boost_round=MAX_ROUNDS, | |||
| valid_sets=[dtrain, dval] if dataset == "Corporacion" else [dval], | |||
| callbacks=[early_stopping(model_param["early_stopping_rounds"], verbose=False)], | |||
| ) | |||
| val_pred.append(bst.predict(X_val, num_iteration=bst.best_iteration or MAX_ROUNDS)) | |||
| return bst | |||
| def train_ridge(X_train, y_train, X_val, y_val, dataset): | |||
| pass | |||
| def train_model(X_train, y_train, X_val, y_val, test_info): | |||
| dataset = test_info["dataset"] | |||
| model_type = test_info["model_type"] | |||
| assert model_type in ["lgb", "ridge"] | |||
| if model_type == "lgb": | |||
| return train_lgb(X_train, y_train, X_val, y_val, dataset) | |||
| @@ -0,0 +1,94 @@ | |||
| import json | |||
| import os | |||
| import random | |||
| from collections import defaultdict | |||
| import matplotlib.pyplot as plt | |||
| import numpy as np | |||
| import torch | |||
| from config import labels, styles | |||
| from learnware.logger import get_module_logger | |||
| logger = get_module_logger("base_table", level="INFO") | |||
| class Recorder: | |||
| def __init__(self, headers=["Mean", "Std Dev"], formats=["{:.2f}", "{:.2f}"]): | |||
| assert len(headers) == len(formats), "Headers and formats length must match." | |||
| self.data = defaultdict(list) | |||
| self.headers = headers | |||
| self.formats = formats | |||
| def record(self, user, scores): | |||
| self.data[user].append(scores) | |||
| def get_performance_data(self, user): | |||
| return self.data.get(user, []) | |||
| def save(self, path): | |||
| with open(path, "w") as f: | |||
| json.dump(self.data, f, indent=4, default=list) | |||
| def load(self, path): | |||
| with open(path, "r") as f: | |||
| self.data = json.load(f, object_hook=lambda x: defaultdict(list, x)) | |||
| def should_test_method(self, user, idx, path): | |||
| if os.path.exists(path): | |||
| self.load(path) | |||
| return user not in self.data or idx > len(self.data[user]) - 1 | |||
| return True | |||
| def plot_performance_curves(path, user, recorders, task, n_labeled_list): | |||
| plt.figure(figsize=(10, 6)) | |||
| plt.xticks(range(len(n_labeled_list)), n_labeled_list) | |||
| for method, recorder in recorders.items(): | |||
| data_path = os.path.join(path, f"{user}/{user}_{method}_performance.json") | |||
| recorder.load(data_path) | |||
| scores_array = recorder.get_performance_data(user) | |||
| mean_curve, std_curve = [], [] | |||
| for i in range(len(n_labeled_list)): | |||
| sub_scores_array = np.vstack([lst[i] for lst in scores_array]) | |||
| sub_scores_mean = np.squeeze(np.mean(sub_scores_array, axis=0)) | |||
| mean_curve.append(np.mean(sub_scores_mean)) | |||
| std_curve.append(np.std(sub_scores_mean)) | |||
| mean_curve = np.array(mean_curve) | |||
| std_curve = np.array(std_curve) | |||
| method_plot = ( | |||
| "_".join(method.split("_")[1:]) | |||
| if method not in ["user_model", "oracle_score", "select_score", "mean_score"] | |||
| else method | |||
| ) | |||
| style = styles.get(method_plot, {"color": "black", "linestyle": "-"}) | |||
| plt.plot(mean_curve, label=labels.get(method_plot), **style) | |||
| plt.fill_between( | |||
| range(len(mean_curve)), mean_curve - std_curve, mean_curve + std_curve, color=style["color"], alpha=0.2 | |||
| ) | |||
| plt.xlabel("Amount of Labeled User Data", fontsize=14) | |||
| plt.ylabel("RMSE", fontsize=14) | |||
| plt.title(f"Results on {task} Table Experimental Scenario", fontsize=16) | |||
| plt.legend(fontsize=12) | |||
| plt.tight_layout() | |||
| root_path = os.path.abspath(os.path.join(__file__, "..")) | |||
| fig_path = os.path.join(root_path, "results", "figs") | |||
| os.makedirs(fig_path, exist_ok=True) | |||
| plt.savefig(os.path.join(fig_path, f"{task}_labeled_curves.svg"), bbox_inches="tight", dpi=700) | |||
| def set_seed(seed): | |||
| random.seed(seed) | |||
| os.environ["PYTHONHASHSEED"] = str(seed) | |||
| np.random.seed(seed) | |||
| torch.manual_seed(seed) | |||
| torch.cuda.manual_seed(seed) | |||
| torch.cuda.manual_seed_all(seed) | |||
| torch.backends.cudnn.benchmark = False | |||
| torch.backends.cudnn.deterministic = True | |||
| @@ -0,0 +1,42 @@ | |||
| import fire | |||
| from config import ( | |||
| hetero_cross_feat_eng_benchmark_config, | |||
| hetero_cross_task_benchmark_config, | |||
| homo_table_benchmark_config, | |||
| ) | |||
| from hetero import HeterogeneousDatasetWorkflow | |||
| from homo import HomogeneousDatasetWorkflow | |||
| from learnware.logger import get_module_logger | |||
| logger = get_module_logger("base_table", level="INFO") | |||
| class TableDatasetWorkflow: | |||
| def unlabeled_homo_table_example(self, rebuild=True): | |||
| workflow = HomogeneousDatasetWorkflow( | |||
| benchmark_config=homo_table_benchmark_config, name="easy", rebuild=rebuild | |||
| ) | |||
| workflow.unlabeled_homo_table_example() | |||
| def labeled_homo_table_example(self, skip_test=False, rebuild=True): | |||
| workflow = HomogeneousDatasetWorkflow( | |||
| benchmark_config=homo_table_benchmark_config, name="easy", rebuild=rebuild | |||
| ) | |||
| workflow.labeled_homo_table_example(skip_test=skip_test) | |||
| def cross_feat_eng_hetero_table_example(self, rebuild=True, retrain=True): | |||
| workflow = HeterogeneousDatasetWorkflow( | |||
| benchmark_config=hetero_cross_feat_eng_benchmark_config, name="hetero", rebuild=rebuild, retrain=retrain | |||
| ) | |||
| workflow.unlabeled_hetero_table_example() | |||
| def cross_task_hetero_table_example(self, skip_test=False, rebuild=True, retrain=True): | |||
| workflow = HeterogeneousDatasetWorkflow( | |||
| benchmark_config=hetero_cross_task_benchmark_config, name="hetero", rebuild=rebuild, retrain=retrain | |||
| ) | |||
| workflow.labeled_hetero_table_example(skip_test=skip_test) | |||
| if __name__ == "__main__": | |||
| fire.Fire(TableDatasetWorkflow) | |||
| @@ -1,6 +1,5 @@ | |||
| from learnware.tests.benchmarks import BenchmarkConfig | |||
| text_benchmark_config = BenchmarkConfig( | |||
| name="20-Newsgroups", | |||
| user_num=10, | |||
| @@ -1,22 +1,23 @@ | |||
| import os | |||
| import fire | |||
| import time | |||
| import random | |||
| import pickle | |||
| import random | |||
| import tempfile | |||
| import numpy as np | |||
| import time | |||
| import fire | |||
| import matplotlib.pyplot as plt | |||
| import numpy as np | |||
| from config import text_benchmark_config | |||
| from sklearn.feature_extraction.text import TfidfVectorizer | |||
| from sklearn.metrics import accuracy_score | |||
| from sklearn.naive_bayes import MultinomialNB | |||
| from sklearn.feature_extraction.text import TfidfVectorizer | |||
| from learnware.client import LearnwareClient | |||
| from learnware.logger import get_module_logger | |||
| from learnware.market import BaseUserInfo, instantiate_learnware_market | |||
| from learnware.reuse import AveragingReuser, EnsemblePruningReuser, JobSelectorReuser | |||
| from learnware.specification import RKMETextSpecification | |||
| from learnware.tests.benchmarks import LearnwareBenchmark | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser | |||
| from config import text_benchmark_config | |||
| logger = get_module_logger("text_workflow", level="INFO") | |||
| @@ -2,10 +2,11 @@ import os | |||
| import pickle | |||
| import tempfile | |||
| import zipfile | |||
| import numpy as np | |||
| from dataclasses import dataclass | |||
| from typing import List, Optional, Tuple, Union | |||
| import numpy as np | |||
| from .config import BenchmarkConfig, benchmark_configs | |||
| from ..data import GetData | |||
| from ...config import C | |||
| @@ -1,5 +1,5 @@ | |||
| from dataclasses import dataclass | |||
| from typing import List, Optional, Dict | |||
| from typing import Dict, List, Optional | |||
| @dataclass | |||
| @@ -1,4 +1,5 @@ | |||
| import os | |||
| from setuptools import find_packages, setup | |||
| @@ -1,12 +1,12 @@ | |||
| import logging | |||
| import os | |||
| import unittest | |||
| import tempfile | |||
| import logging | |||
| import unittest | |||
| import learnware | |||
| from learnware.learnware import Learnware | |||
| from learnware.client import LearnwareClient | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.learnware import Learnware | |||
| from learnware.market import BaseUserInfo, instantiate_learnware_market | |||
| learnware.init(logging_level=logging.WARNING) | |||
| @@ -1,11 +1,11 @@ | |||
| import os | |||
| import json | |||
| import unittest | |||
| import os | |||
| import tempfile | |||
| import unittest | |||
| from learnware.client import LearnwareClient | |||
| from learnware.specification import generate_semantic_spec | |||
| from learnware.market import BaseUserInfo | |||
| from learnware.specification import generate_semantic_spec | |||
| class TestAllLearnware(unittest.TestCase): | |||
| @@ -1,6 +1,6 @@ | |||
| import os | |||
| import unittest | |||
| import tempfile | |||
| import unittest | |||
| from learnware.client import LearnwareClient | |||
| @@ -1,4 +1,5 @@ | |||
| import unittest | |||
| import numpy as np | |||
| from learnware.client import LearnwareClient | |||
| @@ -1,5 +1,6 @@ | |||
| import os | |||
| import unittest | |||
| import numpy as np | |||
| from learnware.client import LearnwareClient | |||
| @@ -1,7 +1,7 @@ | |||
| import os | |||
| import json | |||
| import unittest | |||
| import os | |||
| import tempfile | |||
| import unittest | |||
| from learnware.client import LearnwareClient | |||
| from learnware.specification import generate_semantic_spec | |||
| @@ -1,12 +1,12 @@ | |||
| import os | |||
| import json | |||
| import unittest | |||
| import os | |||
| import tempfile | |||
| import unittest | |||
| import numpy as np | |||
| from learnware.specification import RKMETableSpecification, HeteroMapTableSpecification | |||
| from learnware.specification import generate_stat_spec | |||
| from learnware.market.heterogeneous.organizer import HeteroMap | |||
| from learnware.specification import HeteroMapTableSpecification, RKMETableSpecification, generate_stat_spec | |||
| class TestTableRKME(unittest.TestCase): | |||
| @@ -1,12 +1,12 @@ | |||
| import os | |||
| import json | |||
| import torch | |||
| import unittest | |||
| import os | |||
| import tempfile | |||
| import unittest | |||
| import numpy as np | |||
| import torch | |||
| from learnware.specification import RKMEImageSpecification | |||
| from learnware.specification import generate_stat_spec | |||
| from learnware.specification import RKMEImageSpecification, generate_stat_spec | |||
| class TestImageRKME(unittest.TestCase): | |||
| @@ -1,11 +1,11 @@ | |||
| import os | |||
| import json | |||
| import unittest | |||
| import os | |||
| import tempfile | |||
| import unittest | |||
| import numpy as np | |||
| from learnware.specification import RKMETableSpecification | |||
| from learnware.specification import generate_stat_spec | |||
| from learnware.specification import RKMETableSpecification, generate_stat_spec | |||
| class TestTableRKME(unittest.TestCase): | |||
| @@ -1,12 +1,11 @@ | |||
| import os | |||
| import json | |||
| import string | |||
| import os | |||
| import random | |||
| import unittest | |||
| import string | |||
| import tempfile | |||
| import unittest | |||
| from learnware.specification import RKMETextSpecification | |||
| from learnware.specification import generate_stat_spec | |||
| from learnware.specification import RKMETextSpecification, generate_stat_spec | |||
| class TestTextRKME(unittest.TestCase): | |||
| @@ -1,22 +1,22 @@ | |||
| import torch | |||
| import pickle | |||
| import unittest | |||
| import os | |||
| import logging | |||
| import os | |||
| import pickle | |||
| import tempfile | |||
| import unittest | |||
| import zipfile | |||
| from sklearn.linear_model import Ridge | |||
| import torch | |||
| from hetero_config import input_description_list, input_shape_list, output_description_list, user_description_list | |||
| from sklearn.datasets import make_regression | |||
| from sklearn.linear_model import Ridge | |||
| from sklearn.metrics import mean_squared_error | |||
| import learnware | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.market import BaseUserInfo, instantiate_learnware_market | |||
| from learnware.reuse import AveragingReuser, EnsemblePruningReuser, HeteroMapAlignLearnware | |||
| from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec | |||
| from learnware.reuse import HeteroMapAlignLearnware, AveragingReuser, EnsemblePruningReuser | |||
| from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate | |||
| from hetero_config import input_shape_list, input_description_list, output_description_list, user_description_list | |||
| learnware.init(logging_level=logging.WARNING) | |||
| curr_root = os.path.dirname(os.path.abspath(__file__)) | |||
| @@ -1,18 +1,19 @@ | |||
| import unittest | |||
| import os | |||
| import logging | |||
| import tempfile | |||
| import os | |||
| import pickle | |||
| import tempfile | |||
| import unittest | |||
| import zipfile | |||
| import numpy as np | |||
| from sklearn import svm | |||
| from sklearn.datasets import load_digits | |||
| from sklearn.model_selection import train_test_split | |||
| import learnware | |||
| from learnware.market import instantiate_learnware_market, BaseUserInfo | |||
| from learnware.market import BaseUserInfo, instantiate_learnware_market | |||
| from learnware.reuse import AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser, JobSelectorReuser | |||
| from learnware.specification import RKMETableSpecification, generate_rkme_table_spec, generate_semantic_spec | |||
| from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser, FeatureAugmentReuser | |||
| from learnware.tests.templates import LearnwareTemplate, PickleModelTemplate, StatSpecTemplate | |||
| learnware.init(logging_level=logging.WARNING) | |||