Browse Source

[FIX] fix rkme name bugs

tags/v0.3.2
bxdd 2 years ago
parent
commit
f76f1026ff
4 changed files with 11 additions and 9 deletions
  1. +1
    -1
      learnware/learnware/utils.py
  2. +6
    -6
      learnware/market/easy.py
  3. +1
    -1
      learnware/specification/base.py
  4. +3
    -1
      learnware/specification/regular/table/rkme.py

+ 1
- 1
learnware/learnware/utils.py View File

@@ -47,4 +47,4 @@ def get_stat_spec_from_config(stat_spec: dict) -> BaseStatSpecification:
if stat_spec_inst.load(stat_spec["file_name"]) is False:
raise ValueError("Load statistic specification failed!")

return stat_spec["class_name"], stat_spec_inst
return stat_spec_inst.type, stat_spec_inst

+ 6
- 6
learnware/market/easy.py View File

@@ -116,7 +116,7 @@ class EasyMarket(LearnwareMarket):
pass

# check rkme dimension
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification")
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification")
if stat_spec is not None:
if stat_spec.get_z().shape[1:] != input_shape:
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification")
@@ -321,7 +321,7 @@ class EasyMarket(LearnwareMarket):
"""
learnware_num = len(learnware_list)
RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list
]

if type(intermediate_K) == np.ndarray:
@@ -390,7 +390,7 @@ class EasyMarket(LearnwareMarket):
"""
num = intermediate_K.shape[0] - 1
RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list
]
for i in range(intermediate_K.shape[0]):
intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i])
@@ -446,7 +446,7 @@ class EasyMarket(LearnwareMarket):
if len(mixture_list) <= 1:
mixture_list = [learnware_list[sort_by_weight_idx_list[0]]]
mixture_weight = [1]
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification"))
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification"))
else:
if len(mixture_list) > max_search_num:
mixture_list = mixture_list[:max_search_num]
@@ -508,7 +508,7 @@ class EasyMarket(LearnwareMarket):
user_rkme_dim = str(list(user_rkme.get_z().shape)[1:])

for learnware in learnware_list:
rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification")
rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification")
rkme_dim = str(list(rkme.get_z().shape)[1:])
if rkme_dim == user_rkme_dim:
filtered_learnware_list.append(learnware)
@@ -607,7 +607,7 @@ class EasyMarket(LearnwareMarket):
both lists are sorted by mmd dist
"""
RKME_list = [
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list
]
mmd_dist_list = []
for RKME in RKME_list:


+ 1
- 1
learnware/specification/base.py View File

@@ -83,7 +83,7 @@ class Specification:
or use class name as default name
"""
for _v in args:
self.stat_spec[_v.__class__.__name__] = _v
self.stat_spec[_v.type] = _v

for _k, _v in kwargs.items():
self.stat_spec[_k] = _v


+ 3
- 1
learnware/specification/regular/table/rkme.py View File

@@ -468,7 +468,9 @@ class RKMEStatSpecification(RKMETableSpecification):
TODO: modify all learnware in database and remove this nickname
"""

pass
def __init__(self, gamma: float = 0.1, cuda_idx: int = -1):
super(RKMEStatSpecification, self).__init__(gamma=gamma, cuda_idx=cuda_idx)
super(RKMETableSpecification, self).__init__(type=RKMETableSpecification.__name__)


def setup_seed(seed):


Loading…
Cancel
Save