|
|
|
@@ -77,7 +77,7 @@ class ImageDatasetWorkflow: |
|
|
|
|
|
|
|
logger.info("Total Item: %d" % (len(self.image_market))) |
|
|
|
|
|
|
|
def image_example(self, rebuild=False, skip_test=True): |
|
|
|
def image_example(self, rebuild=False, skip_test=False): |
|
|
|
np.random.seed(1) |
|
|
|
random.seed(1) |
|
|
|
self.n_labeled_list = [100, 200, 500, 1000, 2000, 4000] |
|
|
|
@@ -112,7 +112,9 @@ class ImageDatasetWorkflow: |
|
|
|
test_dataset = TensorDataset(test_x, test_y) |
|
|
|
|
|
|
|
user_stat_spec = generate_stat_spec(type="image", X=test_x, whitening=False) |
|
|
|
user_info = BaseUserInfo(semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec}) |
|
|
|
user_info = BaseUserInfo( |
|
|
|
semantic_spec=self.user_semantic, stat_info={user_stat_spec.type: user_stat_spec} |
|
|
|
) |
|
|
|
logger.info("Searching Market for user: %d" % (i)) |
|
|
|
|
|
|
|
search_result = self.image_market.search_learnware(user_info) |
|
|
|
@@ -195,7 +197,9 @@ class ImageDatasetWorkflow: |
|
|
|
_, user_model_acc = evaluate(model, test_dataset, distribution=True) |
|
|
|
user_model_score_list.append(user_model_acc) |
|
|
|
|
|
|
|
reuse_pruning = EnsemblePruningReuser(learnware_list=mixture_learnware_list, mode="classification") |
|
|
|
reuse_pruning = EnsemblePruningReuser( |
|
|
|
learnware_list=mixture_learnware_list, mode="classification" |
|
|
|
) |
|
|
|
reuse_pruning.fit(x_train, y_train) |
|
|
|
_, pruning_acc = evaluate(reuse_pruning, test_dataset, distribution=False) |
|
|
|
reuse_pruning_score_list.append(pruning_acc) |
|
|
|
|