From 46189ce36e14f71c6b2d85f4ec45959990250a3b Mon Sep 17 00:00:00 2001 From: chenzx Date: Thu, 6 Apr 2023 11:04:37 +0800 Subject: [PATCH 1/3] [ENH] Implement load from db and init db --- examples/example_market_db/example_db.py | 2 +- learnware/market/database_ops.py | 64 ++++++++++++++++-------- learnware/market/easy.py | 11 ++-- learnware/specification/base.py | 4 +- learnware/specification/rkme.py | 1 - 5 files changed, 54 insertions(+), 28 deletions(-) diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index 9a34d81..886518b 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -1,3 +1,3 @@ import learnware.market.database_ops as db_ops -db_ops.init_empty_db() +db_ops.load_market_from_db() diff --git a/learnware/market/database_ops.py b/learnware/market/database_ops.py index 8e9b2d9..e88108a 100644 --- a/learnware/market/database_ops.py +++ b/learnware/market/database_ops.py @@ -2,37 +2,44 @@ import sqlite3 import os from ..logger import get_module_logger from ..learnware import Learnware +from ..specification import RKMEStatSpecification, Specification +import json ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) DB_PATH = os.path.join(ROOT_PATH, "market.db") LOGGER = get_module_logger("market", level="INFO") +def init_empty_db(func): + def wrapper(): + if not os.path.exists(DB_PATH): + conn = sqlite3.connect(DB_PATH) + LOGGER.info("Initializing Database in %s..." % (DB_PATH)) + c = conn.cursor() + c.execute( + """CREATE TABLE LEARNWARE + (ID CHAR(10) PRIMARY KEY NOT NULL, + NAME TEXT NOT NULL, + SEMANTIC_SPEC TEXT NOT NULL, + MODEL_PATH TEXT NOT NULL, + STAT_SPEC_PATH TEXT NOT NULL);""" + ) + LOGGER.info("Database Built!") + conn.commit() + conn.close() + else: + pass + return wrapper + +@init_empty_db def add_learnware_to_db(): pass - -def delete_learnware_from_db(): +@init_empty_db +def delete_learnware_from_db(id:str): pass - -def init_empty_db(): - conn = sqlite3.connect(DB_PATH) - LOGGER.info("Initializing Database in %s..." % (DB_PATH)) - c = conn.cursor() - c.execute( - """CREATE TABLE LEARNWARE - (ID CHAR(10) PRIMARY KEY NOT NULL, - NAME TEXT NOT NULL, - SEMANTIC_SPEC TEXT NOT NULL, - MODEL_PATH TEXT NOT NULL, - STAT_SPEC_PATH TEXT NOT NULL);""" - ) - LOGGER.info("Database Built!") - conn.commit() - conn.close() - - +@init_empty_db def load_market_from_db(): if not os.path.exists(DB_PATH): init_empty_db() @@ -40,6 +47,23 @@ def load_market_from_db(): c = conn.cursor() cursor = c.execute("SELECT id, name, semantic_spec, model_path, stat_spec_path from LEARNWARE") + learnware_list = {} + max_count = 0 for item in cursor: id, name, semantic_spec, model_path, stat_spec_path = item + semantic_spec_dict = json.loads(semantic_spec) + stat_spec_path_dict = json.loads(stat_spec_path) + stat_spec_dict = {} + for stat_spec_name in stat_spec_path_dict: + new_stat_spec = RKMEStatSpecification() + new_stat_spec.load(stat_spec_dict[stat_spec_name]) + stat_spec_dict[stat_spec_name] = new_stat_spec + model_dict = {'model_path':model_path, 'class_name':'BaseModel'} + specification = Specification(semantic_spec=semantic_spec_dict, stat_spec=stat_spec_dict) + new_learnware = Learnware(id=id, name=name, model=model_dict, specification=specification) + learnware_list[id] = new_learnware + max_count = max(max_count, int(id)) + conn.commit() + conn.close() LOGGER.info("Market Reloaded from DB.") + return learnware_list, max_count diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 2a7ec90..b667031 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -7,6 +7,7 @@ from typing import Tuple, Any, List, Union, Dict from .base import BaseMarket, BaseUserInfo from ..learnware import Learnware from ..specification import RKMEStatSpecification, Specification +from .database_ops import load_market_from_db, add_learnware_to_db, delete_learnware_from_db class EasyMarket(BaseMarket): @@ -64,8 +65,9 @@ class EasyMarket(BaseMarket): }, } - def reload_market(self, market_path: str, semantic_spec_list_path: str) -> bool: - raise NotImplementedError("reload market is Not Implemented") + def reload_market(self) -> bool: + self.learnware_list, self.count = load_market_from_db() + def add_learnware( self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str @@ -108,8 +110,9 @@ class EasyMarket(BaseMarket): id = "%08d" % (self.count) rkme_stat_spec = RKMEStatSpecification() rkme_stat_spec.load(stat_spec_path) - specification = Specification(semantic_spec=semantic_spec) - specification.update_stat_spec("RKME", rkme_stat_spec) + stat_spec = {'RKME':rkme_stat_spec} + specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec) + # specification.update_stat_spec("RKME", rkme_stat_spec) model_dict = {"model_path": model_path, "class_name": "BaseModel"} new_learnware = Learnware(id=id, name=learnware_name, model=model_dict, specification=specification) self.learnware_list[id] = new_learnware diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 98af3c9..12b61c6 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -16,9 +16,9 @@ class BaseStatSpecification: class Specification: - def __init__(self, semantic_spec: dict = None): + def __init__(self, semantic_spec: dict = None, stat_spec: dict = {}): self.semantic_spec = semantic_spec - self.stat_spec = {} # stat_spec should be dict + self.stat_spec = stat_spec def get_stat_spec(self): return self.stat_spec diff --git a/learnware/specification/rkme.py b/learnware/specification/rkme.py index e52975f..54a15c6 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/rkme.py @@ -17,7 +17,6 @@ from .base import BaseStatSpecification # mkl.get_max_threads() - class RKMEStatSpecification(BaseStatSpecification): """Reduced-set Kernel Mean Embedding (RKME) Specification""" From f81ac9f4d47d978bf1209cbd4fd7826d3dbd3582 Mon Sep 17 00:00:00 2001 From: chenzx Date: Thu, 6 Apr 2023 11:16:41 +0800 Subject: [PATCH 2/3] [FIX] update logger --- learnware/logger.py | 5 ++++- learnware/market/database_ops.py | 9 ++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/learnware/logger.py b/learnware/logger.py index e838dc9..ddf0534 100644 --- a/learnware/logger.py +++ b/learnware/logger.py @@ -4,7 +4,7 @@ import logging.handlers from .config import C -def get_module_logger(module_name, level=None): +def get_module_logger(module_name, level=logging.INFO): """ Get a logger for a specific module. :param module_name: str @@ -20,6 +20,9 @@ def get_module_logger(module_name, level=None): level = C.logging_level # Get logger. + console_handler = logging.StreamHandler() + console_handler.setLevel('INFO') module_logger = logging.getLogger(module_name) module_logger.setLevel(level) + module_logger.addHandler(console_handler) return module_logger diff --git a/learnware/market/database_ops.py b/learnware/market/database_ops.py index e88108a..9ecd173 100644 --- a/learnware/market/database_ops.py +++ b/learnware/market/database_ops.py @@ -7,7 +7,7 @@ import json ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) DB_PATH = os.path.join(ROOT_PATH, "market.db") -LOGGER = get_module_logger("market", level="INFO") +LOGGER = get_module_logger("market") def init_empty_db(func): @@ -27,8 +27,8 @@ def init_empty_db(func): LOGGER.info("Database Built!") conn.commit() conn.close() - else: - pass + func() + return wrapper @init_empty_db @@ -41,9 +41,8 @@ def delete_learnware_from_db(id:str): @init_empty_db def load_market_from_db(): - if not os.path.exists(DB_PATH): - init_empty_db() conn = sqlite3.connect(DB_PATH) + LOGGER.info("Reload from Database") c = conn.cursor() cursor = c.execute("SELECT id, name, semantic_spec, model_path, stat_spec_path from LEARNWARE") From b787387dd4bf7e559920a8bf2d4dd1ee28b1a932 Mon Sep 17 00:00:00 2001 From: chenzx Date: Thu, 6 Apr 2023 11:17:37 +0800 Subject: [PATCH 3/3] [FIX] update logger --- learnware/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/learnware/logger.py b/learnware/logger.py index ddf0534..8345077 100644 --- a/learnware/logger.py +++ b/learnware/logger.py @@ -21,7 +21,7 @@ def get_module_logger(module_name, level=logging.INFO): # Get logger. console_handler = logging.StreamHandler() - console_handler.setLevel('INFO') + console_handler.setLevel(level) module_logger = logging.getLogger(module_name) module_logger.setLevel(level) module_logger.addHandler(console_handler)