Browse Source

Merge branch 'dev' of ssh://git.nju.edu.cn/learnware/learnware-market into dev

tags/v0.3.2
liuht 3 years ago
parent
commit
2ee9f93f30
6 changed files with 32 additions and 7 deletions
  1. +3
    -3
      examples/example_pfs/main.py
  2. +2
    -2
      examples/example_pfs/pfs/pfs_cross_transfer.py
  3. +1
    -0
      learnware/config.py
  4. +2
    -2
      learnware/learnware/reuse.py
  5. +20
    -0
      learnware/market/easy.py
  6. +4
    -0
      learnware/specification/utils.py

+ 3
- 3
examples/example_pfs/main.py View File

@@ -142,8 +142,8 @@ class PFSDatasetWorkflow:
rmtree(dir_path)

def test(self, regenerate_flag=False):
# self.prepare_learnware(regenerate_flag)
# self._init_learnware_market()
self.prepare_learnware(regenerate_flag)
self._init_learnware_market()

easy_market = EasyMarket()
print("Total Item:", len(easy_market))
@@ -173,7 +173,7 @@ class PFSDatasetWorkflow:
reuse_baseline = ReuseBaseline(learnware_list=mixture_learnware_list)
reuse_predict = reuse_baseline.predict(user_data=test_x)
reuse_score = pfs.score(test_y, reuse_predict)
print(f"mixture reuse score: {reuse_score}\n")
print(f"mixture reuse score: {reuse_score}")


if __name__ == "__main__":


+ 2
- 2
examples/example_pfs/pfs/pfs_cross_transfer.py View File

@@ -67,7 +67,7 @@ def get_split_errs(algo):
for tmp in range(len(proportion_list)):
model = lgb.LGBMModel(
boosting_type="gbdt",
num_leaves=2**7 - 1,
num_leaves=2 ** 7 - 1,
learning_rate=0.01,
objective="rmse",
metric="rmse",
@@ -119,7 +119,7 @@ def get_errors(algo):
if algo == "lgb":
model = lgb.LGBMModel(
boosting_type="gbdt",
num_leaves=2**7 - 1,
num_leaves=2 ** 7 - 1,
learning_rate=0.01,
objective="rmse",
metric="rmse",


+ 1
- 0
learnware/config.py View File

@@ -128,6 +128,7 @@ _DEFAULT_CONFIG = {
"module_file": "__init__.py",
},
"database_path": DATABASE_PATH,
"max_reduced_set_size": 1000000,
}

C = Config(_DEFAULT_CONFIG)

+ 2
- 2
learnware/learnware/reuse.py View File

@@ -191,7 +191,7 @@ class ReuseBaseline:
seed=0,
)
train_y = train_y.astype(np.int)
model.fit(train_x, train_y, eval_set=[(val_x, val_y)], early_stopping_rounds=300)
model.fit(train_x, train_y, eval_set=[(val_x, val_y)], verbose=-1, early_stopping_rounds=300)
pred_y = model.predict(org_train_x)
score = accuracy_score(pred_y, org_train_y)

@@ -208,6 +208,6 @@ class ReuseBaseline:
booster="gbtree",
seed=0,
)
model.fit(org_train_x, org_train_y, eval_set=[(org_train_x, org_train_y)], early_stopping_rounds=300)
model.fit(org_train_x, org_train_y, eval_set=[(org_train_x, org_train_y)], verbose=-1, early_stopping_rounds=300)

return model

+ 20
- 0
learnware/market/easy.py View File

@@ -91,6 +91,26 @@ class EasyMarket(BaseMarket):
logger.warning("Zip Path NOT Found! Fail to add learnware.")
return None, False

try:
if len(semantic_spec["Data"]["Values"]) == 0:
logger.warning("Illegal semantic specification, please choose Data.")
return None, False
if len(semantic_spec["Task"]["Values"]) == 0:
logger.warning("Illegal semantic specification, please choose Task.")
return None, False
if len(semantic_spec["Device"]["Values"]) == 0:
logger.warning("Illegal semantic specification, please choose Device.")
return None, False
if len(semantic_spec["Name"]["Values"]) == 0:
logger.warning("Illegal semantic specification, please provide Name.")
return None, False
if len(semantic_spec["Description"]["Values"]) == 0 and len(semantic_spec["Scenario"]["Values"]) == 0:
logger.warning("Illegal semantic specification, please provide Scenario or Description.")
return None, False
except:
logger.warning("Illegal semantic specification, some keys are missing.")
return None, False

logger.info("Get new learnware from %s" % (zip_path))
id = "%08d" % (self.count)
target_zip_dir = os.path.join(C.learnware_zip_pool_path, "%s.zip" % (id))


+ 4
- 0
learnware/specification/utils.py View File

@@ -2,6 +2,7 @@ import numpy as np

from .base import BaseStatSpecification
from .rkme import RKMEStatSpecification
from ..config import C


def generate_rkme_spec(
@@ -45,6 +46,9 @@ def generate_rkme_spec(
A RKMEStatSpecification object
"""
X = np.ascontiguousarray(X).astype(np.float32)
max_reduced_set_size = C.max_reduced_set_size
if K * X[0].size > max_reduced_set_size:
K = max(1, max_reduced_set_size // X[0].size)
rkme_spec = RKMEStatSpecification(gamma=gamma, cuda_idx=cuda_idx)
rkme_spec.generate_stat_spec_from_data(X, K, step_size, steps, nonnegative_beta, reduce)
return rkme_spec


Loading…
Cancel
Save