From 0ac4c577f4bf58dafa721de8c30969ca9fcba2d9 Mon Sep 17 00:00:00 2001 From: Gene Date: Thu, 28 Dec 2023 17:20:33 +0800 Subject: [PATCH] [FIX] fix skip_test default value --- examples/dataset_image_workflow/workflow.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/dataset_image_workflow/workflow.py b/examples/dataset_image_workflow/workflow.py index 2bfbe4d..dbe293a 100644 --- a/examples/dataset_image_workflow/workflow.py +++ b/examples/dataset_image_workflow/workflow.py @@ -77,7 +77,7 @@ class ImageDatasetWorkflow: logger.info("Total Item: %d" % (len(self.image_market))) - def image_example(self, rebuild=False, skip_test=True): + def image_example(self, rebuild=False, skip_test=False): np.random.seed(1) random.seed(1) self.n_labeled_list = [100, 200, 500, 1000, 2000, 4000] @@ -112,7 +112,9 @@ class ImageDatasetWorkflow: test_dataset = TensorDataset(test_x, test_y) user_stat_spec = generate_stat_spec(type="image", X=test_x, whitening=False) - user_info = BaseUserInfo(semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec}) + user_info = BaseUserInfo( + semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec} + ) logger.info("Searching Market for user: %d" % (i)) search_result = self.image_market.search_learnware(user_info) @@ -195,7 +197,9 @@ class ImageDatasetWorkflow: _, user_model_acc = evaluate(model, test_dataset, distribution=True) user_model_score_list.append(user_model_acc) - reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification") + reuse_pruning = EnsemblePruningReuser( + learnware_list=mixture_learnware_list, mode="classification" + ) reuse_pruning.fit(x_train, y_train) _, pruning_acc = evaluate(reuse_pruning, test_dataset, distribution=False) reuse_pruning_score_list.append(pruning_acc)