diff --git a/learnware/market/base.py b/learnware/market/base.py index d81f1b9..4b1332c 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -99,7 +99,24 @@ class LearnwareMarket: def add_learnware( self, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs - ) -> Tuple[str, bool]: + ) -> Tuple[str, int]: + """Add a learnware into the market. + + Parameters + ---------- + zip_path : str + Filepath for learnware model, a zipped file. + semantic_spec : dict + semantic_spec for new learnware, in dictionary format. + checker_names : List[str], optional + List contains checker names, by default None + + Returns + ------- + Tuple[str, int] + - str indicating model_id + - int indicating the final learnware check_status + """ check_status = self.check_learnware(zip_path, semantic_spec, checker_names) return self.learnware_organizer.add_learnware( zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs @@ -112,9 +129,31 @@ class LearnwareMarket: return self.learnware_organizer.delete_learnware(id, **kwargs) def update_learnware( - self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, **kwargs - ) -> bool: - check_status = self.check_learnware(zip_path, semantic_spec, checker_names) + self, id: str, zip_path: str, semantic_spec: dict, checker_names: List[str] = None, check_status: int = None, **kwargs + ) -> int: + """Update learnware with zip_path and semantic_specification + + Parameters + ---------- + id : str + Learnware id + zip_path : str + Filepath for learnware model, a zipped file. + semantic_spec : dict + semantic_spec for new learnware, in dictionary format. + checker_names : List[str], optional + List contains checker names, by default None. + check_status : int, optional + A flag indicating whether the learnware is usable, by default None. + + Returns + ------- + int + The final learnware check_status. + """ + update_status = self.check_learnware(zip_path, semantic_spec, checker_names) + check_status = update_status if check_status is None or update_status == BaseChecker.INVALID_LEARNWARE else check_status + return self.learnware_organizer.update_learnware( id, zip_path=zip_path, semantic_spec=semantic_spec, check_status=check_status, **kwargs ) diff --git a/learnware/market/easy2/organizer.py b/learnware/market/easy2/organizer.py index 55780e3..3a78794 100644 --- a/learnware/market/easy2/organizer.py +++ b/learnware/market/easy2/organizer.py @@ -37,7 +37,6 @@ class EasyOrganizer(BaseOrganizer): bool A flag indicating whether the market is reload successfully. """ - self.market_store_path = os.path.join(conf.market_root_path, self.market_id) self.learnware_pool_path = os.path.join(self.market_store_path, "learnware_pool") self.learnware_zip_pool_path = os.path.join(self.learnware_pool_path, "zips") @@ -70,33 +69,33 @@ class EasyOrganizer(BaseOrganizer): ) = self.dbops.load_market() def add_learnware( - self, zip_path: str, semantic_spec: dict, id: str = None, check_status: int = None - ) -> Tuple[str, bool]: + self, zip_path: str, semantic_spec: dict, check_status: int + ) -> Tuple[str, int]: """Add a learnware into the market. - .. note:: - - Given a prediction of a certain time, all signals before this time will be prepared well. - - Parameters ---------- zip_path : str Filepath for learnware model, a zipped file. semantic_spec : dict semantic_spec for new learnware, in dictionary format. + check_status: int + A flag indicating whether the learnware is usable. Returns ------- Tuple[str, int] - str indicating model_id - - int indicating what the flag of learnware is added. - + - int indicating the final learnware check_status """ + if check_status == BaseChecker.INVALID_LEARNWARE: + logger.warning("Learnware is invalid!") + return None, BaseChecker.INVALID_LEARNWARE + semantic_spec = copy.deepcopy(semantic_spec) logger.info("Get new learnware from %s" % (zip_path)) - id = id if id is not None else "%08d" % (self.count) + id = "%08d" % (self.count) target_zip_dir = os.path.join(self.learnware_zip_pool_path, "%s.zip" % (id)) target_folder_dir = os.path.join(self.learnware_folder_pool_path, id) copyfile(zip_path, target_zip_dir) @@ -168,44 +167,43 @@ class EasyOrganizer(BaseOrganizer): return True def update_learnware(self, id: str, zip_path: str = None, semantic_spec: dict = None, check_status: int = None): - """update learnware with zip_path and semantic_specification - TODO: update should pass the semantic check too + """Update learnware with zip_path, semantic_specification and check_status Parameters ---------- id : str - _description_ + Learnware id zip_path : str, optional - _description_, by default None + Filepath for learnware model, a zipped file. semantic_spec : dict, optional - _description_, by default None + semantic_spec for new learnware, in dictionary format. check_status : int, optional - _description_, by default None + A flag indicating whether the learnware is usable. Returns ------- - _type_ - _description_ + int + The final learnware check_status. """ - assert ( - zip_path is None and semantic_spec is None - ), f"at least one of 'zip_path' and 'semantic_spec' should not be None when update learnware" - assert check_status != BaseChecker.INVALID_LEARNWARE, f"'check_status' can not be INVALID_LEARNWARE" - - if zip_path is None and check_status is not None: - logger.warning("check_status will be ignored when zip_path is None for learnware update") - + if check_status == BaseChecker.INVALID_LEARNWARE: + logger.warning("Learnware is invalid!") + return BaseChecker.INVALID_LEARNWARE + + if zip_path is None and semantic_spec is None and check_status is None: + logger.warning("At least one of 'zip_path', 'semantic_spec' and 'check_status' should not be None when update learnware") + return BaseChecker.INVALID_LEARNWARE + + # Update semantic_specification learnware_zippath = self.learnware_zip_list[id] if zip_path is None else zip_path semantic_spec = ( self.learnware_list[id].get_specification().get_semantic_spec() if semantic_spec is None else semantic_spec ) - self.dbops.update_learnware_semantic_specification(id, semantic_spec) - + + # Update zip path target_zip_dir = self.learnware_zip_list[id] target_folder_dir = self.learnware_folder_list[id] - - if check_status is None and zip_path is not None: + if zip_path is not None: with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir: with zipfile.ZipFile(zip_path, "r") as z_file: z_file.extractall(tempdir) @@ -219,21 +217,21 @@ class EasyOrganizer(BaseOrganizer): if new_learnware is None: return BaseChecker.INVALID_LEARNWARE - - learnwere_status = BaseChecker.NONUSABLE_LEARNWARE - else: - learnwere_status = self.use_flags[id] if zip_path is None else check_status - - copyfile(zip_path, target_zip_dir) - with zipfile.ZipFile(target_zip_dir, "r") as z_file: - z_file.extractall(target_folder_dir) - + + copyfile(zip_path, target_zip_dir) + with zipfile.ZipFile(target_zip_dir, "r") as z_file: + z_file.extractall(target_folder_dir) + + # Update check_status + self.use_flags[id] = self.use_flags[id] if check_status is None else check_status + self.dbops.update_learnware_use_flag(id, self.use_flags[id]) + + # Update learnware list self.learnware_list[id] = get_learnware_from_dirpath( id=id, semantic_spec=semantic_spec, learnware_dirpath=target_folder_dir ) - self.use_flags[id] = learnwere_status - self.dbops.update_learnware_use_flag(id, learnwere_status) - return learnwere_status + + return self.use_flags[id] def get_learnware_by_ids(self, ids: Union[str, List[str]]) -> Union[Learnware, List[Learnware]]: """Search learnware by id or list of ids.