From beff99a1cb849f75bf468aee307dd2b53e3641f4 Mon Sep 17 00:00:00 2001 From: bxdd Date: Mon, 13 Nov 2023 20:20:30 +0800 Subject: [PATCH] [MNT] modifty market base class --- learnware/market/base.py | 41 +++++++++++++------ .../heterogeneous/organizer/__init__.py | 14 ++++--- 2 files changed, 36 insertions(+), 19 deletions(-) diff --git a/learnware/market/base.py b/learnware/market/base.py index 3ed15b8..eb72814 100644 --- a/learnware/market/base.py +++ b/learnware/market/base.py @@ -68,8 +68,7 @@ class LearnwareMarket: ): self.market_id = market_id self.learnware_organizer = BaseOrganizer() if organizer is None else organizer - self.learnware_organizer.reset(market_id=market_id) - self.learnware_organizer.reload_market(rebuild=rebuild) + self.learnware_organizer.reset(market_id=market_id, reload_kwargs={"rebuild": rebuild}) self.learnware_searcher = BaseSearcher() if searcher is None else searcher self.learnware_searcher.reset(organizer=self.learnware_organizer) @@ -77,9 +76,20 @@ class LearnwareMarket: self.learnware_checker = {"BaseChecker": BaseChecker()} else: self.learnware_checker = {checker.__class__.__name__: checker for checker in checker_list} - for name, checker in self.learnware_checker.items(): + for checker in self.learnware_checker.values(): checker.reset(organizer=self.learnware_organizer) + def reset(self, organizer_kwargs=None, searcher_kwargs=None, checker_kwargs=None, **kwargs): + organizer_kwargs = {} if organizer_kwargs is None else organizer_kwargs + searcher_kwargs = {} if searcher_kwargs is None else searcher_kwargs + checker_kwargs = {} if checker_kwargs is None else checker_kwargs + self.learnware_organizer.reset(**organizer_kwargs) + self.learnware_searcher.reset(**searcher_kwargs) + for checker in self.learnware_checker.values(): + checker.reset(**checker_kwargs) + for _k, _v in kwargs.items(): + setattr(self, _k, _v) + def reload_market(self, **kwargs) -> bool: self.learnware_organizer.reload_market(**kwargs) @@ -254,11 +264,14 @@ class LearnwareMarket: class BaseOrganizer: - def __init__(self, market_id=None): - self.reset(market_id=market_id) + def __init__(self, market_id=None, **kwargs): + self.reset(market_id=market_id, **kwargs) - def reset(self, market_id=None, **kwargs): - self.market_id = market_id + def reset(self, market_id: str = None, reload_kwargs: dict = None): + if market_id is not None: + self.market_id = market_id + if reload_kwargs is not None: + self.reload_market(**reload_kwargs) def reload_market(self, rebuild=False, **kwargs) -> bool: """Reload the learnware organizer when server restared. @@ -428,11 +441,12 @@ class BaseOrganizer: class BaseSearcher: - def __init__(self, organizer: BaseOrganizer = None): - self.learnware_organizer = organizer + def __init__(self, organizer: BaseOrganizer = None, **kwargs): + self.reset(organizer=organizer, **kwargs) - def reset(self, organizer): - self.learnware_organizer = organizer + def reset(self, organizer: BaseOrganizer = None, **kwargs): + if organizer is not None: + self.learnware_organizer = organizer def __call__(self, user_info: BaseUserInfo, check_status: int = None): """Search learnwares based on user_info from learnwares with check_status @@ -456,8 +470,9 @@ class BaseChecker: def __init__(self, organizer: BaseOrganizer = None): self.learnware_organizer = organizer - def reset(self, organizer): - self.learnware_organizer = organizer + def reset(self, organizer=None): + if organizer is not None: + self.learnware_organizer = organizer def __call__(self, learnware: Learnware) -> Tuple[int, str]: """Check the utility of a learnware diff --git a/learnware/market/heterogeneous/organizer/__init__.py b/learnware/market/heterogeneous/organizer/__init__.py index 4a44112..2f03fba 100644 --- a/learnware/market/heterogeneous/organizer/__init__.py +++ b/learnware/market/heterogeneous/organizer/__init__.py @@ -15,7 +15,7 @@ logger = get_module_logger("hetero_map_table_organizer") class HeteroMapTableOrganizer(EasyOrganizer): def reload_market(self, rebuild=False, auto_update=False, auto_update_limit=100): - super().reload_market(rebuild=rebuild) + super(HeteroMapTableOrganizer, self).reload_market(rebuild=rebuild) self.auto_update = auto_update self.auto_update_limit = auto_update_limit self.count_down = auto_update_limit @@ -55,7 +55,9 @@ class HeteroMapTableOrganizer(EasyOrganizer): def add_learnware( self, zip_path: str, semantic_spec: dict, check_status: int, learnware_id: str = None ) -> Tuple[str, int]: - learnware_id, learnwere_status = super().add_learnware(zip_path, semantic_spec, check_status, learnware_id) + learnware_id, learnwere_status = super(HeteroMapTableOrganizer, self).add_learnware( + zip_path, semantic_spec, check_status, learnware_id + ) if learnwere_status == BaseChecker.USABLE_LEARWARE and len(self._get_hetero_learnware_ids(learnware_id)): self._update_learnware_by_ids(learnware_id) @@ -76,13 +78,13 @@ class HeteroMapTableOrganizer(EasyOrganizer): ) self.market_mapping = updated_market_mapping self._update_learnware_by_ids(training_learnware_ids) - + self.count_down = self.auto_update_limit return learnware_id, learnwere_status def delete_learnware(self, id: str) -> bool: - flag = super().delete_learnware(id) + flag = super(HeteroMapTableOrganizer, self).delete_learnware(id) if flag: hetero_spec_path = os.path.join(self.hetero_specs_path, f"{id}.json") try: @@ -92,13 +94,13 @@ class HeteroMapTableOrganizer(EasyOrganizer): return flag def update_learnware(self, id: str, zip_path: str = None, semantic_spec: dict = None, check_status: int = None): - final_status = super().update_learnware(id, zip_path, semantic_spec, check_status) + final_status = super(HeteroMapTableOrganizer, self).update_learnware(id, zip_path, semantic_spec, check_status) if final_status == BaseChecker.USABLE_LEARWARE and len(self._get_hetero_learnware_ids(id)): self._update_learnware_by_ids(id) return final_status def reload_learnware(self, learnware_id: str): - super().reload_learnware(learnware_id) + super(HeteroMapTableOrganizer, self).reload_learnware(learnware_id) try: hetero_spec_path = os.path.join(self.hetero_specs_path, f"{learnware_id}.json") if os.path.exists(hetero_spec_path):