From f76f1026ffe6f4ffe8fb90fa3ee83bba641995c3 Mon Sep 17 00:00:00 2001 From: bxdd Date: Wed, 1 Nov 2023 22:07:58 +0800 Subject: [PATCH] [FIX] fix rkme name bugs --- learnware/learnware/utils.py | 2 +- learnware/market/easy.py | 12 ++++++------ learnware/specification/base.py | 2 +- learnware/specification/regular/table/rkme.py | 4 +++- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/learnware/learnware/utils.py b/learnware/learnware/utils.py index b6dccd0..d5cdc14 100644 --- a/learnware/learnware/utils.py +++ b/learnware/learnware/utils.py @@ -47,4 +47,4 @@ def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification: if stat_spec_inst.load(stat_spec["file_name"]) is False: raise ValueError("Load statistic specification failed!") - return stat_spec["class_name"], stat_spec_inst + return stat_spec_inst.type, stat_spec_inst diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 1f5c084..957efda 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -116,7 +116,7 @@ class EasyMarket(LearnwareMarket): pass # check rkme dimension - stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") + stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") if stat_spec is not None: if stat_spec.get_z().shape[1:] != input_shape: logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") @@ -321,7 +321,7 @@ class EasyMarket(LearnwareMarket): """ learnware_num = len(learnware_list) RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] if type(intermediate_K) == np.ndarray: @@ -390,7 +390,7 @@ class EasyMarket(LearnwareMarket): """ num = intermediate_K.shape[0] - 1 RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] for i in range(intermediate_K.shape[0]): intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) @@ -446,7 +446,7 @@ class EasyMarket(LearnwareMarket): if len(mixture_list) <= 1: mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] mixture_weight = [1] - mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification")) + mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification")) else: if len(mixture_list) > max_search_num: mixture_list = mixture_list[:max_search_num] @@ -508,7 +508,7 @@ class EasyMarket(LearnwareMarket): user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) for learnware in learnware_list: - rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") + rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification") rkme_dim = str(list(rkme.get_z().shape)[1:]) if rkme_dim == user_rkme_dim: filtered_learnware_list.append(learnware) @@ -607,7 +607,7 @@ class EasyMarket(LearnwareMarket): both lists are sorted by mmd dist """ RKME_list = [ - learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list + learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list ] mmd_dist_list = [] for RKME in RKME_list: diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 1c340fa..28212d6 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -83,7 +83,7 @@ class Specification: or use class name as default name """ for _v in args: - self.stat_spec[_v.__class__.__name__] = _v + self.stat_spec[_v.type] = _v for _k, _v in kwargs.items(): self.stat_spec[_k] = _v diff --git a/learnware/specification/regular/table/rkme.py b/learnware/specification/regular/table/rkme.py index 82c81a2..17aedc1 100644 --- a/learnware/specification/regular/table/rkme.py +++ b/learnware/specification/regular/table/rkme.py @@ -468,7 +468,9 @@ class RKMEStatSpecification(RKMETableSpecification): TODO: modify all learnware in database and remove this nickname """ - pass + def __init__(self, gamma: float = 0.1, cuda_idx: int = -1): + super(RKMEStatSpecification, self).__init__(gamma=gamma, cuda_idx=cuda_idx) + super(RKMETableSpecification, self).__init__(type=RKMETableSpecification.__name__) def setup_seed(seed):