Browse Source

[FIX] try to fix pfs

tags/v0.3.2
nju-xy 2 years ago
parent
commit
356375d525
4 changed files with 6 additions and 8 deletions
  1. +1
    -1
      examples/dataset_pfs_workflow/example_init.py
  2. +3
    -3
      examples/dataset_pfs_workflow/main.py
  3. +1
    -3
      learnware/market/easy/checker.py
  4. +1
    -1
      learnware/specification/regular/table/rkme.py

+ 1
- 1
examples/dataset_pfs_workflow/example_init.py View File

@@ -6,7 +6,7 @@ from learnware.model import BaseModel

class Model(BaseModel):
def __init__(self):
super(Model, self).__init__(input_shape=(31,), output_shape=())
super(Model, self).__init__(input_shape=(31,), output_shape=(1,))
dir_path = os.path.dirname(os.path.abspath(__file__))
self.model = joblib.load(os.path.join(dir_path, "model.out"))



+ 3
- 3
examples/dataset_pfs_workflow/main.py View File

@@ -44,7 +44,7 @@ user_semantic = {
"Library": {"Values": ["Scikit-learn"], "Type": "Class"},
"Scenario": {"Values": ["Business"], "Type": "Tag"},
"Description": {"Values": "", "Type": "String"},
"Name": {"Values": "learnware_1", "Type": "String"},
"Name": {"Values": "", "Type": "String"},
"Input": input_description,
"Output": output_description,
}
@@ -55,7 +55,7 @@ class PFSDatasetWorkflow:
pfs = Dataloader()
pfs.regenerate_data()

algo_list = ["ridge", "lgb"]
algo_list = ["ridge"] # , "lgb"
for algo in algo_list:
pfs.set_algo(algo)
pfs.retrain_models()
@@ -88,7 +88,7 @@ class PFSDatasetWorkflow:

pfs = Dataloader()
idx_list = pfs.get_idx_list()
algo_list = ["lgb"] # ["ridge", "lgb"]
algo_list = ["ridge"] # ["ridge", "lgb"]

curr_root = os.path.dirname(os.path.abspath(__file__))
curr_root = os.path.join(curr_root, "learnware_pool")


+ 1
- 3
learnware/market/easy/checker.py View File

@@ -89,7 +89,6 @@ class EasyStatChecker(BaseChecker):
traceback.print_exc()
logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.")
return self.INVALID_LEARNWARE

try:
learnware_model = learnware.get_model()
# Check input shape
@@ -117,10 +116,9 @@ class EasyStatChecker(BaseChecker):
elif spec_type == "RKMEImageSpecification":
inputs = np.random.randint(0, 255, size=(10, *input_shape))
else:
raise ValueError(f"not supported spec type for spec_type = {spec_type}")
raise ValueError(f"not supported spec type for spec_type = {spec_type}")

# Check output
outputs = learnware.predict(inputs)
try:
outputs = learnware.predict(inputs)
except Exception:


+ 1
- 1
learnware/specification/regular/table/rkme.py View File

@@ -129,7 +129,7 @@ class RKMETableSpecification(RegularStatsSpecification):
return

# Initialize Z by clustering, utiliing kmeans or faiss to speed up the process.
self._init_z_by_kmeans(X, K)
self._init_z_by_faiss(X, K)
self._update_beta(X, nonnegative_beta)

# Alternating optimize Z and beta


Loading…
Cancel
Save