From 5c695ef45dd238aa9d819a4ce1d1dcf60bfb82fa Mon Sep 17 00:00:00 2001 From: Gene Date: Tue, 14 Nov 2023 15:23:58 +0800 Subject: [PATCH] [MNT] add cache dir for FeatureTokenizer --- learnware/config.py | 3 +++ learnware/market/heterogeneous/organizer/__init__.py | 4 ++-- .../heterogeneous/organizer/hetero_map/__init__.py | 9 ++------- .../organizer/hetero_map/feature_extractor.py | 10 +++++----- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/learnware/config.py b/learnware/config.py index 45c04b1..84c839b 100644 --- a/learnware/config.py +++ b/learnware/config.py @@ -63,11 +63,13 @@ LEARNWARE_FOLDER_POOL_PATH = os.path.join(LEARNWARE_POOL_PATH, "learnwares") DATABASE_PATH = os.path.join(ROOT_DIRPATH, "database") STDOUT_PATH = os.path.join(ROOT_DIRPATH, "stdout") +CACHE_PATH = os.path.join(ROOT_DIRPATH, "cache") # TODO: Delete them later os.makedirs(ROOT_DIRPATH, exist_ok=True) os.makedirs(DATABASE_PATH, exist_ok=True) os.makedirs(STDOUT_PATH, exist_ok=True) +os.makedirs(CACHE_PATH, exist_ok=True) semantic_config = { "Data": { @@ -123,6 +125,7 @@ _DEFAULT_CONFIG = { "root_path": ROOT_DIRPATH, "package_path": PACKAGE_DIRPATH, "stdout_path": STDOUT_PATH, + "cache_path": CACHE_PATH, "logging_level": logging.INFO, "logging_outfile": None, "semantic_specs": semantic_config, diff --git a/learnware/market/heterogeneous/organizer/__init__.py b/learnware/market/heterogeneous/organizer/__init__.py index 78c738c..07604b8 100644 --- a/learnware/market/heterogeneous/organizer/__init__.py +++ b/learnware/market/heterogeneous/organizer/__init__.py @@ -21,7 +21,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): os.makedirs(hetero_folder_path, exist_ok=True) self.market_mapping_path = os.path.join(hetero_folder_path, "model.bin") self.hetero_specs_path = os.path.join(hetero_folder_path, "hetero_specifications") - self.training_args = {"cache_dir": hetero_folder_path} + self.training_args = {} os.makedirs(self.hetero_specs_path, exist_ok=True) if os.path.exists(self.market_mapping_path): @@ -42,7 +42,7 @@ class HeteroMapTableOrganizer(EasyOrganizer): self._update_learnware_by_ids(self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE)) else: logger.warning(f"No market mapping to reload!") - self.market_mapping = HeteroMap(cache_dir=hetero_folder_path) + self.market_mapping = HeteroMap() def reset(self, market_id, rebuild=False, auto_update=False, auto_update_limit=100, **training_args): super(HeteroMapTableOrganizer, self).reset(market_id, rebuild) diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index 3c63580..653514c 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -59,7 +59,6 @@ class HeteroMap(nn.Module): activation="relu", device="cuda:0", checkpoint=None, - cache_dir=None, **kwargs, ): super(HeteroMap, self).__init__() @@ -74,12 +73,11 @@ class HeteroMap(nn.Module): "ffn_dim": ffn_dim, "projection_dim": projection_dim, "activation": activation, - "cache_dir": cache_dir } self.model_args.update(kwargs) if feature_tokenizer is None: - feature_tokenizer = FeatureTokenizer(cache_dir=cache_dir, **kwargs) + feature_tokenizer = FeatureTokenizer(**kwargs) self.feature_tokenizer = feature_tokenizer @@ -139,10 +137,7 @@ class HeteroMap(nn.Module): the directory path to save. """ # save model weight state dict - model_info = { - "model_state_dict": self.state_dict(), - "model_args": self.model_args - } + model_info = {"model_state_dict": self.state_dict(), "model_args": self.model_args} torch.save(model_info, checkpoint) def forward(self, x, y=None): diff --git a/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py b/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py index 89105c0..4c7bcef 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py @@ -8,6 +8,8 @@ import torch.nn.init as nn_init from torch import Tensor, nn from transformers import BertTokenizerFast +from .....config import C as conf + class WordEmbedding(nn.Module): """ @@ -65,12 +67,13 @@ class FeatureTokenizer: def __init__( self, disable_tokenizer_parallel=True, - cache_dir=None, **kwargs, ): """args: disable_tokenizer_parallel: true if use extractor for collator function in torch.DataLoader """ + cache_dir = conf["cache_path"] + os.makedirs(cache_dir, exist_ok=True) self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir=cache_dir) self.tokenizer.__dict__["model_max_length"] = 512 if disable_tokenizer_parallel: # disable tokenizer parallel @@ -95,10 +98,7 @@ class FeatureTokenizer: 'num_col_input_ids': tensor contains numerical column tokenized ids, } """ - encoded_inputs = { - "x_num": None, - "num_col_input_ids": None - } + encoded_inputs = {"x_num": None, "num_col_input_ids": None} num_cols = x.columns.tolist() if not shuffle else np.random.shuffle(x.columns.tolist()) x_num = x[num_cols].fillna(0)