Browse Source

[MNT] rm extra test_workflow

tags/v0.3.2
Gene 2 years ago
parent
commit
aba490b31d
5 changed files with 0 additions and 288 deletions
  1. +0
    -10
      tests/test_workflow/learnware_example/README.md
  2. +0
    -27
      tests/test_workflow/learnware_example/environment.yaml
  3. +0
    -8
      tests/test_workflow/learnware_example/example.yaml
  4. +0
    -20
      tests/test_workflow/learnware_example/example_init.py
  5. +0
    -223
      tests/test_workflow/test_workflow.py

+ 0
- 10
tests/test_workflow/learnware_example/README.md View File

@@ -1,10 +0,0 @@
## How to Generate Environment Yaml

* create env config for conda:
```shell
conda env export | grep -v "^prefix: " > environment.yml
```
* recover env from config
```
conda env create -f environment.yml
```

+ 0
- 27
tests/test_workflow/learnware_example/environment.yaml View File

@@ -1,27 +0,0 @@
name: learnware_example_env
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- ca-certificates=2023.01.10=h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.2=h6a678d5_6
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- ncurses=6.4=h6a678d5_0
- openssl=1.1.1t=h7f8727e_0
- pip=23.0.1=py38h06a4308_0
- python=3.8.16=h7a1cb2a_3
- readline=8.2=h5eee18b_0
- setuptools=66.0.0=py38h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.38.4=py38h06a4308_0
- xz=5.2.10=h5eee18b_1
- zlib=1.2.13=h5eee18b_0
- pip:
- joblib==1.2.0
- learnware==0.0.1.99
- numpy==1.19.5

+ 0
- 8
tests/test_workflow/learnware_example/example.yaml View File

@@ -1,8 +0,0 @@
model:
class_name: SVM
kwargs: {}
stat_specifications:
- module_path: learnware.specification
class_name: RKMETableSpecification
file_name: svm.json
kwargs: {}

+ 0
- 20
tests/test_workflow/learnware_example/example_init.py View File

@@ -1,20 +0,0 @@
import os
import joblib
import numpy as np
from learnware.model import BaseModel


class SVM(BaseModel):
def __init__(self):
super(SVM, self).__init__(input_shape=(64,), output_shape=(10,))
dir_path = os.path.dirname(os.path.abspath(__file__))
self.model = joblib.load(os.path.join(dir_path, "svm.pkl"))

def fit(self, X: np.ndarray, y: np.ndarray):
pass

def predict(self, X: np.ndarray) -> np.ndarray:
return self.model.predict_proba(X)

def finetune(self, X: np.ndarray, y: np.ndarray):
pass

+ 0
- 223
tests/test_workflow/test_workflow.py View File

@@ -1,223 +0,0 @@
import sys
import unittest
import os
import copy
import joblib
import zipfile
import numpy as np
from sklearn import svm
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from shutil import copyfile, rmtree

import learnware
from learnware.market import EasyMarket, BaseUserInfo
from learnware.reuse import JobSelectorReuser, AveragingReuser, EnsemblePruningReuser
from learnware.specification import RKMETableSpecification, generate_rkme_spec

curr_root = os.path.dirname(os.path.abspath(__file__))

user_semantic = {
"Data": {"Values": ["Image"], "Type": "Class"},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Library": {"Values": ["Scikit-learn"], "Type": "Class"},
"Scenario": {"Values": ["Education"], "Type": "Tag"},
"Description": {"Values": "", "Type": "String"},
"Name": {"Values": "", "Type": "String"},
}


