Browse Source

[MNT] Fix bugs in _search_by_rkme_spec_mixture

tags/v0.3.2
Gene 3 years ago
parent
commit
005729ac4c
2 changed files with 56 additions and 23 deletions
  1. +33
    -12
      examples/workflow_by_code/main.py
  2. +23
    -11
      learnware/market/easy.py

+ 33
- 12
examples/workflow_by_code/main.py View File

@@ -1,10 +1,12 @@
import fire
import os
import fire
import joblib
import zipfile
import numpy as np
import learnware

from sklearn import svm
from shutil import copyfile, rmtree

import learnware
from learnware.market import EasyMarket, BaseUserInfo
from learnware.market import database_ops
from learnware.learnware import Learnware
@@ -76,14 +78,23 @@ class LearnwareMarketWorkflow:
spec.save(os.path.join(dir_path, "svm.json"))

init_file = os.path.join(dir_path, "__init__.py")
os.system(f"cp example_init.py {init_file}")
copyfile("example_init.py", init_file) # cp example_init.py init_file

yaml_file = os.path.join(dir_path, "learnware.yaml")
os.system(f"cp example.yaml {yaml_file}")
copyfile("example.yaml", yaml_file) # cp example.yaml yaml_file

zip_file = dir_path + ".zip"
os.system(f"zip -q -r -j {zip_file} {dir_path}")
os.system(f"rm -r {dir_path}")
# zip -q -r -j zip_file dir_path
with zipfile.ZipFile(zip_file, "w") as zip_obj:
for foldername, subfolders, filenames in os.walk(dir_path):
for filename in filenames:
file_path = os.path.join(foldername, filename)
zip_info = zipfile.ZipInfo(filename)
zip_info.compress_type = zipfile.ZIP_STORED
with open(file_path, "rb") as file:
zip_obj.writestr(zip_info, file.read())
rmtree(dir_path) # rm -r dir_path

self.zip_path_list.append(zip_file)

@@ -120,8 +131,13 @@ class LearnwareMarketWorkflow:

idx, zip_path = 1, self.zip_path_list[1]
unzip_dir = os.path.join(test_folder, f"{idx}")
# unzip -o -q zip_path -d unzip_dir
if os.path.exists(unzip_dir):
rmtree(unzip_dir)
os.makedirs(unzip_dir, exist_ok=True)
os.system(f"unzip -o -q {zip_path} -d {unzip_dir}")
with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir)

user_info = BaseUserInfo(id="user_0", semantic_spec=user_senmantic)
_, single_learnware_list, _ = easy_market.search_learnware(user_info)
@@ -131,11 +147,11 @@ class LearnwareMarketWorkflow:
for learnware in single_learnware_list:
print("Choose learnware:", learnware.id, learnware.get_specification().get_semantic_spec())

os.system(f"rm -r {test_folder}")
rmtree(test_folder) # rm -r test_folder

def test_stat_search(self, learnware_num=5):
self._init_learnware_market()
self.prepare_learnware_randomly(learnware_num)
self.test_upload_delete_learnware(learnware_num)

print(self.zip_path_list)
easy_market = EasyMarket()
@@ -145,8 +161,13 @@ class LearnwareMarketWorkflow:

for idx, zip_path in enumerate(self.zip_path_list):
unzip_dir = os.path.join(test_folder, f"{idx}")
# unzip -o -q zip_path -d unzip_dir
if os.path.exists(unzip_dir):
rmtree(unzip_dir)
os.makedirs(unzip_dir, exist_ok=True)
os.system(f"unzip -o -q {zip_path} -d {unzip_dir}")
with zipfile.ZipFile(zip_path, "r") as zip_obj:
zip_obj.extractall(path=unzip_dir)

user_spec = specification.rkme.RKMEStatSpecification()
user_spec.load(os.path.join(unzip_dir, "svm.json"))
@@ -161,7 +182,7 @@ class LearnwareMarketWorkflow:
mixture_id = " ".join([learnware.id for learnware in mixture_learnware_list])
print(f"mixture_learnware: {mixture_id}\n")

os.system(f"rm -r {test_folder}")
rmtree(test_folder) # rm -r test_folder


if __name__ == "__main__":


+ 23
- 11
learnware/market/easy.py View File

@@ -4,6 +4,7 @@ import zipfile
import torch
import numpy as np
import pandas as pd
from cvxopt import solvers, matrix
from typing import Tuple, Any, List, Union, Dict

from .base import BaseMarket, BaseUserInfo
@@ -190,10 +191,21 @@ class EasyMarket(BaseMarket):
K = torch.from_numpy(K).double().to(user_rkme.device)
C = torch.from_numpy(C).double().to(user_rkme.device)

# if nonnegative_beta:
# w = solve_qp(K, C).double().to(Phi_t.device)
# else:
weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C
# beta can be negative
# weight = torch.linalg.inv(K + torch.eye(K.shape[0]).to(user_rkme.device) * 1e-5) @ C
# beta must be nonnegative
n = K.shape[0]
P = matrix(K.cpu().numpy())
q = matrix(-C.cpu().numpy())
G = matrix(-np.eye(n))
h = matrix(np.zeros((n, 1)))
A = matrix(np.ones((1, n)))
b = matrix(np.ones((1, 1)))
solvers.options["show_progress"] = False
sol = solvers.qp(P, q, G, h, A, b)
weight = np.array(sol["x"])
weight = torch.from_numpy(weight).reshape(-1).double().to(user_rkme.device)

term1 = user_rkme.inner_prod(user_rkme)
term2 = weight.T @ C
@@ -238,7 +250,7 @@ class EasyMarket(BaseMarket):
return intermediate_K, intermediate_C

def _search_by_rkme_spec_mixture(
self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, max_search_num: int = 5, score_cutoff: float = 0.1
self, learnware_list: List[Learnware], user_rkme: RKMEStatSpecification, max_search_num: int = 5, score_cutoff: float = 0.05
) -> Tuple[List[float], List[Learnware]]:
"""Get learnwares with their mixture weight from the given learnware_list

@@ -269,7 +281,7 @@ class EasyMarket(BaseMarket):
flag_list = [0 for _ in range(learnware_num)]
mixture_list = []
intermediate_K, intermediate_C = np.zeros((1, 1)), np.zeros((1, 1))
for k in range(max_search_num):
idx_min, score_min = -1, -1
weight_min = None
@@ -292,14 +304,14 @@ class EasyMarket(BaseMarket):
if idx_min == -1 or score < score_min:
idx_min, score_min, weight_min = idx, score, weight

if score_min >= score_cutoff:
mixture_list[-1] = learnware_list[idx_min]
if score_min < score_cutoff:
break
else:
flag_list[idx_min] = 1
mixture_list[-1] = learnware_list[idx_min]
intermediate_K, intermediate_C = self._calculate_intermediate_K_and_C(
mixture_list, user_rkme, intermediate_K, intermediate_C
)
else:
break

return weight_min, mixture_list

@@ -394,7 +406,7 @@ class EasyMarket(BaseMarket):
learnware_list = [self.learnware_list[key] for key in self.learnware_list]
learnware_list = self._search_by_semantic_spec(learnware_list, user_info)
# learnware_list = list(set(learnware_list_tags + learnware_list_description))
if "RKMEStatSpecification" not in user_info.stat_info:
return None, learnware_list, None
elif len(learnware_list) == 0:


Loading…
Cancel
Save