| @@ -1,3 +1,3 @@ | |||||
| import learnware.market.database_ops as db_ops | import learnware.market.database_ops as db_ops | ||||
| db_ops.init_empty_db() | |||||
| db_ops.load_market_from_db() | |||||
| @@ -1,18 +1,17 @@ | |||||
| import logging | import logging | ||||
| import logging.handlers | |||||
| from logging import Logger, handlers | |||||
| from .config import C | from .config import C | ||||
| def get_module_logger(module_name: str, level:int = None): | |||||
| def get_module_logger(module_name: str, level:int = None) -> Logger: | |||||
| """Get a logger for a specific module. | """Get a logger for a specific module. | ||||
| Parameters | Parameters | ||||
| ---------- | ---------- | ||||
| module_name : str | module_name : str | ||||
| Logic module name. | Logic module name. | ||||
| level : int, optional | level : int, optional | ||||
| logging level, by default None | |||||
| Logging level, by default None | |||||
| Returns | Returns | ||||
| ------- | ------- | ||||
| @@ -23,6 +22,9 @@ def get_module_logger(module_name: str, level:int = None): | |||||
| level = C.logging_level | level = C.logging_level | ||||
| # Get logger. | # Get logger. | ||||
| console_handler = logging.StreamHandler() | |||||
| console_handler.setLevel(level) | |||||
| module_logger = logging.getLogger(module_name) | module_logger = logging.getLogger(module_name) | ||||
| module_logger.setLevel(level) | module_logger.setLevel(level) | ||||
| module_logger.addHandler(console_handler) | |||||
| return module_logger | return module_logger | ||||
| @@ -2,44 +2,67 @@ import sqlite3 | |||||
| import os | import os | ||||
| from ..logger import get_module_logger | from ..logger import get_module_logger | ||||
| from ..learnware import Learnware | from ..learnware import Learnware | ||||
| from ..specification import RKMEStatSpecification, Specification | |||||
| import json | |||||
| ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) | ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) | ||||
| DB_PATH = os.path.join(ROOT_PATH, "market.db") | DB_PATH = os.path.join(ROOT_PATH, "market.db") | ||||
| LOGGER = get_module_logger("market", level="INFO") | |||||
| LOGGER = get_module_logger("market") | |||||
| def init_empty_db(func): | |||||
| def add_learnware_to_db(): | |||||
| pass | |||||
| def wrapper(): | |||||
| if not os.path.exists(DB_PATH): | |||||
| conn = sqlite3.connect(DB_PATH) | |||||
| LOGGER.info("Initializing Database in %s..." % (DB_PATH)) | |||||
| c = conn.cursor() | |||||
| c.execute( | |||||
| """CREATE TABLE LEARNWARE | |||||
| (ID CHAR(10) PRIMARY KEY NOT NULL, | |||||
| NAME TEXT NOT NULL, | |||||
| SEMANTIC_SPEC TEXT NOT NULL, | |||||
| MODEL_PATH TEXT NOT NULL, | |||||
| STAT_SPEC_PATH TEXT NOT NULL);""" | |||||
| ) | |||||
| LOGGER.info("Database Built!") | |||||
| conn.commit() | |||||
| conn.close() | |||||
| func() | |||||
| return wrapper | |||||
| def delete_learnware_from_db(): | |||||
| @init_empty_db | |||||
| def add_learnware_to_db(): | |||||
| pass | pass | ||||
| @init_empty_db | |||||
| def delete_learnware_from_db(id:str): | |||||
| pass | |||||
| def init_empty_db(): | |||||
| conn = sqlite3.connect(DB_PATH) | |||||
| LOGGER.info("Initializing Database in %s..." % (DB_PATH)) | |||||
| c = conn.cursor() | |||||
| c.execute( | |||||
| """CREATE TABLE LEARNWARE | |||||
| (ID CHAR(10) PRIMARY KEY NOT NULL, | |||||
| NAME TEXT NOT NULL, | |||||
| SEMANTIC_SPEC TEXT NOT NULL, | |||||
| MODEL_PATH TEXT NOT NULL, | |||||
| STAT_SPEC_PATH TEXT NOT NULL);""" | |||||
| ) | |||||
| LOGGER.info("Database Built!") | |||||
| conn.commit() | |||||
| conn.close() | |||||
| @init_empty_db | |||||
| def load_market_from_db(): | def load_market_from_db(): | ||||
| if not os.path.exists(DB_PATH): | |||||
| init_empty_db() | |||||
| conn = sqlite3.connect(DB_PATH) | conn = sqlite3.connect(DB_PATH) | ||||
| LOGGER.info("Reload from Database") | |||||
| c = conn.cursor() | c = conn.cursor() | ||||
| cursor = c.execute("SELECT id, name, semantic_spec, model_path, stat_spec_path from LEARNWARE") | cursor = c.execute("SELECT id, name, semantic_spec, model_path, stat_spec_path from LEARNWARE") | ||||
| learnware_list = {} | |||||
| max_count = 0 | |||||
| for item in cursor: | for item in cursor: | ||||
| id, name, semantic_spec, model_path, stat_spec_path = item | id, name, semantic_spec, model_path, stat_spec_path = item | ||||
| semantic_spec_dict = json.loads(semantic_spec) | |||||
| stat_spec_path_dict = json.loads(stat_spec_path) | |||||
| stat_spec_dict = {} | |||||
| for stat_spec_name in stat_spec_path_dict: | |||||
| new_stat_spec = RKMEStatSpecification() | |||||
| new_stat_spec.load(stat_spec_dict[stat_spec_name]) | |||||
| stat_spec_dict[stat_spec_name] = new_stat_spec | |||||
| model_dict = {'model_path':model_path, 'class_name':'BaseModel'} | |||||
| specification = Specification(semantic_spec=semantic_spec_dict, stat_spec=stat_spec_dict) | |||||
| new_learnware = Learnware(id=id, name=name, model=model_dict, specification=specification) | |||||
| learnware_list[id] = new_learnware | |||||
| max_count = max(max_count, int(id)) | |||||
| conn.commit() | |||||
| conn.close() | |||||
| LOGGER.info("Market Reloaded from DB.") | LOGGER.info("Market Reloaded from DB.") | ||||
| return learnware_list, max_count | |||||
| @@ -7,6 +7,7 @@ from typing import Tuple, Any, List, Union, Dict | |||||
| from .base import BaseMarket, BaseUserInfo | from .base import BaseMarket, BaseUserInfo | ||||
| from ..learnware import Learnware | from ..learnware import Learnware | ||||
| from ..specification import RKMEStatSpecification, Specification | from ..specification import RKMEStatSpecification, Specification | ||||
| from .database_ops import load_market_from_db, add_learnware_to_db, delete_learnware_from_db | |||||
| class EasyMarket(BaseMarket): | class EasyMarket(BaseMarket): | ||||
| @@ -64,8 +65,9 @@ class EasyMarket(BaseMarket): | |||||
| }, | }, | ||||
| } | } | ||||
| def reload_market(self, market_path: str, semantic_spec_list_path: str) -> bool: | |||||
| raise NotImplementedError("reload market is Not Implemented") | |||||
| def reload_market(self) -> bool: | |||||
| self.learnware_list, self.count = load_market_from_db() | |||||
| def add_learnware( | def add_learnware( | ||||
| self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str | self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str | ||||
| @@ -108,8 +110,9 @@ class EasyMarket(BaseMarket): | |||||
| id = "%08d" % (self.count) | id = "%08d" % (self.count) | ||||
| rkme_stat_spec = RKMEStatSpecification() | rkme_stat_spec = RKMEStatSpecification() | ||||
| rkme_stat_spec.load(stat_spec_path) | rkme_stat_spec.load(stat_spec_path) | ||||
| specification = Specification(semantic_spec=semantic_spec) | |||||
| specification.update_stat_spec("RKME", rkme_stat_spec) | |||||
| stat_spec = {'RKME':rkme_stat_spec} | |||||
| specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec) | |||||
| # specification.update_stat_spec("RKME", rkme_stat_spec) | |||||
| model_dict = {"model_path": model_path, "class_name": "BaseModel"} | model_dict = {"model_path": model_path, "class_name": "BaseModel"} | ||||
| new_learnware = Learnware(id=id, name=learnware_name, model=model_dict, specification=specification) | new_learnware = Learnware(id=id, name=learnware_name, model=model_dict, specification=specification) | ||||
| self.learnware_list[id] = new_learnware | self.learnware_list[id] = new_learnware | ||||
| @@ -16,9 +16,9 @@ class BaseStatSpecification: | |||||
| class Specification: | class Specification: | ||||
| def __init__(self, semantic_spec: dict = None): | |||||
| def __init__(self, semantic_spec: dict = None, stat_spec: dict = {}): | |||||
| self.semantic_spec = semantic_spec | self.semantic_spec = semantic_spec | ||||
| self.stat_spec = {} # stat_spec should be dict | |||||
| self.stat_spec = stat_spec | |||||
| def get_stat_spec(self): | def get_stat_spec(self): | ||||
| return self.stat_spec | return self.stat_spec | ||||
| @@ -17,7 +17,6 @@ from .base import BaseStatSpecification | |||||
| # mkl.get_max_threads() | # mkl.get_max_threads() | ||||
| class RKMEStatSpecification(BaseStatSpecification): | class RKMEStatSpecification(BaseStatSpecification): | ||||
| """Reduced-set Kernel Mean Embedding (RKME) Specification""" | """Reduced-set Kernel Mean Embedding (RKME) Specification""" | ||||