Browse Source

[MNT] modify add_learnware and update_learnware

tags/v0.3.2
Gene 2 years ago
parent
commit
715049c19b
2 changed files with 84 additions and 47 deletions
  1. +43
    -4
      learnware/market/base.py
  2. +41
    -43
      learnware/market/easy2/organizer.py

+ 43
- 4
learnware/market/base.py View File

@@ -99,7 +99,24 @@ class LearnwareMarket:

def add_learnware(
self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs
) -> Tuple[str, bool]:
) -> Tuple[str, int]:
"""Add a learnware into the market.

Parameters
----------
zip_path : str
Filepath for learnware model, a zipped file.
semantic_spec : dict
semantic_spec for new learnware, in dictionary format.
checker_names : List[str], optional
List contains checker names, by default None

Returns
-------
Tuple[str, int]
- str indicating model_id
- int indicating the final learnware check_status
"""
check_status = self.check_learnware(zip_path, semantic_spec, checker_names)
return self.learnware_organizer.add_learnware(
zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
@@ -112,9 +129,31 @@ class LearnwareMarket:
return self.learnware_organizer.delete_learnware(id, **kwargs)

def update_learnware(
self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs
) -> bool:
check_status = self.check_learnware(zip_path, semantic_spec, checker_names)
self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, check_status: int = None, **kwargs
) -> int:
"""Update learnware with zip_path and semantic_specification

Parameters
----------
id : str
Learnware id
zip_path : str
Filepath for learnware model, a zipped file.
semantic_spec : dict
semantic_spec for new learnware, in dictionary format.
checker_names : List[str], optional
List contains checker names, by default None.
check_status : int, optional
A flag indicating whether the learnware is usable, by default None.

Returns
-------
int
The final learnware check_status.
"""
update_status = self.check_learnware(zip_path, semantic_spec, checker_names)
check_status = update_status if check_status is None or update_status == BaseChecker.INVALID_LEARNWARE else check_status

return self.learnware_organizer.update_learnware(
id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs
)


+ 41
- 43
learnware/market/easy2/organizer.py View File

@@ -37,7 +37,6 @@ class EasyOrganizer(BaseOrganizer):
bool
A flag indicating whether the market is reload successfully.
"""

self.market_store_path = os.path.join(conf.market_root_path, self.market_id)
self.learnware_pool_path = os.path.join(self.market_store_path, "learnware_pool")
self.learnware_zip_pool_path = os.path.join(self.learnware_pool_path, "zips")
@@ -70,33 +69,33 @@ class EasyOrganizer(BaseOrganizer):
) = self.dbops.load_market()

def add_learnware(
self, zip_path: str, semantic_spec: dict, id: str = None, check_status: int = None
) -> Tuple[str, bool]:
self, zip_path: str, semantic_spec: dict, check_status: int
) -> Tuple[str, int]:
"""Add a learnware into the market.

.. note::

Given a prediction of a certain time, all signals before this time will be prepared well.


Parameters
----------
zip_path : str
Filepath for learnware model, a zipped file.
semantic_spec : dict
semantic_spec for new learnware, in dictionary format.
check_status: int
A flag indicating whether the learnware is usable.

Returns
-------
Tuple[str, int]
- str indicating model_id
- int indicating what the flag of learnware is added.

- int indicating the final learnware check_status
"""
if check_status == BaseChecker.INVALID_LEARNWARE:
logger.warning("Learnware is invalid!")
return None, BaseChecker.INVALID_LEARNWARE
semantic_spec = copy.deepcopy(semantic_spec)
logger.info("Get new learnware from %s" % (zip_path))

id = id if id is not None else "%08d" % (self.count)
id = "%08d" % (self.count)
target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (id))
target_folder_dir = os.path.join(self.learnware_folder_pool_path, id)
copyfile(zip_path, target_zip_dir)
@@ -168,44 +167,43 @@ class EasyOrganizer(BaseOrganizer):
return True

def update_learnware(self, id: str, zip_path: str = None, semantic_spec: dict = None, check_status: int = None):
"""update learnware with zip_path and semantic_specification
TODO: update should pass the semantic check too
"""Update learnware with zip_path, semantic_specification and check_status

Parameters
----------
id : str
_description_
Learnware id
zip_path : str, optional
_description_, by default None
Filepath for learnware model, a zipped file.
semantic_spec : dict, optional
_description_, by default None
semantic_spec for new learnware, in dictionary format.
check_status : int, optional
_description_, by default None
A flag indicating whether the learnware is usable.

Returns
-------
_type_
_description_
int
The final learnware check_status.
"""
assert (
zip_path is None and semantic_spec is None
), f"at least one of 'zip_path' and 'semantic_spec' should not be None when update learnware"
assert check_status != BaseChecker.INVALID_LEARNWARE, f"'check_status' can not be INVALID_LEARNWARE"

if zip_path is None and check_status is not None:
logger.warning("check_status will be ignored when zip_path is None for learnware update")

if check_status == BaseChecker.INVALID_LEARNWARE:
logger.warning("Learnware is invalid!")
return BaseChecker.INVALID_LEARNWARE
if zip_path is None and semantic_spec is None and check_status is None:
logger.warning("At least one of 'zip_path', 'semantic_spec' and 'check_status' should not be None when update learnware")
return BaseChecker.INVALID_LEARNWARE

# Update semantic_specification
learnware_zippath = self.learnware_zip_list[id] if zip_path is None else zip_path
semantic_spec = (
self.learnware_list[id].get_specification().get_semantic_spec() if semantic_spec is None else semantic_spec
)

self.dbops.update_learnware_semantic_specification(id, semantic_spec)

# Update zip path
target_zip_dir = self.learnware_zip_list[id]
target_folder_dir = self.learnware_folder_list[id]

if check_status is None and zip_path is not None:
if zip_path is not None:
with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
with zipfile.ZipFile(zip_path, "r") as z_file:
z_file.extractall(tempdir)
@@ -219,21 +217,21 @@ class EasyOrganizer(BaseOrganizer):

if new_learnware is None:
return BaseChecker.INVALID_LEARNWARE

learnwere_status = BaseChecker.NONUSABLE_LEARNWARE
else:
learnwere_status = self.use_flags[id] if zip_path is None else check_status

copyfile(zip_path, target_zip_dir)
with zipfile.ZipFile(target_zip_dir, "r") as z_file:
z_file.extractall(target_folder_dir)

copyfile(zip_path, target_zip_dir)
with zipfile.ZipFile(target_zip_dir, "r") as z_file:
z_file.extractall(target_folder_dir)
# Update check_status
self.use_flags[id] = self.use_flags[id] if check_status is None else check_status
self.dbops.update_learnware_use_flag(id, self.use_flags[id])
# Update learnware list
self.learnware_list[id] = get_learnware_from_dirpath(
id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir
)
self.use_flags[id] = learnwere_status
self.dbops.update_learnware_use_flag(id, learnwere_status)
return learnwere_status

return self.use_flags[id]

def get_learnware_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]:
"""Search learnware by id or list of ids.


Loading…
Cancel
Save