From fda945f5cc76a6f765d3200f6fe110203a2f8fcb Mon Sep 17 00:00:00 2001 From: bxdd Date: Tue, 31 Oct 2023 16:38:38 +0800 Subject: [PATCH] [MNT, FIX] modify typehint for easysearch, and fix file not close warning --- learnware/market/easy2/organizer.py | 17 ++-------------- learnware/market/easy2/searcher.py | 20 +++++++++---------- learnware/specification/regular/__init__.py | 2 +- learnware/specification/regular/base.py | 2 ++ learnware/specification/regular/image/rkme.py | 11 +++------- .../specification/regular/table/__init__.py | 2 +- learnware/specification/regular/table/rkme.py | 7 ++++++- tests/test_specification/test_rkme.py | 1 - 8 files changed, 25 insertions(+), 37 deletions(-) diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index 9b6bf8c..18f67eb 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -1,28 +1,15 @@ import os -import json import copy -import torch import zipfile -import traceback import tempfile -import numpy as np -import pandas as pd -from rapidfuzz import fuzz -from cvxopt import solvers, matrix from shutil import copyfile, rmtree -from typing import Tuple, Any, List, Union, Dict +from typing import Tuple, List, Union from .database_ops import DatabaseOperations -from ..base import LearnwareMarket, BaseUserInfo - - -from ... import utils +from ..base import BaseOrganizer, BaseChecker from ...config import C as conf from ...logger import get_module_logger from ...learnware import Learnware, get_learnware_from_dirpath -from ...specification import Specification - -from ..base import BaseOrganizer, BaseChecker from ...logger import get_module_logger logger = get_module_logger("easy_organizer") diff --git a/learnware/market/easy2/searcher.py b/learnware/market/easy2/searcher.py index dcf3335..aa741e3 100644 --- a/learnware/market/easy2/searcher.py +++ b/learnware/market/easy2/searcher.py @@ -2,12 +2,12 @@ import torch import numpy as np from rapidfuzz import fuzz from cvxopt import solvers, matrix -from typing import Tuple, List +from typing import Tuple, List, Union from .organizer import EasyOrganizer from ..base import BaseUserInfo, BaseSearcher from ...learnware import Learnware -from ...specification import RKMETableSpecification +from ...specification import RKMETableSpecification, RKMEImageSpecification from ...logger import get_module_logger logger = get_module_logger("easy_seacher") @@ -188,7 +188,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): return final_result -class EasyTableSearcher(BaseSearcher): +class EasyStatSearcher(BaseSearcher): def _convert_dist_to_score( self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92 ) -> List[float]: @@ -419,7 +419,7 @@ class EasyTableSearcher(BaseSearcher): return sorted_score_list[:idx], learnware_list[:idx] def _filter_by_rkme_spec_dimension( - self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification + self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] ) -> List[Learnware]: """Filter learnwares whose rkme dimension different from user_rkme @@ -427,7 +427,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMETableSpecification + user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] User RKME statistical specification Returns @@ -519,7 +519,7 @@ class EasyTableSearcher(BaseSearcher): return mmd_dist, weight_min, mixture_list def _search_by_rkme_spec_single( - self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification + self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] ) -> Tuple[List[float], List[Learnware]]: """Calculate the distances between learnwares in the given learnware_list and user_rkme @@ -527,7 +527,7 @@ class EasyTableSearcher(BaseSearcher): ---------- learnware_list : List[Learnware] The list of learnwares whose mixture approximates the user's rkme - user_rkme : RKMETableSpecification + user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] user RKME statistical specification Returns @@ -599,12 +599,12 @@ class EasySearcher(BaseSearcher): def __init__(self, organizer: EasyOrganizer = None): super(EasySearcher, self).__init__(organizer) self.semantic_searcher = EasyFuzzSemanticSearcher(organizer) - self.table_searcher = EasyTableSearcher(organizer) + self.stat_searcher = EasyStatSearcher(organizer) def reset(self, organizer): self.learnware_oganizer = organizer self.semantic_searcher.reset(organizer) - self.table_searcher.reset(organizer) + self.stat_searcher.reset(organizer) def __call__( self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" @@ -632,6 +632,6 @@ class EasySearcher(BaseSearcher): if len(learnware_list) == 0: return [], [], 0.0, [] elif "RKMETableSpecification" in user_info.stat_info: - return self.table_searcher(learnware_list, user_info, max_search_num, search_method) + return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) else: return None, learnware_list, 0.0, None diff --git a/learnware/specification/regular/__init__.py b/learnware/specification/regular/__init__.py index 4373eb0..eeb4b3f 100644 --- a/learnware/specification/regular/__init__.py +++ b/learnware/specification/regular/__init__.py @@ -1,3 +1,3 @@ from .table import RKMETableSpecification, RKMEStatSpecification from .image import RKMEImageSpecification -from .base import RegularStatsSpecification \ No newline at end of file +from .base import RegularStatsSpecification diff --git a/learnware/specification/regular/base.py b/learnware/specification/regular/base.py index 48a7e1f..6916177 100644 --- a/learnware/specification/regular/base.py +++ b/learnware/specification/regular/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from ..base import BaseStatSpecification diff --git a/learnware/specification/regular/image/rkme.py b/learnware/specification/regular/image/rkme.py index e0454da..1f05382 100644 --- a/learnware/specification/regular/image/rkme.py +++ b/learnware/specification/regular/image/rkme.py @@ -122,9 +122,7 @@ class RKMEImageSpecification(BaseStatSpecification): X[i] = torch.where(is_nan, img_mean, img) if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: - X = Resize( - (RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None - )(X) + X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None)(X) num_points = X.shape[0] X_shape = X.shape @@ -343,11 +341,8 @@ class RKMEImageSpecification(BaseStatSpecification): rkme_to_save["beta"] = rkme_to_save["beta"].tolist() rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" - json.dump( - rkme_to_save, - codecs.open(save_path, "w", encoding="utf-8"), - separators=(",", ":"), - ) + with codecs.open(save_path, "w", encoding="utf-8") as fout: + json.dump(rkme_to_save, fout, separators=(",", ":")) def load(self, filepath: str) -> bool: """Load a RKME Image specification file in JSON format from the specified path. diff --git a/learnware/specification/regular/table/__init__.py b/learnware/specification/regular/table/__init__.py index 3cc9bd0..19fa956 100644 --- a/learnware/specification/regular/table/__init__.py +++ b/learnware/specification/regular/table/__init__.py @@ -1 +1 @@ -from .rkme import RKMETableSpecification +from .rkme import RKMETableSpecification, RKMEStatSpecification diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index ab763d8..ba76f6b 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -26,7 +26,9 @@ from ....logger import get_module_logger logger = get_module_logger("rkme") if not _FAISS_INSTALLED: - logger.warning("Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first") + logger.warning( + "Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first" + ) class RKMETableSpecification(RegularStatsSpecification): @@ -463,12 +465,15 @@ class RKMETableSpecification(RegularStatsSpecification): else: return False + class RKMEStatSpecification(RKMETableSpecification): """nickname for RKMETableSpecification, for compatibility currently. TODO: modify all learnware in database and remove this nickname """ + pass + def setup_seed(seed): """Fix a random seed for addressing reproducibility issues. diff --git a/tests/test_specification/test_rkme.py b/tests/test_specification/test_rkme.py index 613e40b..c77e654 100644 --- a/tests/test_specification/test_rkme.py +++ b/tests/test_specification/test_rkme.py @@ -11,7 +11,6 @@ from learnware.specification import generate_rkme_image_spec, generate_rkme_spec class TestRKME(unittest.TestCase): def test_rkme(self): - pass X = np.random.uniform(-10000, 10000, size=(5000, 200)) rkme = generate_rkme_spec(X) rkme.generate_stat_spec_from_data(X)