|
|
|
@@ -5,6 +5,7 @@ import pandas as pd |
|
|
|
import torch |
|
|
|
import torch.nn.functional as F |
|
|
|
from torch import Tensor, nn |
|
|
|
from loguru import logger |
|
|
|
|
|
|
|
from .....specification import HeteroMapTableSpecification, RKMETableSpecification |
|
|
|
from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer |
|
|
|
@@ -73,7 +74,7 @@ class HeteroMap(nn.Module): |
|
|
|
activation : Union[str, Callable], optional |
|
|
|
Activation function for transformer layer, by default "relu" |
|
|
|
device : Union[str, torch.device], optional |
|
|
|
Device to run the model on, by default "cuda:0" |
|
|
|
Device to run the model on, by default "cpu" |
|
|
|
kwargs: |
|
|
|
Additional arguments to be passed to the feature tokenizer |
|
|
|
""" |
|
|
|
@@ -124,8 +125,26 @@ class HeteroMap(nn.Module): |
|
|
|
self.base_temperature = base_temperature |
|
|
|
self.num_partition = num_partition |
|
|
|
self.overlap_ratio = overlap_ratio |
|
|
|
self.device = device |
|
|
|
self.to(device) |
|
|
|
|
|
|
|
def to(self, device: Union[str, torch.device]): |
|
|
|
"""Moves the model and all its submodules to the specified device |
|
|
|
|
|
|
|
Parameters |
|
|
|
---------- |
|
|
|
device : Union[str, torch.device] |
|
|
|
The target device to which the model and its components should be moved. |
|
|
|
|
|
|
|
Returns |
|
|
|
------- |
|
|
|
HeteroMap |
|
|
|
The instance of HeteroMap after moving to the specified device. |
|
|
|
""" |
|
|
|
super(HeteroMap, self).to(device) |
|
|
|
if hasattr(self, 'feature_processor'): |
|
|
|
self.feature_processor.device = device |
|
|
|
self.device = device |
|
|
|
return self |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def load(checkpoint: str = None): |
|
|
|
|