diff --git a/examples/dataset_pfs_workflow/example_init.py b/examples/dataset_pfs_workflow/example_init.py index 88b788a..77bad5e 100644 --- a/examples/dataset_pfs_workflow/example_init.py +++ b/examples/dataset_pfs_workflow/example_init.py @@ -6,7 +6,7 @@ from learnware.model import BaseModel class Model(BaseModel): def __init__(self): - super(Model, self).__init__(input_shape=(31,), output_shape=()) + super(Model, self).__init__(input_shape=(31,), output_shape=(1,)) dir_path = os.path.dirname(os.path.abspath(__file__)) self.model = joblib.load(os.path.join(dir_path, "model.out")) diff --git a/examples/dataset_pfs_workflow/main.py b/examples/dataset_pfs_workflow/main.py index 9ce02f7..1f8c348 100644 --- a/examples/dataset_pfs_workflow/main.py +++ b/examples/dataset_pfs_workflow/main.py @@ -44,7 +44,7 @@ user_semantic = { "Library": {"Values": ["Scikit-learn"], "Type": "Class"}, "Scenario": {"Values": ["Business"], "Type": "Tag"}, "Description": {"Values": "", "Type": "String"}, - "Name": {"Values": "learnware_1", "Type": "String"}, + "Name": {"Values": "", "Type": "String"}, "Input": input_description, "Output": output_description, } @@ -55,7 +55,7 @@ class PFSDatasetWorkflow: pfs = Dataloader() pfs.regenerate_data() - algo_list = ["ridge", "lgb"] + algo_list = ["ridge"] # , "lgb" for algo in algo_list: pfs.set_algo(algo) pfs.retrain_models() @@ -88,7 +88,7 @@ class PFSDatasetWorkflow: pfs = Dataloader() idx_list = pfs.get_idx_list() - algo_list = ["lgb"] # ["ridge", "lgb"] + algo_list = ["ridge"] # ["ridge", "lgb"] curr_root = os.path.dirname(os.path.abspath(__file__)) curr_root = os.path.join(curr_root, "learnware_pool") diff --git a/learnware/market/easy/checker.py b/learnware/market/easy/checker.py index a988157..d0d50cf 100644 --- a/learnware/market/easy/checker.py +++ b/learnware/market/easy/checker.py @@ -89,7 +89,6 @@ class EasyStatChecker(BaseChecker): traceback.print_exc() logger.warning(f"The learnware [{learnware.id}] is instantiated failed! Due to {e}.") return self.INVALID_LEARNWARE - try: learnware_model = learnware.get_model() # Check input shape @@ -117,10 +116,9 @@ class EasyStatChecker(BaseChecker): elif spec_type == "RKMEImageSpecification": inputs = np.random.randint(0, 255, size=(10, *input_shape)) else: - raise ValueError(f"not supported spec type for spec_type = {spec_type}") + raise ValueError(f"not supported spec type for spec_type = {spec_type}") # Check output - outputs = learnware.predict(inputs) try: outputs = learnware.predict(inputs) except Exception: diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index 8f147b2..2b00c61 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -129,7 +129,7 @@ class RKMETableSpecification(RegularStatsSpecification): return # Initialize Z by clustering, utiliing kmeans or faiss to speed up the process. - self._init_z_by_kmeans(X, K) + self._init_z_by_faiss(X, K) self._update_beta(X, nonnegative_beta) # Alternating optimize Z and beta