diff --git a/examples/workflow_by_code/main.py b/examples/workflow_by_code/main.py index 22b6f0d..e835436 100644 --- a/examples/workflow_by_code/main.py +++ b/examples/workflow_by_code/main.py @@ -1,10 +1,12 @@ -import fire import os +import fire import joblib +import zipfile import numpy as np -import learnware - from sklearn import svm +from shutil import copyfile, rmtree + +import learnware from learnware.market import EasyMarket, BaseUserInfo from learnware.market import database_ops from learnware.learnware import Learnware @@ -76,14 +78,23 @@ class LearnwareMarketWorkflow: spec.save(os.path.join(dir_path, "svm.json")) init_file = os.path.join(dir_path, "__init__.py") - os.system(f"cp example_init.py {init_file}") + copyfile("example_init.py", init_file) # cp example_init.py init_file yaml_file = os.path.join(dir_path, "learnware.yaml") - os.system(f"cp example.yaml {yaml_file}") + copyfile("example.yaml", yaml_file) # cp example.yaml yaml_file zip_file = dir_path + ".zip" - os.system(f"zip -q -r -j {zip_file} {dir_path}") - os.system(f"rm -r {dir_path}") + # zip -q -r -j zip_file dir_path + with zipfile.ZipFile(zip_file, "w") as zip_obj: + for foldername, subfolders, filenames in os.walk(dir_path): + for filename in filenames: + file_path = os.path.join(foldername, filename) + zip_info = zipfile.ZipInfo(filename) + zip_info.compress_type = zipfile.ZIP_STORED + with open(file_path, "rb") as file: + zip_obj.writestr(zip_info, file.read()) + + rmtree(dir_path) # rm -r dir_path self.zip_path_list.append(zip_file) @@ -120,8 +131,13 @@ class LearnwareMarketWorkflow: idx, zip_path = 1, self.zip_path_list[1] unzip_dir = os.path.join(test_folder, f"{idx}") + + # unzip -o -q zip_path -d unzip_dir + if os.path.exists(unzip_dir): + rmtree(unzip_dir) os.makedirs(unzip_dir, exist_ok=True) - os.system(f"unzip -o -q {zip_path} -d {unzip_dir}") + with zipfile.ZipFile(zip_path, "r") as zip_obj: + zip_obj.extractall(path=unzip_dir) user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic) _, single_learnware_list, _ = easy_market.search_learnware(user_info) @@ -131,11 +147,11 @@ class LearnwareMarketWorkflow: for learnware in single_learnware_list: print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec()) - os.system(f"rm -r {test_folder}") + rmtree(test_folder) # rm -r test_folder def test_stat_search(self, learnware_num=5): self._init_learnware_market() - self.prepare_learnware_randomly(learnware_num) + self.test_upload_delete_learnware(learnware_num) print(self.zip_path_list) easy_market = EasyMarket() @@ -145,8 +161,13 @@ class LearnwareMarketWorkflow: for idx, zip_path in enumerate(self.zip_path_list): unzip_dir = os.path.join(test_folder, f"{idx}") + + # unzip -o -q zip_path -d unzip_dir + if os.path.exists(unzip_dir): + rmtree(unzip_dir) os.makedirs(unzip_dir, exist_ok=True) - os.system(f"unzip -o -q {zip_path} -d {unzip_dir}") + with zipfile.ZipFile(zip_path, "r") as zip_obj: + zip_obj.extractall(path=unzip_dir) user_spec = specification.rkme.RKMEStatSpecification() user_spec.load(os.path.join(unzip_dir, "svm.json")) @@ -161,7 +182,7 @@ class LearnwareMarketWorkflow: mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list]) print(f"mixture_learnware: {mixture_id}\n") - os.system(f"rm -r {test_folder}") + rmtree(test_folder) # rm -r test_folder if __name__ == "__main__": diff --git a/learnware/market/easy.py b/learnware/market/easy.py index d4e0e08..d8ed01c 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -4,6 +4,7 @@ import zipfile import torch import numpy as np import pandas as pd +from cvxopt import solvers, matrix from typing import Tuple, Any, List, Union, Dict from .base import BaseMarket, BaseUserInfo @@ -190,10 +191,21 @@ class EasyMarket(BaseMarket): K = torch.from_numpy(K).double().to(user_rkme.device) C = torch.from_numpy(C).double().to(user_rkme.device) - # if nonnegative_beta: - # w = solve_qp(K, C).double().to(Phi_t.device) - # else: - weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C + # beta can be negative + # weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C + + # beta must be nonnegative + n = K.shape[0] + P = matrix(K.cpu().numpy()) + q = matrix(-C.cpu().numpy()) + G = matrix(-np.eye(n)) + h = matrix(np.zeros((n, 1))) + A = matrix(np.ones((1, n))) + b = matrix(np.ones((1, 1))) + solvers.options["show_progress"] = False + sol = solvers.qp(P, q, G, h, A, b) + weight = np.array(sol["x"]) + weight = torch.from_numpy(weight).reshape(-1).double().to(user_rkme.device) term1 = user_rkme.inner_prod(user_rkme) term2 = weight.T @ C @@ -238,7 +250,7 @@ class EasyMarket(BaseMarket): return intermediate_K, intermediate_C def _search_by_rkme_spec_mixture( - self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, max_search_num: int = 5, score_cutoff: float = 0.1 + self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, max_search_num: int = 5, score_cutoff: float = 0.05 ) -> Tuple[List[float], List[Learnware]]: """Get learnwares with their mixture weight from the given learnware_list @@ -269,7 +281,7 @@ class EasyMarket(BaseMarket): flag_list = [0 for _ in range(learnware_num)] mixture_list = [] intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1)) - + for k in range(max_search_num): idx_min, score_min = -1, -1 weight_min = None @@ -292,14 +304,14 @@ class EasyMarket(BaseMarket): if idx_min == -1 or score < score_min: idx_min, score_min, weight_min = idx, score, weight - if score_min >= score_cutoff: + mixture_list[-1] = learnware_list[idx_min] + if score_min < score_cutoff: + break + else: flag_list[idx_min] = 1 - mixture_list[-1] = learnware_list[idx_min] intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C( mixture_list, user_rkme, intermediate_K, intermediate_C ) - else: - break return weight_min, mixture_list @@ -394,7 +406,7 @@ class EasyMarket(BaseMarket): learnware_list = [self.learnware_list[key] for key in self.learnware_list] learnware_list = self._search_by_semantic_spec(learnware_list, user_info) # learnware_list = list(set(learnware_list_tags + learnware_list_description)) - + if "RKMEStatSpecification" not in user_info.stat_info: return None, learnware_list, None elif len(learnware_list) == 0: