| @@ -1,3 +1,3 @@ | |||
| import learnware.market.database_ops as db_ops | |||
| db_ops.init_empty_db() | |||
| db_ops.load_market_from_db() | |||
| @@ -1,18 +1,17 @@ | |||
| import logging | |||
| import logging.handlers | |||
| from logging import Logger, handlers | |||
| from .config import C | |||
| def get_module_logger(module_name: str, level:int = None): | |||
| def get_module_logger(module_name: str, level:int = None) -> Logger: | |||
| """Get a logger for a specific module. | |||
| Parameters | |||
| ---------- | |||
| module_name : str | |||
| Logic module name. | |||
| level : int, optional | |||
| logging level, by default None | |||
| Logging level, by default None | |||
| Returns | |||
| ------- | |||
| @@ -23,6 +22,9 @@ def get_module_logger(module_name: str, level:int = None): | |||
| level = C.logging_level | |||
| # Get logger. | |||
| console_handler = logging.StreamHandler() | |||
| console_handler.setLevel(level) | |||
| module_logger = logging.getLogger(module_name) | |||
| module_logger.setLevel(level) | |||
| module_logger.addHandler(console_handler) | |||
| return module_logger | |||
| @@ -2,44 +2,67 @@ import sqlite3 | |||
| import os | |||
| from ..logger import get_module_logger | |||
| from ..learnware import Learnware | |||
| from ..specification import RKMEStatSpecification, Specification | |||
| import json | |||
| ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) | |||
| DB_PATH = os.path.join(ROOT_PATH, "market.db") | |||
| LOGGER = get_module_logger("market", level="INFO") | |||
| LOGGER = get_module_logger("market") | |||
| def init_empty_db(func): | |||
| def add_learnware_to_db(): | |||
| pass | |||
| def wrapper(): | |||
| if not os.path.exists(DB_PATH): | |||
| conn = sqlite3.connect(DB_PATH) | |||
| LOGGER.info("Initializing Database in %s..." % (DB_PATH)) | |||
| c = conn.cursor() | |||
| c.execute( | |||
| """CREATE TABLE LEARNWARE | |||
| (ID CHAR(10) PRIMARY KEY NOT NULL, | |||
| NAME TEXT NOT NULL, | |||
| SEMANTIC_SPEC TEXT NOT NULL, | |||
| MODEL_PATH TEXT NOT NULL, | |||
| STAT_SPEC_PATH TEXT NOT NULL);""" | |||
| ) | |||
| LOGGER.info("Database Built!") | |||
| conn.commit() | |||
| conn.close() | |||
| func() | |||
| return wrapper | |||
| def delete_learnware_from_db(): | |||
| @init_empty_db | |||
| def add_learnware_to_db(): | |||
| pass | |||
| @init_empty_db | |||
| def delete_learnware_from_db(id:str): | |||
| pass | |||
| def init_empty_db(): | |||
| conn = sqlite3.connect(DB_PATH) | |||
| LOGGER.info("Initializing Database in %s..." % (DB_PATH)) | |||
| c = conn.cursor() | |||
| c.execute( | |||
| """CREATE TABLE LEARNWARE | |||
| (ID CHAR(10) PRIMARY KEY NOT NULL, | |||
| NAME TEXT NOT NULL, | |||
| SEMANTIC_SPEC TEXT NOT NULL, | |||
| MODEL_PATH TEXT NOT NULL, | |||
| STAT_SPEC_PATH TEXT NOT NULL);""" | |||
| ) | |||
| LOGGER.info("Database Built!") | |||
| conn.commit() | |||
| conn.close() | |||
| @init_empty_db | |||
| def load_market_from_db(): | |||
| if not os.path.exists(DB_PATH): | |||
| init_empty_db() | |||
| conn = sqlite3.connect(DB_PATH) | |||
| LOGGER.info("Reload from Database") | |||
| c = conn.cursor() | |||
| cursor = c.execute("SELECT id, name, semantic_spec, model_path, stat_spec_path from LEARNWARE") | |||
| learnware_list = {} | |||
| max_count = 0 | |||
| for item in cursor: | |||
| id, name, semantic_spec, model_path, stat_spec_path = item | |||
| semantic_spec_dict = json.loads(semantic_spec) | |||
| stat_spec_path_dict = json.loads(stat_spec_path) | |||
| stat_spec_dict = {} | |||
| for stat_spec_name in stat_spec_path_dict: | |||
| new_stat_spec = RKMEStatSpecification() | |||
| new_stat_spec.load(stat_spec_dict[stat_spec_name]) | |||
| stat_spec_dict[stat_spec_name] = new_stat_spec | |||
| model_dict = {'model_path':model_path, 'class_name':'BaseModel'} | |||
| specification = Specification(semantic_spec=semantic_spec_dict, stat_spec=stat_spec_dict) | |||
| new_learnware = Learnware(id=id, name=name, model=model_dict, specification=specification) | |||
| learnware_list[id] = new_learnware | |||
| max_count = max(max_count, int(id)) | |||
| conn.commit() | |||
| conn.close() | |||
| LOGGER.info("Market Reloaded from DB.") | |||
| return learnware_list, max_count | |||
| @@ -7,6 +7,7 @@ from typing import Tuple, Any, List, Union, Dict | |||
| from .base import BaseMarket, BaseUserInfo | |||
| from ..learnware import Learnware | |||
| from ..specification import RKMEStatSpecification, Specification | |||
| from .database_ops import load_market_from_db, add_learnware_to_db, delete_learnware_from_db | |||
| class EasyMarket(BaseMarket): | |||
| @@ -64,8 +65,9 @@ class EasyMarket(BaseMarket): | |||
| }, | |||
| } | |||
| def reload_market(self, market_path: str, semantic_spec_list_path: str) -> bool: | |||
| raise NotImplementedError("reload market is Not Implemented") | |||
| def reload_market(self) -> bool: | |||
| self.learnware_list, self.count = load_market_from_db() | |||
| def add_learnware( | |||
| self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str | |||
| @@ -108,8 +110,9 @@ class EasyMarket(BaseMarket): | |||
| id = "%08d" % (self.count) | |||
| rkme_stat_spec = RKMEStatSpecification() | |||
| rkme_stat_spec.load(stat_spec_path) | |||
| specification = Specification(semantic_spec=semantic_spec) | |||
| specification.update_stat_spec("RKME", rkme_stat_spec) | |||
| stat_spec = {'RKME':rkme_stat_spec} | |||
| specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec) | |||
| # specification.update_stat_spec("RKME", rkme_stat_spec) | |||
| model_dict = {"model_path": model_path, "class_name": "BaseModel"} | |||
| new_learnware = Learnware(id=id, name=learnware_name, model=model_dict, specification=specification) | |||
| self.learnware_list[id] = new_learnware | |||
| @@ -16,9 +16,9 @@ class BaseStatSpecification: | |||
| class Specification: | |||
| def __init__(self, semantic_spec: dict = None): | |||
| def __init__(self, semantic_spec: dict = None, stat_spec: dict = {}): | |||
| self.semantic_spec = semantic_spec | |||
| self.stat_spec = {} # stat_spec should be dict | |||
| self.stat_spec = stat_spec | |||
| def get_stat_spec(self): | |||
| return self.stat_spec | |||
| @@ -17,7 +17,6 @@ from .base import BaseStatSpecification | |||
| # mkl.get_max_threads() | |||
| class RKMEStatSpecification(BaseStatSpecification): | |||
| """Reduced-set Kernel Mean Embedding (RKME) Specification""" | |||