| @@ -5,6 +5,7 @@ import pandas as pd | |||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch import Tensor, nn | from torch import Tensor, nn | ||||
| from loguru import logger | |||||
| from .....specification import HeteroMapTableSpecification, RKMETableSpecification | from .....specification import HeteroMapTableSpecification, RKMETableSpecification | ||||
| from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer | from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer | ||||
| @@ -73,7 +74,7 @@ class HeteroMap(nn.Module): | |||||
| activation : Union[str, Callable], optional | activation : Union[str, Callable], optional | ||||
| Activation function for transformer layer, by default "relu" | Activation function for transformer layer, by default "relu" | ||||
| device : Union[str, torch.device], optional | device : Union[str, torch.device], optional | ||||
| Device to run the model on, by default "cuda:0" | |||||
| Device to run the model on, by default "cpu" | |||||
| kwargs: | kwargs: | ||||
| Additional arguments to be passed to the feature tokenizer | Additional arguments to be passed to the feature tokenizer | ||||
| """ | """ | ||||
| @@ -124,8 +125,26 @@ class HeteroMap(nn.Module): | |||||
| self.base_temperature = base_temperature | self.base_temperature = base_temperature | ||||
| self.num_partition = num_partition | self.num_partition = num_partition | ||||
| self.overlap_ratio = overlap_ratio | self.overlap_ratio = overlap_ratio | ||||
| self.device = device | |||||
| self.to(device) | self.to(device) | ||||
| def to(self, device: Union[str, torch.device]): | |||||
| """Moves the model and all its submodules to the specified device | |||||
| Parameters | |||||
| ---------- | |||||
| device : Union[str, torch.device] | |||||
| The target device to which the model and its components should be moved. | |||||
| Returns | |||||
| ------- | |||||
| HeteroMap | |||||
| The instance of HeteroMap after moving to the specified device. | |||||
| """ | |||||
| super(HeteroMap, self).to(device) | |||||
| if hasattr(self, 'feature_processor'): | |||||
| self.feature_processor.device = device | |||||
| self.device = device | |||||
| return self | |||||
| @staticmethod | @staticmethod | ||||
| def load(checkpoint: str = None): | def load(checkpoint: str = None): | ||||