Browse Source

[MNT, FIX] modify typehint for easysearch, and fix file not close warning

tags/v0.3.2
bxdd 2 years ago
parent
commit
fda945f5cc
8 changed files with 25 additions and 37 deletions
  1. +2
    -15
      learnware/market/easy2/organizer.py
  2. +10
    -10
      learnware/market/easy2/searcher.py
  3. +1
    -1
      learnware/specification/regular/__init__.py
  4. +2
    -0
      learnware/specification/regular/base.py
  5. +3
    -8
      learnware/specification/regular/image/rkme.py
  6. +1
    -1
      learnware/specification/regular/table/__init__.py
  7. +6
    -1
      learnware/specification/regular/table/rkme.py
  8. +0
    -1
      tests/test_specification/test_rkme.py

+ 2
- 15
learnware/market/easy2/organizer.py View File

@@ -1,28 +1,15 @@
import os
import json
import copy
import torch
import zipfile
import traceback
import tempfile
import numpy as np
import pandas as pd
from rapidfuzz import fuzz
from cvxopt import solvers, matrix
from shutil import copyfile, rmtree
from typing import Tuple, Any, List, Union, Dict
from typing import Tuple, List, Union

from .database_ops import DatabaseOperations
from ..base import LearnwareMarket, BaseUserInfo


from ... import utils
from ..base import BaseOrganizer, BaseChecker
from ...config import C as conf
from ...logger import get_module_logger
from ...learnware import Learnware, get_learnware_from_dirpath
from ...specification import Specification

from ..base import BaseOrganizer, BaseChecker
from ...logger import get_module_logger

logger = get_module_logger("easy_organizer")


+ 10
- 10
learnware/market/easy2/searcher.py View File

@@ -2,12 +2,12 @@ import torch
import numpy as np
from rapidfuzz import fuzz
from cvxopt import solvers, matrix
from typing import Tuple, List
from typing import Tuple, List, Union

from .organizer import EasyOrganizer
from ..base import BaseUserInfo, BaseSearcher
from ...learnware import Learnware
from ...specification import RKMETableSpecification
from ...specification import RKMETableSpecification, RKMEImageSpecification
from ...logger import get_module_logger

logger = get_module_logger("easy_seacher")
@@ -188,7 +188,7 @@ class EasyFuzzSemanticSearcher(BaseSearcher):
return final_result


class EasyTableSearcher(BaseSearcher):
class EasyStatSearcher(BaseSearcher):
def _convert_dist_to_score(
self, dist_list: List[float], dist_epsilon: float = 0.01, min_score: float = 0.92
) -> List[float]:
@@ -419,7 +419,7 @@ class EasyTableSearcher(BaseSearcher):
return sorted_score_list[:idx], learnware_list[:idx]

def _filter_by_rkme_spec_dimension(
self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification
self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification]
) -> List[Learnware]:
"""Filter learnwares whose rkme dimension different from user_rkme

@@ -427,7 +427,7 @@ class EasyTableSearcher(BaseSearcher):
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMETableSpecification
user_rkme : Union[RKMETableSpecification, RKMEImageSpecification]
User RKME statistical specification

Returns
@@ -519,7 +519,7 @@ class EasyTableSearcher(BaseSearcher):
return mmd_dist, weight_min, mixture_list

def _search_by_rkme_spec_single(
self, learnware_list: List[Learnware], user_rkme: RKMETableSpecification
self, learnware_list: List[Learnware], user_rkme: Union[RKMETableSpecification, RKMEImageSpecification]
) -> Tuple[List[float], List[Learnware]]:
"""Calculate the distances between learnwares in the given learnware_list and user_rkme

@@ -527,7 +527,7 @@ class EasyTableSearcher(BaseSearcher):
----------
learnware_list : List[Learnware]
The list of learnwares whose mixture approximates the user's rkme
user_rkme : RKMETableSpecification
user_rkme : Union[RKMETableSpecification, RKMEImageSpecification]
user RKME statistical specification

