From 720af10c0f8c4421c2a8b683e215e22aa241fb98 Mon Sep 17 00:00:00 2001 From: Gene Date: Mon, 20 Nov 2023 21:12:19 +0800 Subject: [PATCH] [FIX] add data_type check in is_hetero --- learnware/market/heterogeneous/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/learnware/market/heterogeneous/utils.py b/learnware/market/heterogeneous/utils.py index 1be7ba6..6732991 100644 --- a/learnware/market/heterogeneous/utils.py +++ b/learnware/market/heterogeneous/utils.py @@ -20,6 +20,11 @@ def is_hetero(stat_specs: dict, semantic_spec: dict) -> bool: table_stat_spec = stat_specs["RKMETableSpecification"] table_input_shape = table_stat_spec.get_z().shape[1] + semantic_data_type = semantic_spec["Data"]["Values"] + if len(semantic_data_type) > 0 and semantic_data_type != ["Table"]: + logger.warning("User doesn't provide correct data type, it must be Table.") + return False + semantic_task_type = semantic_spec["Task"]["Values"] if len(semantic_task_type) > 0 and semantic_task_type not in [["Classification"], ["Regression"]]: logger.warning("User doesn't provide correct task type, it must be either Classification or Regression.")