| @@ -1,28 +1,15 @@ | |||
| import os | |||
| import json | |||
| import copy | |||
| import torch | |||
| import zipfile | |||
| import traceback | |||
| import tempfile | |||
| import numpy as np | |||
| import pandas as pd | |||
| from rapidfuzz import fuzz | |||
| from cvxopt import solvers, matrix | |||
| from shutil import copyfile, rmtree | |||
| from typing import Tuple, Any, List, Union, Dict | |||
| from typing import Tuple, List, Union | |||
| from .database_ops import DatabaseOperations | |||
| from ..base import LearnwareMarket, BaseUserInfo | |||
| from ... import utils | |||
| from ..base import BaseOrganizer, BaseChecker | |||
| from ...config import C as conf | |||
| from ...logger import get_module_logger | |||
| from ...learnware import Learnware, get_learnware_from_dirpath | |||
| from ...specification import Specification | |||
| from ..base import BaseOrganizer, BaseChecker | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("easy_organizer") | |||
| @@ -2,12 +2,12 @@ import torch | |||
| import numpy as np | |||
| from rapidfuzz import fuzz | |||
| from cvxopt import solvers, matrix | |||
| from typing import Tuple, List | |||
| from typing import Tuple, List, Union | |||
| from .organizer import EasyOrganizer | |||
| from ..base import BaseUserInfo, BaseSearcher | |||
| from ...learnware import Learnware | |||
| from ...specification import RKMETableSpecification | |||
| from ...specification import RKMETableSpecification, RKMEImageSpecification | |||
| from ...logger import get_module_logger | |||
| logger = get_module_logger("easy_seacher") | |||
| @@ -188,7 +188,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher): | |||
| return final_result | |||
| class EasyTableSearcher(BaseSearcher): | |||
| class EasyStatSearcher(BaseSearcher): | |||
| def _convert_dist_to_score( | |||
| self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92 | |||
| ) -> List[float]: | |||
| @@ -419,7 +419,7 @@ class EasyTableSearcher(BaseSearcher): | |||
| return sorted_score_list[:idx], learnware_list[:idx] | |||
| def _filter_by_rkme_spec_dimension( | |||
| self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification | |||
| self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] | |||
| ) -> List[Learnware]: | |||
| """Filter learnwares whose rkme dimension different from user_rkme | |||
| @@ -427,7 +427,7 @@ class EasyTableSearcher(BaseSearcher): | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : RKMETableSpecification | |||
| user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] | |||
| User RKME statistical specification | |||
| Returns | |||
| @@ -519,7 +519,7 @@ class EasyTableSearcher(BaseSearcher): | |||
| return mmd_dist, weight_min, mixture_list | |||
| def _search_by_rkme_spec_single( | |||
| self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification | |||
| self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification] | |||
| ) -> Tuple[List[float], List[Learnware]]: | |||
| """Calculate the distances between learnwares in the given learnware_list and user_rkme | |||
| @@ -527,7 +527,7 @@ class EasyTableSearcher(BaseSearcher): | |||
| ---------- | |||
| learnware_list : List[Learnware] | |||
| The list of learnwares whose mixture approximates the user's rkme | |||
| user_rkme : RKMETableSpecification | |||
| user_rkme : Union[RKMETableSpecification, RKMEImageSpecification] | |||
| user RKME statistical specification | |||
| Returns | |||
| @@ -599,12 +599,12 @@ class EasySearcher(BaseSearcher): | |||
| def __init__(self, organizer: EasyOrganizer = None): | |||
| super(EasySearcher, self).__init__(organizer) | |||
| self.semantic_searcher = EasyFuzzSemanticSearcher(organizer) | |||
| self.table_searcher = EasyTableSearcher(organizer) | |||
| self.stat_searcher = EasyStatSearcher(organizer) | |||
| def reset(self, organizer): | |||
| self.learnware_oganizer = organizer | |||
| self.semantic_searcher.reset(organizer) | |||
| self.table_searcher.reset(organizer) | |||
| self.stat_searcher.reset(organizer) | |||
| def __call__( | |||
| self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy" | |||
| @@ -632,6 +632,6 @@ class EasySearcher(BaseSearcher): | |||
| if len(learnware_list) == 0: | |||
| return [], [], 0.0, [] | |||
| elif "RKMETableSpecification" in user_info.stat_info: | |||
| return self.table_searcher(learnware_list, user_info, max_search_num, search_method) | |||
| return self.stat_searcher(learnware_list, user_info, max_search_num, search_method) | |||
| else: | |||
| return None, learnware_list, 0.0, None | |||
| @@ -1,3 +1,3 @@ | |||
| from .table import RKMETableSpecification, RKMEStatSpecification | |||
| from .image import RKMEImageSpecification | |||
| from .base import RegularStatsSpecification | |||
| from .base import RegularStatsSpecification | |||
| @@ -1,3 +1,5 @@ | |||
| from __future__ import annotations | |||
| from ..base import BaseStatSpecification | |||
| @@ -122,9 +122,7 @@ class RKMEImageSpecification(BaseStatSpecification): | |||
| X[i] = torch.where(is_nan, img_mean, img) | |||
| if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH: | |||
| X = Resize( | |||
| (RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None | |||
| )(X) | |||
| X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None)(X) | |||
| num_points = X.shape[0] | |||
| X_shape = X.shape | |||
| @@ -343,11 +341,8 @@ class RKMEImageSpecification(BaseStatSpecification): | |||
| rkme_to_save["beta"] = rkme_to_save["beta"].tolist() | |||
| rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu" | |||
| json.dump( | |||
| rkme_to_save, | |||
| codecs.open(save_path, "w", encoding="utf-8"), | |||
| separators=(",", ":"), | |||
| ) | |||
| with codecs.open(save_path, "w", encoding="utf-8") as fout: | |||
| json.dump(rkme_to_save, fout, separators=(",", ":")) | |||
| def load(self, filepath: str) -> bool: | |||
| """Load a RKME Image specification file in JSON format from the specified path. | |||
| @@ -1 +1 @@ | |||
| from .rkme import RKMETableSpecification | |||
| from .rkme import RKMETableSpecification, RKMEStatSpecification | |||
| @@ -26,7 +26,9 @@ from ....logger import get_module_logger | |||
| logger = get_module_logger("rkme") | |||
| if not _FAISS_INSTALLED: | |||
| logger.warning("Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first") | |||
| logger.warning( | |||
| "Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first" | |||
| ) | |||
| class RKMETableSpecification(RegularStatsSpecification): | |||
| @@ -463,12 +465,15 @@ class RKMETableSpecification(RegularStatsSpecification): | |||
| else: | |||
| return False | |||
| class RKMEStatSpecification(RKMETableSpecification): | |||
| """nickname for RKMETableSpecification, for compatibility currently. | |||
| TODO: modify all learnware in database and remove this nickname | |||
| """ | |||
| pass | |||
| def setup_seed(seed): | |||
| """Fix a random seed for addressing reproducibility issues. | |||
| @@ -11,7 +11,6 @@ from learnware.specification import generate_rkme_image_spec, generate_rkme_spec | |||
| class TestRKME(unittest.TestCase): | |||
| def test_rkme(self): | |||
| pass | |||
| X = np.random.uniform(-10000, 10000, size=(5000, 200)) | |||
| rkme = generate_rkme_spec(X) | |||
| rkme.generate_stat_spec_from_data(X) | |||