|
|
|
@@ -116,7 +116,7 @@ class EasyMarket(LearnwareMarket): |
|
|
|
pass |
|
|
|
|
|
|
|
# check rkme dimension |
|
|
|
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMEStatSpecification") |
|
|
|
stat_spec = learnware.get_specification().get_stat_spec_by_name("RKMETableSpecification") |
|
|
|
if stat_spec is not None: |
|
|
|
if stat_spec.get_z().shape[1:] != input_shape: |
|
|
|
logger.warning(f"The learnware [{learnware.id}] input dimension mismatch with stat specification") |
|
|
|
@@ -321,7 +321,7 @@ class EasyMarket(LearnwareMarket): |
|
|
|
""" |
|
|
|
learnware_num = len(learnware_list) |
|
|
|
RKME_list = [ |
|
|
|
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list |
|
|
|
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list |
|
|
|
] |
|
|
|
|
|
|
|
if type(intermediate_K) == np.ndarray: |
|
|
|
@@ -390,7 +390,7 @@ class EasyMarket(LearnwareMarket): |
|
|
|
""" |
|
|
|
num = intermediate_K.shape[0] - 1 |
|
|
|
RKME_list = [ |
|
|
|
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list |
|
|
|
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list |
|
|
|
] |
|
|
|
for i in range(intermediate_K.shape[0]): |
|
|
|
intermediate_K[num, i] = RKME_list[-1].inner_prod(RKME_list[i]) |
|
|
|
@@ -446,7 +446,7 @@ class EasyMarket(LearnwareMarket): |
|
|
|
if len(mixture_list) <= 1: |
|
|
|
mixture_list = [learnware_list[sort_by_weight_idx_list[0]]] |
|
|
|
mixture_weight = [1] |
|
|
|
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMEStatSpecification")) |
|
|
|
mmd_dist = user_rkme.dist(mixture_list[0].specification.get_stat_spec_by_name("RKMETableSpecification")) |
|
|
|
else: |
|
|
|
if len(mixture_list) > max_search_num: |
|
|
|
mixture_list = mixture_list[:max_search_num] |
|
|
|
@@ -508,7 +508,7 @@ class EasyMarket(LearnwareMarket): |
|
|
|
user_rkme_dim = str(list(user_rkme.get_z().shape)[1:]) |
|
|
|
|
|
|
|
for learnware in learnware_list: |
|
|
|
rkme = learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") |
|
|
|
rkme = learnware.specification.get_stat_spec_by_name("RKMETableSpecification") |
|
|
|
rkme_dim = str(list(rkme.get_z().shape)[1:]) |
|
|
|
if rkme_dim == user_rkme_dim: |
|
|
|
filtered_learnware_list.append(learnware) |
|
|
|
@@ -607,7 +607,7 @@ class EasyMarket(LearnwareMarket): |
|
|
|
both lists are sorted by mmd dist |
|
|
|
""" |
|
|
|
RKME_list = [ |
|
|
|
learnware.specification.get_stat_spec_by_name("RKMEStatSpecification") for learnware in learnware_list |
|
|
|
learnware.specification.get_stat_spec_by_name("RKMETableSpecification") for learnware in learnware_list |
|
|
|
] |
|
|
|
mmd_dist_list = [] |
|
|
|
for RKME in RKME_list: |
|
|
|
|