Browse Source

[MNT] add cache dir for FeatureTokenizer

tags/v0.3.2
Gene 2 years ago
parent
commit
5c695ef45d
4 changed files with 12 additions and 14 deletions
  1. +3
    -0
      learnware/config.py
  2. +2
    -2
      learnware/market/heterogeneous/organizer/__init__.py
  3. +2
    -7
      learnware/market/heterogeneous/organizer/hetero_map/__init__.py
  4. +5
    -5
      learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py

+ 3
- 0
learnware/config.py View File

@@ -63,11 +63,13 @@ LEARNWARE_FOLDER_POOL_PATH = os.path.join(LEARNWARE_POOL_PATH, "learnwares")

DATABASE_PATH = os.path.join(ROOT_DIRPATH, "database")
STDOUT_PATH = os.path.join(ROOT_DIRPATH, "stdout")
CACHE_PATH = os.path.join(ROOT_DIRPATH, "cache")

# TODO: Delete them later
os.makedirs(ROOT_DIRPATH, exist_ok=True)
os.makedirs(DATABASE_PATH, exist_ok=True)
os.makedirs(STDOUT_PATH, exist_ok=True)
os.makedirs(CACHE_PATH, exist_ok=True)

semantic_config = {
"Data": {
@@ -123,6 +125,7 @@ _DEFAULT_CONFIG = {
"root_path": ROOT_DIRPATH,
"package_path": PACKAGE_DIRPATH,
"stdout_path": STDOUT_PATH,
"cache_path": CACHE_PATH,
"logging_level": logging.INFO,
"logging_outfile": None,
"semantic_specs": semantic_config,


+ 2
- 2
learnware/market/heterogeneous/organizer/__init__.py View File

@@ -21,7 +21,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
os.makedirs(hetero_folder_path, exist_ok=True)
self.market_mapping_path = os.path.join(hetero_folder_path, "model.bin")
self.hetero_specs_path = os.path.join(hetero_folder_path, "hetero_specifications")
self.training_args = {"cache_dir": hetero_folder_path}
self.training_args = {}
os.makedirs(self.hetero_specs_path, exist_ok=True)

if os.path.exists(self.market_mapping_path):
@@ -42,7 +42,7 @@ class HeteroMapTableOrganizer(EasyOrganizer):
self._update_learnware_by_ids(self.get_learnware_ids(check_status=BaseChecker.USABLE_LEARWARE))
else:
logger.warning(f"No market mapping to reload!")
self.market_mapping = HeteroMap(cache_dir=hetero_folder_path)
self.market_mapping = HeteroMap()

def reset(self, market_id, rebuild=False, auto_update=False, auto_update_limit=100, **training_args):
super(HeteroMapTableOrganizer, self).reset(market_id, rebuild)


+ 2
- 7
learnware/market/heterogeneous/organizer/hetero_map/__init__.py View File

@@ -59,7 +59,6 @@ class HeteroMap(nn.Module):
activation="relu",
device="cuda:0",
checkpoint=None,
cache_dir=None,
**kwargs,
):
super(HeteroMap, self).__init__()
@@ -74,12 +73,11 @@ class HeteroMap(nn.Module):
"ffn_dim": ffn_dim,
"projection_dim": projection_dim,
"activation": activation,
"cache_dir": cache_dir
}
self.model_args.update(kwargs)

if feature_tokenizer is None:
feature_tokenizer = FeatureTokenizer(cache_dir=cache_dir, **kwargs)
feature_tokenizer = FeatureTokenizer(**kwargs)

self.feature_tokenizer = feature_tokenizer

@@ -139,10 +137,7 @@ class HeteroMap(nn.Module):
the directory path to save.
"""
# save model weight state dict
model_info = {
"model_state_dict": self.state_dict(),
"model_args": self.model_args
}
model_info = {"model_state_dict": self.state_dict(), "model_args": self.model_args}
torch.save(model_info, checkpoint)

def forward(self, x, y=None):


+ 5
- 5
learnware/market/heterogeneous/organizer/hetero_map/feature_extractor.py View File

@@ -8,6 +8,8 @@ import torch.nn.init as nn_init
from torch import Tensor, nn
from transformers import BertTokenizerFast

from .....config import C as conf


class WordEmbedding(nn.Module):
"""
@@ -65,12 +67,13 @@ class FeatureTokenizer:
def __init__(
self,
disable_tokenizer_parallel=True,
cache_dir=None,
**kwargs,
):
"""args:
disable_tokenizer_parallel: true if use extractor for collator function in torch.DataLoader
"""
cache_dir = conf["cache_path"]
os.makedirs(cache_dir, exist_ok=True)
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir=cache_dir)
self.tokenizer.__dict__["model_max_length"] = 512
if disable_tokenizer_parallel: # disable tokenizer parallel
@@ -95,10 +98,7 @@ class FeatureTokenizer:
'num_col_input_ids': tensor contains numerical column tokenized ids,
}
"""
encoded_inputs = {
"x_num": None,
"num_col_input_ids": None
}
encoded_inputs = {"x_num": None, "num_col_input_ids": None}
num_cols = x.columns.tolist() if not shuffle else np.random.shuffle(x.columns.tolist())
x_num = x[num_cols].fillna(0)



Loading…
Cancel
Save