Browse Source

[ENH] Split market with database operation

tags/v0.3.2
bxdd 3 years ago
parent
commit
195fcde3d5
3 changed files with 13 additions and 7 deletions
  1. +1
    -0
      learnware/learnware/base.py
  2. +2
    -2
      learnware/market/database_ops.py
  3. +10
    -5
      learnware/market/easy.py

+ 1
- 0
learnware/learnware/base.py View File

@@ -46,6 +46,7 @@ class Learnware:
elif isinstance(self.model, dict):
model_module = get_module_by_module_path(self.model["module_path"])
self.model = getattr(model_module, self.model["class_name"])(**self.model.get("kwargs", {}))
print(self.model)
else:
raise TypeError(f"Model must be BaseModel or dict, not {type(self.model)}")



+ 2
- 2
learnware/market/database_ops.py View File

@@ -11,8 +11,8 @@ logger = get_module_logger("database_ops")


def init_empty_db(func):
def wrapper(*args, **kwargs):
conn = sqlite3.connect(os.path.join(C.database_path, "market.db"))
def wrapper(market_id, *args, **kwargs):
conn = sqlite3.connect(os.path.join(C.database_path, f"market_{market_id}.db"))
cur = conn.cursor()
listOfTables = cur.execute(
"""SELECT name FROM sqlite_master WHERE type='table' AND name='LEARNWARE'; """


+ 10
- 5
learnware/market/easy.py View File

@@ -23,17 +23,21 @@ class EasyMarket(BaseMarket):
NOPREDICTION_LEARNWARE = 0
PREDICTION_LEARWARE = 1

def __init__(self, market_id: str = None, rebuild: bool = False):
def __init__(self, market_id: str = "default", rebuild: bool = False):
"""Initialize Learnware Market.
Automatically reload from db if available.
Build an empty db otherwise.

Parameters
----------
market_id : str, optional, by default 'default'
The unique market id for market database
rebuild : bool, optional
Clear current database if set to True, by default False
!!! Do NOT set to True unless highly necessary !!!
"""
self.market_id = market_id
self.learnware_list = {} # id: Learnware
self.learnware_zip_list = {}
self.learnware_folder_list = {}
@@ -45,13 +49,13 @@ class EasyMarket(BaseMarket):
def reload_market(self, rebuild: bool = False) -> bool:
if rebuild:
logger.warning("Warning! You are trying to clear current database!")
clear_learnware_table()
clear_learnware_table(market_id=self.market_id)
rmtree(C.learnware_pool_path)

os.makedirs(C.learnware_pool_path, exist_ok=True)
os.makedirs(C.learnware_zip_pool_path, exist_ok=True)
os.makedirs(C.learnware_folder_pool_path, exist_ok=True)
self.learnware_list, self.learnware_zip_list, self.learnware_folder_list, self.count = load_market_from_db()
self.learnware_list, self.learnware_zip_list, self.learnware_folder_list, self.count = load_market_from_db(market_id=self.market_id)

@classmethod
def check_learnware(cls, learnware: Learnware) -> int:
@@ -176,7 +180,8 @@ class EasyMarket(BaseMarket):
self.learnware_folder_list[id] = target_folder_dir
self.count += 1
add_learnware_to_db(
id,
market_id=self.market_id,
id=id,
semantic_spec=semantic_spec,
zip_path=target_zip_dir,
folder_path=target_folder_dir,
@@ -655,7 +660,7 @@ class EasyMarket(BaseMarket):
self.learnware_list.pop(id)
self.learnware_zip_list.pop(id)
self.learnware_folder_list.pop(id)
delete_learnware_from_db(id)
delete_learnware_from_db(market_id=self.market_id, id=id)

return True



Loading…
Cancel
Save