Browse Source

[FIX] Fix databse zippath bug

tags/v0.3.2
bxdd 3 years ago
parent
commit
de879c3da6
6 changed files with 59 additions and 14 deletions
  1. +7
    -4
      examples/example_market_db/example_db.py
  2. +23
    -1
      examples/workflow_by_code/main.py
  3. +16
    -4
      learnware/config.py
  4. +4
    -1
      learnware/learnware/__init__.py
  5. +5
    -2
      learnware/market/easy.py
  6. +4
    -2
      learnware/specification/rkme.py

+ 7
- 4
examples/example_market_db/example_db.py View File

@@ -40,7 +40,10 @@ semantic_specs = [

user_senmantic = {
"Data": {"Values": ["Tabular"], "Type": "Class"},
"Task": {"Values": ["Classification"], "Type": "Class",},
"Task": {
"Values": ["Classification"],
"Type": "Class",
},
"Device": {"Values": ["GPU"], "Type": "Tag"},
"Scenario": {"Values": ["Business"], "Type": "Tag"},
"Description": {"Values": "", "Type": "Description"},
@@ -95,7 +98,7 @@ def test_market():
semantic_spec["Name"]["Values"] = "learnware_%d" % (idx)
semantic_spec["Description"]["Values"] = "test_learnware_number_%d" % (idx)
easy_market.add_learnware(zip_path, semantic_spec)
return
print("Total Item:", len(easy_market))
curr_inds = easy_market._get_ids()
print("Available ids:", curr_inds)
@@ -165,8 +168,8 @@ def test_stat_search():


if __name__ == "__main__":
learnware_num = 5
learnware_num = 10
prepare_learnware(learnware_num)
test_market()
# test_stat_search()
test_stat_search()
test_search_semantics()

+ 23
- 1
examples/workflow_by_code/main.py View File

@@ -1,5 +1,15 @@
import os
import fire
import os
import joblib
import numpy as np
import learnware

from sklearn import svm
from learnware.market import EasyMarket, BaseUserInfo
from learnware.market import database_ops
from learnware.learnware import Learnware
import learnware.specification as specification
from learnware.utils import get_module_by_module_path


class LearnwareMarketWorkflow:
@@ -53,6 +63,18 @@ class LearnwareMarketWorkflow:
"Name": {"Values": "", "Type": "Name"},
}

def _init_learnware_market(self):
"""initialize learnware market"""
database_ops.clear_learnware_table()
learnware.init()

self.learnware_market = EasyMarket()

def _generate_learnware_randomly(self):
pass

# def _


if __name__ == "__main__":
fire.Fire(LearnwareMarketWorkflow)

+ 16
- 4
learnware/config.py View File

@@ -66,7 +66,10 @@ os.makedirs(LEARNWARE_FOLDER_POOL_PATH, exist_ok=True)
os.makedirs(DATABASE_PATH, exist_ok=True)

semantic_config = {
"Data": {"Values": ["Tabular", "Image", "Video", "Text", "Audio"], "Type": "Class",}, # Choose only one class
"Data": {
"Values": ["Tabular", "Image", "Video", "Text", "Audio"],
"Type": "Class",
}, # Choose only one class
"Task": {
"Values": [
"Classification",
@@ -79,7 +82,10 @@ semantic_config = {
],
"Type": "Class", # Choose only one class
},
"Device": {"Values": ["CPU", "GPU"], "Type": "Tag",}, # Choose one or more tags
"Device": {
"Values": ["CPU", "GPU"],
"Type": "Tag",
}, # Choose one or more tags
"Scenario": {
"Values": [
"Business",
@@ -99,8 +105,14 @@ semantic_config = {
],
"Type": "Tag", # Choose one or more tags
},
"Description": {"Values": None, "Type": "Description",},
"Name": {"Values": None, "Type": "Name",},
"Description": {
"Values": None,
"Type": "Description",
},
"Name": {
"Values": None,
"Type": "Name",
},
}

_DEFAULT_CONFIG = {


+ 4
- 1
learnware/learnware/__init__.py View File

@@ -29,7 +29,10 @@ def get_learnware_from_dirpath(id: str, semantic_spec: dict, learnware_dirpath:
The contructed learnware object, return None if build failed
"""
learnware_config = {
"model": {"class_name": "Model", "kwargs": {},},
"model": {
"class_name": "Model",
"kwargs": {},
},
"stat_specifications": [
{
"module_path": "learnware.specification",


+ 5
- 2
learnware/market/easy.py View File

@@ -119,10 +119,13 @@ class EasyMarket(BaseMarket):
self.learnware_folder_list[id] = target_folder_dir
self.count += 1
add_learnware_to_db(
id, semantic_spec=semantic_spec, zip_path=target_folder_dir, folder_path=target_folder_dir,
id,
semantic_spec=semantic_spec,
zip_path=target_zip_dir,
folder_path=target_folder_dir,
)
return id, True
def _convert_dist_to_score(self, dist_list: List[float]) -> List[float]:
"""Convert mmd dist list into min_max score list



+ 4
- 2
learnware/specification/rkme.py View File

@@ -255,7 +255,9 @@ class RKMEStatSpecification(BaseStatSpecification):
rkme_to_save["beta"] = rkme_to_save["beta"].tolist()
rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu"
json.dump(
rkme_to_save, codecs.open(save_path, "w", encoding="utf-8"), separators=(",", ":"),
rkme_to_save,
codecs.open(save_path, "w", encoding="utf-8"),
separators=(",", ":"),
)

def load(self, filepath: str) -> bool:
@@ -343,7 +345,7 @@ def torch_rbf_kernel(x1, x2, gamma) -> torch.Tensor:
"""
x1 = x1.double()
x2 = x2.double()
X12norm = torch.sum(x1 ** 2, 1, keepdim=True) - 2 * x1 @ x2.T + torch.sum(x2 ** 2, 1, keepdim=True).T
X12norm = torch.sum(x1**2, 1, keepdim=True) - 2 * x1 @ x2.T + torch.sum(x2**2, 1, keepdim=True).T
return torch.exp(-X12norm * gamma)




Loading…
Cancel
Save