diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index 0453f97..37b5d3e 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -5,6 +5,7 @@ import pandas as pd import torch import torch.nn.functional as F from torch import Tensor, nn +from loguru import logger from .....specification import HeteroMapTableSpecification, RKMETableSpecification from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer @@ -73,7 +74,7 @@ class HeteroMap(nn.Module): activation : Union[str, Callable], optional Activation function for transformer layer, by default "relu" device : Union[str, torch.device], optional - Device to run the model on, by default "cuda:0" + Device to run the model on, by default "cpu" kwargs: Additional arguments to be passed to the feature tokenizer """ @@ -124,8 +125,26 @@ class HeteroMap(nn.Module): self.base_temperature = base_temperature self.num_partition = num_partition self.overlap_ratio = overlap_ratio - self.device = device self.to(device) + + def to(self, device: Union[str, torch.device]): + """Moves the model and all its submodules to the specified device + + Parameters + ---------- + device : Union[str, torch.device] + The target device to which the model and its components should be moved. + + Returns + ------- + HeteroMap + The instance of HeteroMap after moving to the specified device. + """ + super(HeteroMap, self).to(device) + if hasattr(self, 'feature_processor'): + self.feature_processor.device = device + self.device = device + return self @staticmethod def load(checkpoint: str = None):