class TestAllWorkflow(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
np.random.seed(2023)
learnware.init()

def _init_learnware_market(self):
"""initialize learnware market"""
easy_market = EasyMarket(market_id="sklearn_digits", rebuild=True)
return easy_market

def test_prepare_learnware_randomly(self, learnware_num=5):
self.zip_path_list = []
X, y = load_digits(return_X_y=True)

for i in range(learnware_num):
dir_path = os.path.join(curr_root, "learnware_pool", "svm_%d" % (i))
os.makedirs(dir_path, exist_ok=True)

print("Preparing Learnware: %d" % (i))

data_X, _, data_y, _ = train_test_split(X, y, test_size=0.3, shuffle=True)
clf = svm.SVC(kernel="linear", probability=True)
clf.fit(data_X, data_y)

joblib.dump(clf, os.path.join(dir_path, "svm.pkl"))

spec = generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
spec.save(os.path.join(dir_path, "svm.json"))

init_file = os.path.join(dir_path, "__init__.py")
copyfile(
os.path.join(curr_root, "learnware_example/example_init.py"), init_file
) # cp example_init.py init_file

yaml_file = os.path.join(dir_path, "learnware.yaml")
copyfile(os.path.join(curr_root, "learnware_example/example.yaml"), yaml_file) # cp example.yaml yaml_file

env_file = os.path.join(dir_path, "environment.yaml")
copyfile(os.path.join(curr_root, "learnware_example/environment.yaml"), env_file)

zip_file = dir_path + ".zip"
# zip -q -r -j zip_file dir_path
with zipfile.ZipFile(zip_file, "w") as zip_obj:
for foldername, subfolders, filenames in os.walk(dir_path):
for filename in filenames:
file_path = os.path.join(foldername, filename)
zip_info = zipfile.ZipInfo(filename)
zip_info.compress_type = zipfile.ZIP_STORED
with open(file_path, "rb") as file:
zip_obj.writestr(zip_info, file.read())

rmtree(dir_path) # rm -r dir_path

self.zip_path_list.append(zip_file)

def test_upload_delete_learnware(self, learnware_num=5, delete=False):
easy_market = self._init_learnware_market()
self.test_prepare_learnware_randomly(learnware_num)

print("Total Item:", len(easy_market))

for idx, zip_path in enumerate(self.zip_path_list):
semantic_spec = copy.deepcopy(user_semantic)
semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
semantic_spec["Output"] = {
"Dimension": 10,
"Description": {f"{i}": "The probability for each digit for 0 to 9." for i in range(10)},
}
easy_market.add_learnware(zip_path, semantic_spec)

print("Total Item:", len(easy_market))
curr_inds = easy_market._get_ids()
print("Available ids After Uploading Learnwares:", curr_inds)

if delete:
for learnware_id in curr_inds:
easy_market.delete_learnware(learnware_id)
curr_inds = easy_market._get_ids()
print("Available ids After Deleting Learnwares:", curr_inds)

return easy_market

def test_search_semantics(self, learnware_num=5):
easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
print("Total Item:", len(easy_market))

test_folder = os.path.join(curr_root, "test_semantics")

# unzip -o -q zip_path -d unzip_dir
if os.path.exists(test_folder):
rmtree(test_folder)
os.makedirs(test_folder, exist_ok=True)

with zipfile.ZipFile(self.zip_path_list[0], "r") as zip_obj:
zip_obj.extractall(path=test_folder)

semantic_spec = copy.deepcopy(user_semantic)
semantic_spec["Name"]["Values"] = f"learnware_{learnware_num - 1}"
semantic_spec["Description"]["Values"] = f"test_learnware_number_{learnware_num - 1}"

user_info = BaseUserInfo(semantic_spec=semantic_spec)
_, single_learnware_list, _, _ = easy_market.search_learnware(user_info)

print("User info:", user_info.get_semantic_spec())
print(f"Search result:")
for learnware in single_learnware_list:
print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec())

rmtree(test_folder) # rm -r test_folder

def test_stat_search(self, learnware_num=5):
easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
print("Total Item:", len(easy_market))

test_folder = os.path.join(curr_root, "test_stat")

for idx, zip_path in enumerate(self.zip_path_list):
unzip_dir = os.path.join(test_folder, f"{idx}")

# unzip -o -q zip_path -d unzip_dir
if os.path.exists(unzip_dir):
rmtree(unzip_dir)
os.makedirs(unzip_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir)

user_spec = RKMETableSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": user_spec})
(
sorted_score_list,
single_learnware_list,
mixture_score,
mixture_learnware_list,
) = easy_market.search_learnware(user_info)

print(f"search result of user{idx}:")
for score, learnware in zip(sorted_score_list, single_learnware_list):
print(f"score: {score}, learnware_id: {learnware.id}")
print(f"mixture_score: {mixture_score}\n")
mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
print(f"mixture_learnware: {mixture_id}\n")

rmtree(test_folder) # rm -r test_folder

def test_learnware_reuse(self, learnware_num=5):
easy_market = self.test_upload_delete_learnware(learnware_num, delete=False)
print("Total Item:", len(easy_market))

X, y = load_digits(return_X_y=True)
train_X, data_X, train_y, data_y = train_test_split(X, y, test_size=0.3, shuffle=True)

stat_spec = generate_rkme_spec(X=data_X, gamma=0.1, cuda_idx=0)
user_info = BaseUserInfo(semantic_spec=user_semantic, stat_info={"RKMETableSpecification": stat_spec})

_, _, _, mixture_learnware_list = easy_market.search_learnware(user_info)

# Based on user information, the learnware market returns a list of learnwares (learnware_list)
# Use jobselector reuser to reuse the searched learnwares to make prediction
reuse_job_selector = JobSelectorReuser(learnware_list=mixture_learnware_list)
job_selector_predict_y = reuse_job_selector.predict(user_data=data_X)

# Use averaging ensemble reuser to reuse the searched learnwares to make prediction
reuse_ensemble = AveragingReuser(learnware_list=mixture_learnware_list, mode="vote_by_prob")
ensemble_predict_y = reuse_ensemble.predict(user_data=data_X)

# Use ensemble pruning reuser to reuse the searched learnwares to make prediction
reuse_ensemble = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification")
reuse_ensemble.fit(train_X[-200:], train_y[-200:])
ensemble_pruning_predict_y = reuse_ensemble.predict(user_data=data_X)

print("Job Selector Acc:", np.sum(np.argmax(job_selector_predict_y, axis=1) == data_y) / len(data_y))
print("Averaging Reuser Acc:", np.sum(np.argmax(ensemble_predict_y, axis=1) == data_y) / len(data_y))
print("Ensemble Pruning Reuser Acc:", np.sum(ensemble_pruning_predict_y == data_y) / len(data_y))


def suite():
_suite = unittest.TestSuite()
_suite.addTest(TestAllWorkflow("test_prepare_learnware_randomly"))
_suite.addTest(TestAllWorkflow("test_upload_delete_learnware"))
_suite.addTest(TestAllWorkflow("test_search_semantics"))
_suite.addTest(TestAllWorkflow("test_stat_search"))
_suite.addTest(TestAllWorkflow("test_learnware_reuse"))
return _suite


if __name__ == "__main__":
runner = unittest.TextTestRunner()
runner.run(suite())

Loading…
Cancel
Save