Browse Source

[FIX] fix skip_test default value

tags/v0.3.2
Gene 2 years ago
parent
commit
0ac4c577f4
1 changed files with 7 additions and 3 deletions
  1. +7
    -3
      examples/dataset_image_workflow/workflow.py

+ 7
- 3
examples/dataset_image_workflow/workflow.py View File

@@ -77,7 +77,7 @@ class ImageDatasetWorkflow:

logger.info("Total Item: %d" % (len(self.image_market)))

def image_example(self, rebuild=False, skip_test=True):
def image_example(self, rebuild=False, skip_test=False):
np.random.seed(1)
random.seed(1)
self.n_labeled_list = [100, 200, 500, 1000, 2000, 4000]
@@ -112,7 +112,9 @@ class ImageDatasetWorkflow:
test_dataset = TensorDataset(test_x, test_y)

user_stat_spec = generate_stat_spec(type="image", X=test_x, whitening=False)
user_info = BaseUserInfo(semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec})
user_info = BaseUserInfo(
semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec}
)
logger.info("Searching Market for user: %d" % (i))

search_result = self.image_market.search_learnware(user_info)
@@ -195,7 +197,9 @@ class ImageDatasetWorkflow:
_, user_model_acc = evaluate(model, test_dataset, distribution=True)
user_model_score_list.append(user_model_acc)

reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification")
reuse_pruning = EnsemblePruningReuser(
learnware_list=mixture_learnware_list, mode="classification"
)
reuse_pruning.fit(x_train, y_train)
_, pruning_acc = evaluate(reuse_pruning, test_dataset, distribution=False)
reuse_pruning_score_list.append(pruning_acc)


Loading…
Cancel
Save