Browse Source

[ENH] Modify logger

tags/v0.3.2
bxdd 3 years ago
parent
commit
2b03049536
6 changed files with 62 additions and 35 deletions
  1. +1
    -1
      examples/example_market_db/example_db.py
  2. +6
    -4
      learnware/logger.py
  3. +46
    -23
      learnware/market/database_ops.py
  4. +7
    -4
      learnware/market/easy.py
  5. +2
    -2
      learnware/specification/base.py
  6. +0
    -1
      learnware/specification/rkme.py

+ 1
- 1
examples/example_market_db/example_db.py View File

@@ -1,3 +1,3 @@
import learnware.market.database_ops as db_ops import learnware.market.database_ops as db_ops


db_ops.init_empty_db()
db_ops.load_market_from_db()

+ 6
- 4
learnware/logger.py View File

@@ -1,18 +1,17 @@
import logging import logging
import logging.handlers
from logging import Logger, handlers


from .config import C from .config import C




def get_module_logger(module_name: str, level:int = None):
def get_module_logger(module_name: str, level:int = None) -> Logger:
"""Get a logger for a specific module. """Get a logger for a specific module.

Parameters Parameters
---------- ----------
module_name : str module_name : str
Logic module name. Logic module name.
level : int, optional level : int, optional
logging level, by default None
Logging level, by default None


Returns Returns
------- -------
@@ -23,6 +22,9 @@ def get_module_logger(module_name: str, level:int = None):
level = C.logging_level level = C.logging_level


# Get logger. # Get logger.
console_handler = logging.StreamHandler()
console_handler.setLevel(level)
module_logger = logging.getLogger(module_name) module_logger = logging.getLogger(module_name)
module_logger.setLevel(level) module_logger.setLevel(level)
module_logger.addHandler(console_handler)
return module_logger return module_logger

+ 46
- 23
learnware/market/database_ops.py View File

@@ -2,44 +2,67 @@ import sqlite3
import os import os
from ..logger import get_module_logger from ..logger import get_module_logger
from ..learnware import Learnware from ..learnware import Learnware
from ..specification import RKMEStatSpecification, Specification
import json


ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
DB_PATH = os.path.join(ROOT_PATH, "market.db") DB_PATH = os.path.join(ROOT_PATH, "market.db")
LOGGER = get_module_logger("market", level="INFO")
LOGGER = get_module_logger("market")


def init_empty_db(func):


def add_learnware_to_db():
pass
def wrapper():
if not os.path.exists(DB_PATH):
conn = sqlite3.connect(DB_PATH)
LOGGER.info("Initializing Database in %s..." % (DB_PATH))
c = conn.cursor()
c.execute(
"""CREATE TABLE LEARNWARE
(ID CHAR(10) PRIMARY KEY NOT NULL,
NAME TEXT NOT NULL,
SEMANTIC_SPEC TEXT NOT NULL,
MODEL_PATH TEXT NOT NULL,
STAT_SPEC_PATH TEXT NOT NULL);"""
)
LOGGER.info("Database Built!")
conn.commit()
conn.close()
func()


return wrapper


def delete_learnware_from_db():
@init_empty_db
def add_learnware_to_db():
pass pass


@init_empty_db
def delete_learnware_from_db(id:str):
pass


def init_empty_db():
conn = sqlite3.connect(DB_PATH)
LOGGER.info("Initializing Database in %s..." % (DB_PATH))
c = conn.cursor()
c.execute(
"""CREATE TABLE LEARNWARE
(ID CHAR(10) PRIMARY KEY NOT NULL,
NAME TEXT NOT NULL,
SEMANTIC_SPEC TEXT NOT NULL,
MODEL_PATH TEXT NOT NULL,
STAT_SPEC_PATH TEXT NOT NULL);"""
)
LOGGER.info("Database Built!")
conn.commit()
conn.close()


@init_empty_db
def load_market_from_db(): def load_market_from_db():
if not os.path.exists(DB_PATH):
init_empty_db()
conn = sqlite3.connect(DB_PATH) conn = sqlite3.connect(DB_PATH)
LOGGER.info("Reload from Database")
c = conn.cursor() c = conn.cursor()
cursor = c.execute("SELECT id, name, semantic_spec, model_path, stat_spec_path from LEARNWARE") cursor = c.execute("SELECT id, name, semantic_spec, model_path, stat_spec_path from LEARNWARE")


learnware_list = {}
max_count = 0
for item in cursor: for item in cursor:
id, name, semantic_spec, model_path, stat_spec_path = item id, name, semantic_spec, model_path, stat_spec_path = item
semantic_spec_dict = json.loads(semantic_spec)
stat_spec_path_dict = json.loads(stat_spec_path)
stat_spec_dict = {}
for stat_spec_name in stat_spec_path_dict:
new_stat_spec = RKMEStatSpecification()
new_stat_spec.load(stat_spec_dict[stat_spec_name])
stat_spec_dict[stat_spec_name] = new_stat_spec
model_dict = {'model_path':model_path, 'class_name':'BaseModel'}
specification = Specification(semantic_spec=semantic_spec_dict, stat_spec=stat_spec_dict)
new_learnware = Learnware(id=id, name=name, model=model_dict, specification=specification)
learnware_list[id] = new_learnware
max_count = max(max_count, int(id))
conn.commit()
conn.close()
LOGGER.info("Market Reloaded from DB.") LOGGER.info("Market Reloaded from DB.")
return learnware_list, max_count

+ 7
- 4
learnware/market/easy.py View File

@@ -7,6 +7,7 @@ from typing import Tuple, Any, List, Union, Dict
from .base import BaseMarket, BaseUserInfo from .base import BaseMarket, BaseUserInfo
from ..learnware import Learnware from ..learnware import Learnware
from ..specification import RKMEStatSpecification, Specification from ..specification import RKMEStatSpecification, Specification
from .database_ops import load_market_from_db, add_learnware_to_db, delete_learnware_from_db




class EasyMarket(BaseMarket): class EasyMarket(BaseMarket):
@@ -64,8 +65,9 @@ class EasyMarket(BaseMarket):
}, },
} }


def reload_market(self, market_path: str, semantic_spec_list_path: str) -> bool:
raise NotImplementedError("reload market is Not Implemented")
def reload_market(self) -> bool:
self.learnware_list, self.count = load_market_from_db()



def add_learnware( def add_learnware(
self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str self, learnware_name: str, model_path: str, stat_spec_path: str, semantic_spec: dict, desc: str
@@ -108,8 +110,9 @@ class EasyMarket(BaseMarket):
id = "%08d" % (self.count) id = "%08d" % (self.count)
rkme_stat_spec = RKMEStatSpecification() rkme_stat_spec = RKMEStatSpecification()
rkme_stat_spec.load(stat_spec_path) rkme_stat_spec.load(stat_spec_path)
specification = Specification(semantic_spec=semantic_spec)
specification.update_stat_spec("RKME", rkme_stat_spec)
stat_spec = {'RKME':rkme_stat_spec}
specification = Specification(semantic_spec=semantic_spec, stat_spec=stat_spec)
# specification.update_stat_spec("RKME", rkme_stat_spec)
model_dict = {"model_path": model_path, "class_name": "BaseModel"} model_dict = {"model_path": model_path, "class_name": "BaseModel"}
new_learnware = Learnware(id=id, name=learnware_name, model=model_dict, specification=specification) new_learnware = Learnware(id=id, name=learnware_name, model=model_dict, specification=specification)
self.learnware_list[id] = new_learnware self.learnware_list[id] = new_learnware


+ 2
- 2
learnware/specification/base.py View File

@@ -16,9 +16,9 @@ class BaseStatSpecification:




class Specification: class Specification:
def __init__(self, semantic_spec: dict = None):
def __init__(self, semantic_spec: dict = None, stat_spec: dict = {}):
self.semantic_spec = semantic_spec self.semantic_spec = semantic_spec
self.stat_spec = {} # stat_spec should be dict
self.stat_spec = stat_spec


def get_stat_spec(self): def get_stat_spec(self):
return self.stat_spec return self.stat_spec


+ 0
- 1
learnware/specification/rkme.py View File

@@ -17,7 +17,6 @@ from .base import BaseStatSpecification


# mkl.get_max_threads() # mkl.get_max_threads()



class RKMEStatSpecification(BaseStatSpecification): class RKMEStatSpecification(BaseStatSpecification):
"""Reduced-set Kernel Mean Embedding (RKME) Specification""" """Reduced-set Kernel Mean Embedding (RKME) Specification"""




Loading…
Cancel
Save