From 195fcde3d53c0c78a4c78efef0576278445fbbd5 Mon Sep 17 00:00:00 2001 From: bxdd Date: Wed, 19 Apr 2023 19:35:56 +0800 Subject: [PATCH] [ENH] Split market with database operation --- learnware/learnware/base.py | 1 + learnware/market/database_ops.py | 4 ++-- learnware/market/easy.py | 15 ++++++++++----- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/learnware/learnware/base.py b/learnware/learnware/base.py index e2142d2..5395e72 100644 --- a/learnware/learnware/base.py +++ b/learnware/learnware/base.py @@ -46,6 +46,7 @@ class Learnware: elif isinstance(self.model, dict): model_module = get_module_by_module_path(self.model["module_path"]) self.model = getattr(model_module, self.model["class_name"])(**self.model.get("kwargs", {})) + print(self.model) else: raise TypeError(f"Model must be BaseModel or dict, not {type(self.model)}") diff --git a/learnware/market/database_ops.py b/learnware/market/database_ops.py index 67c8304..0dc5b7f 100644 --- a/learnware/market/database_ops.py +++ b/learnware/market/database_ops.py @@ -11,8 +11,8 @@ logger = get_module_logger("database_ops") def init_empty_db(func): - def wrapper(*args, **kwargs): - conn = sqlite3.connect(os.path.join(C.database_path, "market.db")) + def wrapper(market_id, *args, **kwargs): + conn = sqlite3.connect(os.path.join(C.database_path, f"market_{market_id}.db")) cur = conn.cursor() listOfTables = cur.execute( """SELECT name FROM sqlite_master WHERE type='table' AND name='LEARNWARE'; """ diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 0133fa1..3c520fe 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -23,17 +23,21 @@ class EasyMarket(BaseMarket): NOPREDICTION_LEARNWARE = 0 PREDICTION_LEARWARE = 1 - def __init__(self, market_id: str = None, rebuild: bool = False): + def __init__(self, market_id: str = "default", rebuild: bool = False): """Initialize Learnware Market. Automatically reload from db if available. Build an empty db otherwise. Parameters ---------- + market_id : str, optional, by default 'default' + The unique market id for market database + rebuild : bool, optional Clear current database if set to True, by default False !!! Do NOT set to True unless highly necessary !!! """ + self.market_id = market_id self.learnware_list = {} # id: Learnware self.learnware_zip_list = {} self.learnware_folder_list = {} @@ -45,13 +49,13 @@ class EasyMarket(BaseMarket): def reload_market(self, rebuild: bool = False) -> bool: if rebuild: logger.warning("Warning! You are trying to clear current database!") - clear_learnware_table() + clear_learnware_table(market_id=self.market_id) rmtree(C.learnware_pool_path) os.makedirs(C.learnware_pool_path, exist_ok=True) os.makedirs(C.learnware_zip_pool_path, exist_ok=True) os.makedirs(C.learnware_folder_pool_path, exist_ok=True) - self.learnware_list, self.learnware_zip_list, self.learnware_folder_list, self.count = load_market_from_db() + self.learnware_list, self.learnware_zip_list, self.learnware_folder_list, self.count = load_market_from_db(market_id=self.market_id) @classmethod def check_learnware(cls, learnware: Learnware) -> int: @@ -176,7 +180,8 @@ class EasyMarket(BaseMarket): self.learnware_folder_list[id] = target_folder_dir self.count += 1 add_learnware_to_db( - id, + market_id=self.market_id, + id=id, semantic_spec=semantic_spec, zip_path=target_zip_dir, folder_path=target_folder_dir, @@ -655,7 +660,7 @@ class EasyMarket(BaseMarket): self.learnware_list.pop(id) self.learnware_zip_list.pop(id) self.learnware_folder_list.pop(id) - delete_learnware_from_db(id) + delete_learnware_from_db(market_id=self.market_id, id=id) return True