Returns
@@ -599,12 +599,12 @@ class EasySearcher(BaseSearcher):
def __init__(self, organizer: EasyOrganizer = None):
super(EasySearcher, self).__init__(organizer)
self.semantic_searcher = EasyFuzzSemanticSearcher(organizer)
self.table_searcher = EasyTableSearcher(organizer)
self.stat_searcher = EasyStatSearcher(organizer)

def reset(self, organizer):
self.learnware_oganizer = organizer
self.semantic_searcher.reset(organizer)
self.table_searcher.reset(organizer)
self.stat_searcher.reset(organizer)

def __call__(
self, user_info: BaseUserInfo, max_search_num: int = 5, search_method: str = "greedy"
@@ -632,6 +632,6 @@ class EasySearcher(BaseSearcher):
if len(learnware_list) == 0:
return [], [], 0.0, []
elif "RKMETableSpecification" in user_info.stat_info:
return self.table_searcher(learnware_list, user_info, max_search_num, search_method)
return self.stat_searcher(learnware_list, user_info, max_search_num, search_method)
else:
return None, learnware_list, 0.0, None

+ 1
- 1
learnware/specification/regular/__init__.py View File

@@ -1,3 +1,3 @@
from .table import RKMETableSpecification, RKMEStatSpecification
from .image import RKMEImageSpecification
from .base import RegularStatsSpecification
from .base import RegularStatsSpecification

+ 2
- 0
learnware/specification/regular/base.py View File

@@ -1,3 +1,5 @@
from __future__ import annotations

from ..base import BaseStatSpecification




+ 3
- 8
learnware/specification/regular/image/rkme.py View File

@@ -122,9 +122,7 @@ class RKMEImageSpecification(BaseStatSpecification):
X[i] = torch.where(is_nan, img_mean, img)

if X.shape[2] != RKMEImageSpecification.IMAGE_WIDTH or X.shape[3] != RKMEImageSpecification.IMAGE_WIDTH:
X = Resize(
(RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None
)(X)
X = Resize((RKMEImageSpecification.IMAGE_WIDTH, RKMEImageSpecification.IMAGE_WIDTH), antialias=None)(X)

num_points = X.shape[0]
X_shape = X.shape
@@ -343,11 +341,8 @@ class RKMEImageSpecification(BaseStatSpecification):
rkme_to_save["beta"] = rkme_to_save["beta"].tolist()
rkme_to_save["device"] = "gpu" if rkme_to_save["cuda_idx"] != -1 else "cpu"

json.dump(
rkme_to_save,
codecs.open(save_path, "w", encoding="utf-8"),
separators=(",", ":"),
)
with codecs.open(save_path, "w", encoding="utf-8") as fout:
json.dump(rkme_to_save, fout, separators=(",", ":"))

def load(self, filepath: str) -> bool:
"""Load a RKME Image specification file in JSON format from the specified path.


+ 1
- 1
learnware/specification/regular/table/__init__.py View File

@@ -1 +1 @@
from .rkme import RKMETableSpecification
from .rkme import RKMETableSpecification, RKMEStatSpecification

+ 6
- 1
learnware/specification/regular/table/rkme.py View File

@@ -26,7 +26,9 @@ from ....logger import get_module_logger
logger = get_module_logger("rkme")

if not _FAISS_INSTALLED:
logger.warning("Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first")
logger.warning(
"Required faiss version >= 1.7.1 is not detected! Please run 'conda install -c pytorch faiss-cpu' first"
)


class RKMETableSpecification(RegularStatsSpecification):
@@ -463,12 +465,15 @@ class RKMETableSpecification(RegularStatsSpecification):
else:
return False


class RKMEStatSpecification(RKMETableSpecification):
"""nickname for RKMETableSpecification, for compatibility currently.
TODO: modify all learnware in database and remove this nickname
"""

pass


def setup_seed(seed):
"""Fix a random seed for addressing reproducibility issues.



+ 0
- 1
tests/test_specification/test_rkme.py View File

@@ -11,7 +11,6 @@ from learnware.specification import generate_rkme_image_spec, generate_rkme_spec

class TestRKME(unittest.TestCase):
def test_rkme(self):
pass
X = np.random.uniform(-10000, 10000, size=(5000, 200))
rkme = generate_rkme_spec(X)
rkme.generate_stat_spec_from_data(X)


Loading…
Cancel
Save