| @@ -24,23 +24,6 @@ class HeteroMap(nn.Module): | |||||
| is modified for feature extraction purposes only. | is modified for feature extraction purposes only. | ||||
| The class implements a neural network module for processing tabular data, specifically tuned for numerical features. | The class implements a neural network module for processing tabular data, specifically tuned for numerical features. | ||||
| Args: | |||||
| feature_tokenizer (FeatureTokenizer, optional): Tokenizer for feature representation. | |||||
| hidden_dim (int, optional): Dimension of hidden layer. | |||||
| num_layer (int, optional): Number of layers in the transformer encoder. | |||||
| num_attention_head (int, optional): Number of attention heads in the transformer. | |||||
| hidden_dropout_prob (float, optional): Dropout probability for hidden layers. | |||||
| ffn_dim (int, optional): Dimension of feedforward network. | |||||
| projection_dim (int, optional): Dimension for projection head. | |||||
| overlap_ratio (float, optional): Overlap ratio for tokenization. | |||||
| num_partition (int, optional): Number of partitions for collation. | |||||
| temperature (float, optional): Temperature parameter for contrastive learning. | |||||
| base_temperature (float, optional): Base temperature parameter. | |||||
| activation (str, optional): Activation function for transformer layers. | |||||
| device (str, optional): Device to run the model on. | |||||
| checkpoint (str, optional): Path to a pre-trained model checkpoint. | |||||
| **kwargs: Additional keyword arguments. | |||||
| """ | """ | ||||
| def __init__( | def __init__( | ||||
| @@ -58,10 +41,41 @@ class HeteroMap(nn.Module): | |||||
| base_temperature=10, | base_temperature=10, | ||||
| activation="relu", | activation="relu", | ||||
| device="cuda:0", | device="cuda:0", | ||||
| checkpoint=None, | |||||
| cache_dir=None, | cache_dir=None, | ||||
| **kwargs, | **kwargs, | ||||
| ): | ): | ||||
| """ | |||||
| Parameters | |||||
| ---------- | |||||
| feature_tokenizer : FeatureTokenizer, optional | |||||
| Tokenizer for feature representation, by default None | |||||
| hidden_dim : int, optional | |||||
| Dimension of hidden layer, by default 128 | |||||
| num_layer : int, optional | |||||
| Number of layers in the transformer encoder, by default 2 | |||||
| num_attention_head : int, optional | |||||
| Number of attention heads in the transformer, by default 8 | |||||
| hidden_dropout_prob : int, optional | |||||
| Dropout probability for hidden layers, by default 0 | |||||
| ffn_dim : int, optional | |||||
| Dimension of feedforward network, by default 256 | |||||
| projection_dim : int, optional | |||||
| Dimension for projection head, by default 128 | |||||
| overlap_ratio : float, optional | |||||
| Overlap ratio for tokenizatio, by default 0.5 | |||||
| num_partition : int, optional | |||||
| Number of partitions for collatio, by default 3 | |||||
| temperature : int, optional | |||||
| Temperature parameter for contrastive learnin, by default 10 | |||||
| base_temperature : int, optional | |||||
| Base temperature paramete, by default 10 | |||||
| activation : str, optional | |||||
| Activation function for transformer layer, by default "relu" | |||||
| device : str, optional | |||||
| Device to run the model on, by default "cuda:0" | |||||
| cache_dir : str, optional | |||||
| The cache directory, by default None | |||||
| """ | |||||
| super(HeteroMap, self).__init__() | super(HeteroMap, self).__init__() | ||||
| self.model_args = { | self.model_args = { | ||||
| @@ -74,7 +88,7 @@ class HeteroMap(nn.Module): | |||||
| "ffn_dim": ffn_dim, | "ffn_dim": ffn_dim, | ||||
| "projection_dim": projection_dim, | "projection_dim": projection_dim, | ||||
| "activation": activation, | "activation": activation, | ||||
| "cache_dir": cache_dir | |||||
| "cache_dir": cache_dir, | |||||
| } | } | ||||
| self.model_args.update(kwargs) | self.model_args.update(kwargs) | ||||
| @@ -139,10 +153,7 @@ class HeteroMap(nn.Module): | |||||
| the directory path to save. | the directory path to save. | ||||
| """ | """ | ||||
| # save model weight state dict | # save model weight state dict | ||||
| model_info = { | |||||
| "model_state_dict": self.state_dict(), | |||||
| "model_args": self.model_args | |||||
| } | |||||
| model_info = {"model_state_dict": self.state_dict(), "model_args": self.model_args} | |||||
| torch.save(model_info, checkpoint) | torch.save(model_info, checkpoint) | ||||
| def forward(self, x, y=None): | def forward(self, x, y=None): | ||||
| @@ -361,14 +372,16 @@ class TransformerLayer(nn.Module): | |||||
| def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=None, **kwargs) -> Tensor: | def forward(self, src, src_mask=None, src_key_padding_mask=None, is_causal=None, **kwargs) -> Tensor: | ||||
| """Pass the input through the encoder layer. | """Pass the input through the encoder layer. | ||||
| Args: | |||||
| src: the sequence to the encoder layer (required). | |||||
| src_mask: the mask for the src sequence (optional). | |||||
| src_key_padding_mask: the mask for the src keys per batch (optional). | |||||
| Shape: | |||||
| see the docs in Transformer class. | |||||
| Parameters | |||||
| ---------- | |||||
| src : Any | |||||
| The sequence to the encoder layer. | |||||
| src_mask : Any, optional | |||||
| The mask for the src sequence, by default None | |||||
| src_key_padding_mask : Any, optional | |||||
| The mask for the src keys per batch, by default None | |||||
| """ | """ | ||||
| # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf | # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf | ||||
| x = src | x = src | ||||
| if self.use_layer_norm: | if self.use_layer_norm: | ||||
| @@ -427,8 +440,11 @@ class TransformerMultiLayer(nn.Module): | |||||
| self.transformer_encoder.append(stacked_transformer) | self.transformer_encoder.append(stacked_transformer) | ||||
| def forward(self, embedding, attention_mask=None, **kwargs) -> Tensor: | def forward(self, embedding, attention_mask=None, **kwargs) -> Tensor: | ||||
| """args: | |||||
| embedding: bs, num_token, hidden_dim | |||||
| """ | |||||
| Parameters | |||||
| ---------- | |||||
| embedding : Any | |||||
| bs, num_token, hidden_dim | |||||
| """ | """ | ||||
| outputs = embedding | outputs = embedding | ||||
| for i, mod in enumerate(self.transformer_encoder): | for i, mod in enumerate(self.transformer_encoder): | ||||