diff --git a/examples/example_market_db/example_db.py b/examples/example_market_db/example_db.py index 9a34d81..886518b 100644 --- a/examples/example_market_db/example_db.py +++ b/examples/example_market_db/example_db.py @@ -1,3 +1,3 @@ import learnware.market.database_ops as db_ops -db_ops.init_empty_db() +db_ops.load_market_from_db() diff --git a/learnware/logger.py b/learnware/logger.py index b4db329..08db9d0 100644 --- a/learnware/logger.py +++ b/learnware/logger.py @@ -1,18 +1,17 @@ import logging -import logging.handlers +from logging import Logger, handlers from .config import C -def get_module_logger(module_name: str, level:int = None): +def get_module_logger(module_name: str, level:int = None) -> Logger: """Get a logger for a specific module. - Parameters ---------- module_name : str Logic module name. level : int, optional - logging level, by default None + Logging level, by default None Returns ------- @@ -23,6 +22,9 @@ def get_module_logger(module_name: str, level:int = None): level = C.logging_level # Get logger. + console_handler = logging.StreamHandler() + console_handler.setLevel(level) module_logger = logging.getLogger(module_name) module_logger.setLevel(level) + module_logger.addHandler(console_handler) return module_logger diff --git a/learnware/market/database_ops.py b/learnware/market/database_ops.py index 8e9b2d9..9ecd173 100644 --- a/learnware/market/database_ops.py +++ b/learnware/market/database_ops.py @@ -2,44 +2,67 @@ import sqlite3 import os from ..logger import get_module_logger from ..learnware import Learnware +from ..specification import RKMEStatSpecification, Specification +import json ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) DB_PATH = os.path.join(ROOT_PATH, "market.db") -LOGGER = get_module_logger("market", level="INFO") +LOGGER = get_module_logger("market") +def init_empty_db(func): -def add_learnware_to_db(): - pass + def wrapper(): + if not os.path.exists(DB_PATH): + conn = sqlite3.connect(DB_PATH) + LOGGER.info("Initializing Database in %s..." % (DB_PATH)) + c = conn.cursor() + c.execute( + """CREATE TABLE LEARNWARE + (ID CHAR(10) PRIMARY KEY NOT NULL, + NAME TEXT NOT NULL, + SEMANTIC_SPEC TEXT NOT NULL, + MODEL_PATH TEXT NOT NULL, + STAT_SPEC_PATH TEXT NOT NULL);""" + ) + LOGGER.info("Database Built!") + conn.commit() + conn.close() + func() + return wrapper -def delete_learnware_from_db(): +@init_empty_db +def add_learnware_to_db(): pass +@init_empty_db +def delete_learnware_from_db(id:str): + pass -def init_empty_db(): - conn = sqlite3.connect(DB_PATH) - LOGGER.info("Initializing Database in %s..." % (DB_PATH)) - c = conn.cursor() - c.execute( - """CREATE TABLE LEARNWARE - (ID CHAR(10) PRIMARY KEY NOT NULL, - NAME TEXT NOT NULL, - SEMANTIC_SPEC TEXT NOT NULL, - MODEL_PATH TEXT NOT NULL, - STAT_SPEC_PATH TEXT NOT NULL);""" - ) - LOGGER.info("Database Built!") - conn.commit() - conn.close() - - +@init_empty_db def load_market_from_db(): - if not os.path.exists(DB_PATH): - init_empty_db() conn = sqlite3.connect(DB_PATH) + LOGGER.info("Reload from Database") c = conn.cursor() cursor = c.execute("SELECT id, name, semantic_spec, model_path, stat_spec_path from LEARNWARE") + learnware_list = {} + max_count = 0 for item in cursor: id, name, semantic_spec, model_path, stat_spec_path = item + semantic_spec_dict = json.loads(semantic_spec) + stat_spec_path_dict = json.loads(stat_spec_path) + stat_spec_dict = {} + for stat_spec_name in stat_spec_path_dict: + new_stat_spec = RKMEStatSpecification() + new_stat_spec.load(stat_spec_dict[stat_spec_name]) + stat_spec_dict[stat_spec_name] = new_stat_spec + model_dict = {'model_path':model_path, 'class_name':'BaseModel'} + specification = Specification(semantic_spec=semantic_spec_dict, stat_spec=stat_spec_dict) + new_learnware = Learnware(id=id, name=name, model=model_dict, specification=specification) + learnware_list[id] = new_learnware + max_count = max(max_count, int(id)) + conn.commit() + conn.close() LOGGER.info("Market Reloaded from DB.") + return learnware_list, max_count diff --git a/learnware/market/easy.py b/learnware/market/easy.py index 2a7ec90..b667031 100644 --- a/learnware/market/easy.py +++ b/learnware/market/easy.py @@ -7,6 +7,7 @@ from typing import Tuple, Any, List, Union, Dict from .base import BaseMarket, BaseUserInfo from ..learnware import Learnware from ..specification import RKMEStatSpecification, Specification +from .database_ops import load_market_from_db, add_learnware_to_db, delete_learnware_from_db class EasyMarket(BaseMarket): @@ -64,8 +65,9 @@ class EasyMarket(BaseMarket): }, } - def reload_market(self, market_path: str, semantic_spec_list_path: str) -> bool: - raise NotImplementedError("reload market is Not Implemented") + def reload_market(self) -> bool: + self.learnware_list, self.count = load_market_from_db() + def add_learnware( self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str @@ -108,8 +110,9 @@ class EasyMarket(BaseMarket): id = "%08d" % (self.count) rkme_stat_spec = RKMEStatSpecification() rkme_stat_spec.load(stat_spec_path) - specification = Specification(semantic_spec=semantic_spec) - specification.update_stat_spec("RKME", rkme_stat_spec) + stat_spec = {'RKME':rkme_stat_spec} + specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec) + # specification.update_stat_spec("RKME", rkme_stat_spec) model_dict = {"model_path": model_path, "class_name": "BaseModel"} new_learnware = Learnware(id=id, name=learnware_name, model=model_dict, specification=specification) self.learnware_list[id] = new_learnware diff --git a/learnware/specification/base.py b/learnware/specification/base.py index 98af3c9..12b61c6 100644 --- a/learnware/specification/base.py +++ b/learnware/specification/base.py @@ -16,9 +16,9 @@ class BaseStatSpecification: class Specification: - def __init__(self, semantic_spec: dict = None): + def __init__(self, semantic_spec: dict = None, stat_spec: dict = {}): self.semantic_spec = semantic_spec - self.stat_spec = {} # stat_spec should be dict + self.stat_spec = stat_spec def get_stat_spec(self): return self.stat_spec diff --git a/learnware/specification/rkme.py b/learnware/specification/rkme.py index e52975f..54a15c6 100644 --- a/learnware/specification/rkme.py +++ b/learnware/specification/rkme.py @@ -17,7 +17,6 @@ from .base import BaseStatSpecification # mkl.get_max_threads() - class RKMEStatSpecification(BaseStatSpecification): """Reduced-set Kernel Mean Embedding (RKME) Specification"""