From c44b76afd569ab2f0ea999cb5551299ca8fe75dc Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 16 Nov 2023 19:57:48 +0800 Subject: [PATCH] [FIX | MNT] fix bugs for hetero organizer, and modify tests --- learnware/market/heterogeneous/organizer/__init__.py | 6 +++++- .../example_learnwares/example_learnware_1/learnware.yaml | 8 -------- .../example_learnware_1/requirements.txt | 1 - .../{example_learnware_0 => }/learnware.yaml | 0 .../{example_learnware_0/__init__.py => model0.py} | 6 ------ .../{example_learnware_1/__init__.py => model1.py} | 6 ------ .../{example_learnware_0 => }/requirements.txt | 0 tests/test_hetero_market/test_hetero.py | 8 ++++---- 8 files changed, 9 insertions(+), 26 deletions(-) delete mode 100644 tests/test_hetero_market/example_learnwares/example_learnware_1/learnware.yaml delete mode 100644 tests/test_hetero_market/example_learnwares/example_learnware_1/requirements.txt rename tests/test_hetero_market/example_learnwares/{example_learnware_0 => }/learnware.yaml (100%) rename tests/test_hetero_market/example_learnwares/{example_learnware_0/__init__.py => model0.py} (78%) rename tests/test_hetero_market/example_learnwares/{example_learnware_1/__init__.py => model1.py} (78%) rename tests/test_hetero_market/example_learnwares/{example_learnware_0 => }/requirements.txt (100%) diff --git a/learnware/market/heterogeneous/organizer/__init__.py b/learnware/market/heterogeneous/organizer/__init__.py index 566faa6..113b8c3 100644 --- a/learnware/market/heterogeneous/organizer/__init__.py +++ b/learnware/market/heterogeneous/organizer/__init__.py @@ -206,6 +206,9 @@ class HeteroMapTableOrganizer(EasyOrganizer): str: id of target learware List[str]: A list of ids of target learnwares """ + if isinstance(ids, str): + ids = [ids] + for idx in ids: try: spec = self.learnware_list[idx].get_specification() @@ -218,7 +221,8 @@ class HeteroMapTableOrganizer(EasyOrganizer): hetero_spec.save(save_path) except Exception as err: - logger.warning(f"Learnware {idx} generate HeteroMapTableSpecification failed! Due to {err}") + traceback.print_exc() + logger.warning(f"Learnware {idx} generate HeteroMapTableSpecification failed!") def _get_hetero_learnware_ids(self, ids: Union[str, List[str]]) -> List[str]: """Get learnware ids that supports heterogeneous market training and search. diff --git a/tests/test_hetero_market/example_learnwares/example_learnware_1/learnware.yaml b/tests/test_hetero_market/example_learnwares/example_learnware_1/learnware.yaml deleted file mode 100644 index 4a37a37..0000000 --- a/tests/test_hetero_market/example_learnwares/example_learnware_1/learnware.yaml +++ /dev/null @@ -1,8 +0,0 @@ -model: - class_name: MyModel - kwargs: {} -stat_specifications: - - module_path: learnware.specification - class_name: RKMETableSpecification - file_name: stat.json - kwargs: {} \ No newline at end of file diff --git a/tests/test_hetero_market/example_learnwares/example_learnware_1/requirements.txt b/tests/test_hetero_market/example_learnwares/example_learnware_1/requirements.txt deleted file mode 100644 index 1da1c5f..0000000 --- a/tests/test_hetero_market/example_learnwares/example_learnware_1/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -learnware == 0.1.0.999 \ No newline at end of file diff --git a/tests/test_hetero_market/example_learnwares/example_learnware_0/learnware.yaml b/tests/test_hetero_market/example_learnwares/learnware.yaml similarity index 100% rename from tests/test_hetero_market/example_learnwares/example_learnware_0/learnware.yaml rename to tests/test_hetero_market/example_learnwares/learnware.yaml diff --git a/tests/test_hetero_market/example_learnwares/example_learnware_0/__init__.py b/tests/test_hetero_market/example_learnwares/model0.py similarity index 78% rename from tests/test_hetero_market/example_learnwares/example_learnware_0/__init__.py rename to tests/test_hetero_market/example_learnwares/model0.py index ea21917..45f64b7 100644 --- a/tests/test_hetero_market/example_learnwares/example_learnware_0/__init__.py +++ b/tests/test_hetero_market/example_learnwares/model0.py @@ -12,11 +12,5 @@ class MyModel(BaseModel): model = joblib.load(model_path) self.model = model - def fit(self, X: np.ndarray, y: np.ndarray): - pass - def predict(self, X: np.ndarray) -> np.ndarray: return self.model.predict(X) - - def finetune(self, X: np.ndarray, y: np.ndarray): - pass diff --git a/tests/test_hetero_market/example_learnwares/example_learnware_1/__init__.py b/tests/test_hetero_market/example_learnwares/model1.py similarity index 78% rename from tests/test_hetero_market/example_learnwares/example_learnware_1/__init__.py rename to tests/test_hetero_market/example_learnwares/model1.py index 11fb9e0..aca46b3 100644 --- a/tests/test_hetero_market/example_learnwares/example_learnware_1/__init__.py +++ b/tests/test_hetero_market/example_learnwares/model1.py @@ -12,11 +12,5 @@ class MyModel(BaseModel): model = joblib.load(model_path) self.model = model - def fit(self, X: np.ndarray, y: np.ndarray): - pass - def predict(self, X: np.ndarray) -> np.ndarray: return self.model.predict(X) - - def finetune(self, X: np.ndarray, y: np.ndarray): - pass diff --git a/tests/test_hetero_market/example_learnwares/example_learnware_0/requirements.txt b/tests/test_hetero_market/example_learnwares/requirements.txt similarity index 100% rename from tests/test_hetero_market/example_learnwares/example_learnware_0/requirements.txt rename to tests/test_hetero_market/example_learnwares/requirements.txt diff --git a/tests/test_hetero_market/test_hetero.py b/tests/test_hetero_market/test_hetero.py index 58c285e..0755699 100644 --- a/tests/test_hetero_market/test_hetero.py +++ b/tests/test_hetero_market/test_hetero.py @@ -75,7 +75,7 @@ class TestMarket(unittest.TestCase): example_learnware_idx = i % 2 input_dim = input_shape_list[example_learnware_idx] - example_learnware_name = "example_learnwares/example_learnware_%d" % (example_learnware_idx) + learnware_example_dir = "example_learnwares" X, y = make_regression(n_samples=5000, n_informative=15, n_features=input_dim, noise=0.1, random_state=42) @@ -89,16 +89,16 @@ class TestMarket(unittest.TestCase): init_file = os.path.join(dir_path, "__init__.py") copyfile( - os.path.join(curr_root, example_learnware_name, "__init__.py"), init_file + os.path.join(curr_root, learnware_example_dir, f"model{example_learnware_idx}.py"), init_file ) # cp example_init.py init_file yaml_file = os.path.join(dir_path, "learnware.yaml") copyfile( - os.path.join(curr_root, example_learnware_name, "learnware.yaml"), yaml_file + os.path.join(curr_root, learnware_example_dir, "learnware.yaml"), yaml_file ) # cp example.yaml yaml_file env_file = os.path.join(dir_path, "requirements.txt") - copyfile(os.path.join(curr_root, example_learnware_name, "requirements.txt"), env_file) + copyfile(os.path.join(curr_root, learnware_example_dir, "requirements.txt"), env_file) zip_file = dir_path + ".zip" # zip -q -r -j zip_file dir_path