| @@ -11,12 +11,13 @@ LOGGER = get_module_logger("market") | |||||
| def init_empty_db(func): | def init_empty_db(func): | ||||
| def wrapper(): | |||||
| if not os.path.exists(DB_PATH): | |||||
| conn = sqlite3.connect(DB_PATH) | |||||
| def wrapper(*args, **kwargs): | |||||
| conn = sqlite3.connect(DB_PATH) | |||||
| cur = conn.cursor() | |||||
| listOfTables = cur.execute( """SELECT name FROM sqlite_master WHERE type='table' AND name='LEARNWARE'; """).fetchall() | |||||
| if listOfTables == []: | |||||
| LOGGER.info("Initializing Database in %s..." % (DB_PATH)) | LOGGER.info("Initializing Database in %s..." % (DB_PATH)) | ||||
| c = conn.cursor() | |||||
| c.execute( | |||||
| cur.execute( | |||||
| """CREATE TABLE LEARNWARE | """CREATE TABLE LEARNWARE | ||||
| (ID CHAR(10) PRIMARY KEY NOT NULL, | (ID CHAR(10) PRIMARY KEY NOT NULL, | ||||
| NAME TEXT NOT NULL, | NAME TEXT NOT NULL, | ||||
| @@ -25,26 +26,30 @@ def init_empty_db(func): | |||||
| STAT_SPEC_PATH TEXT NOT NULL);""" | STAT_SPEC_PATH TEXT NOT NULL);""" | ||||
| ) | ) | ||||
| LOGGER.info("Database Built!") | LOGGER.info("Database Built!") | ||||
| conn.commit() | |||||
| conn.close() | |||||
| func() | |||||
| kwargs['cur'] = cur | |||||
| kwargs['conn'] = conn | |||||
| func(*args, **kwargs) | |||||
| conn.commit() | |||||
| conn.close() | |||||
| return wrapper | return wrapper | ||||
| @init_empty_db | @init_empty_db | ||||
| def add_learnware_to_db(): | |||||
| def add_learnware_to_db(id:str, name:str, model_path:str, semantic_spec:dict): | |||||
| pass | pass | ||||
| @init_empty_db | @init_empty_db | ||||
| def delete_learnware_from_db(id:str): | |||||
| pass | |||||
| def delete_learnware_from_db(id:str, cur, conn): | |||||
| cur.execute("DELETE from LEARNWARE where ID=%;") | |||||
| conn.commit() | |||||
| LOGGER.info("%d item has been deleted from table 'LEARNWARE'"%(conn.total_changes)) | |||||
| @init_empty_db | @init_empty_db | ||||
| def load_market_from_db(): | |||||
| conn = sqlite3.connect(DB_PATH) | |||||
| def load_market_from_db(cur, conn): | |||||
| # conn = sqlite3.connect(DB_PATH) | |||||
| LOGGER.info("Reload from Database") | LOGGER.info("Reload from Database") | ||||
| c = conn.cursor() | |||||
| cursor = c.execute("SELECT id, name, semantic_spec, model_path, stat_spec_path from LEARNWARE") | |||||
| # c = conn.cursor() | |||||
| cursor = cur.execute("SELECT id, name, semantic_spec, model_path, stat_spec_path from LEARNWARE") | |||||
| learnware_list = {} | learnware_list = {} | ||||
| max_count = 0 | max_count = 0 | ||||
| @@ -62,7 +67,5 @@ def load_market_from_db(): | |||||
| new_learnware = Learnware(id=id, name=name, model=model_dict, specification=specification) | new_learnware = Learnware(id=id, name=name, model=model_dict, specification=specification) | ||||
| learnware_list[id] = new_learnware | learnware_list[id] = new_learnware | ||||
| max_count = max(max_count, int(id)) | max_count = max(max_count, int(id)) | ||||
| conn.commit() | |||||
| conn.close() | |||||
| LOGGER.info("Market Reloaded from DB.") | LOGGER.info("Market Reloaded from DB.") | ||||
| return learnware_list, max_count | |||||
| return learnware_list, max_count+1 | |||||