|
|
|
@@ -8,6 +8,8 @@ import torch.nn.init as nn_init |
|
|
|
from torch import Tensor, nn |
|
|
|
from transformers import BertTokenizerFast |
|
|
|
|
|
|
|
from .....config import C as conf |
|
|
|
|
|
|
|
|
|
|
|
class WordEmbedding(nn.Module): |
|
|
|
""" |
|
|
|
@@ -65,12 +67,13 @@ class FeatureTokenizer: |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
disable_tokenizer_parallel=True, |
|
|
|
cache_dir=None, |
|
|
|
**kwargs, |
|
|
|
): |
|
|
|
"""args: |
|
|
|
disable_tokenizer_parallel: true if use extractor for collator function in torch.DataLoader |
|
|
|
""" |
|
|
|
cache_dir = conf["cache_path"] |
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", cache_dir=cache_dir) |
|
|
|
self.tokenizer.__dict__["model_max_length"] = 512 |
|
|
|
if disable_tokenizer_parallel: # disable tokenizer parallel |
|
|
|
@@ -95,10 +98,7 @@ class FeatureTokenizer: |
|
|
|
'num_col_input_ids': tensor contains numerical column tokenized ids, |
|
|
|
} |
|
|
|
""" |
|
|
|
encoded_inputs = { |
|
|
|
"x_num": None, |
|
|
|
"num_col_input_ids": None |
|
|
|
} |
|
|
|
encoded_inputs = {"x_num": None, "num_col_input_ids": None} |
|
|
|
num_cols = x.columns.tolist() if not shuffle else np.random.shuffle(x.columns.tolist()) |
|
|
|
x_num = x[num_cols].fillna(0) |
|
|
|
|
|
|
|
|