From cc540f2a7c3397701380bf401db5d39510943a84 Mon Sep 17 00:00:00 2001 From: liuht Date: Wed, 15 Nov 2023 16:58:24 +0800 Subject: [PATCH 1/2] [FIX] fix hetero_map 'to(device)' bug --- .../organizer/hetero_map/__init__.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index 0453f97..dc3b90e 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -5,6 +5,7 @@ import pandas as pd import torch import torch.nn.functional as F from torch import Tensor, nn +from loguru import logger from .....specification import HeteroMapTableSpecification, RKMETableSpecification from .feature_extractor import CLSToken, FeatureProcessor, FeatureTokenizer @@ -73,7 +74,7 @@ class HeteroMap(nn.Module): activation : Union[str, Callable], optional Activation function for transformer layer, by default "relu" device : Union[str, torch.device], optional - Device to run the model on, by default "cuda:0" + Device to run the model on, by default "cpu" kwargs: Additional arguments to be passed to the feature tokenizer """ @@ -124,8 +125,26 @@ class HeteroMap(nn.Module): self.base_temperature = base_temperature self.num_partition = num_partition self.overlap_ratio = overlap_ratio - self.device = device self.to(device) + + def to(self, device: Union[str, torch.device]): + """Moves the model and all its submodules to the specified device + + Parameters + ---------- + device : Union[str, torch.device] + The target device to which the model and its components should be moved. + + Returns + ------- + HeteroMap + The instance of HeteroMap after moving to the specified device. + """ + super(HeteroMap, self).to(device) + if hasattr(self, 'feature_processor'): + self.feature_processor.device = device + self.device = device + return self @staticmethod def load(checkpoint: str = None): @@ -255,7 +274,9 @@ class HeteroMap(nn.Module): if isinstance(x, pd.DataFrame): inputs = self.feature_tokenizer(x) elif isinstance(x, torch.Tensor): + logger.info(f"extract features, input device:{x.device}") inputs = self.feature_tokenizer.forward(cols, x) + logger.info(f"extract features, output device:{inputs['x_num'].device}") else: raise ValueError(f"feature_tokenizer takes inputs with dict or pd.DataFrame, find {type(x)}.") From cc5c336183b3bea75f7c4773c5bd0210d1f79b8a Mon Sep 17 00:00:00 2001 From: liuht Date: Wed, 15 Nov 2023 17:03:38 +0800 Subject: [PATCH 2/2] [FIX] delete print --- learnware/market/heterogeneous/organizer/hetero_map/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py index dc3b90e..37b5d3e 100644 --- a/learnware/market/heterogeneous/organizer/hetero_map/__init__.py +++ b/learnware/market/heterogeneous/organizer/hetero_map/__init__.py @@ -274,9 +274,7 @@ class HeteroMap(nn.Module): if isinstance(x, pd.DataFrame): inputs = self.feature_tokenizer(x) elif isinstance(x, torch.Tensor): - logger.info(f"extract features, input device:{x.device}") inputs = self.feature_tokenizer.forward(cols, x) - logger.info(f"extract features, output device:{inputs['x_num'].device}") else: raise ValueError(f"feature_tokenizer takes inputs with dict or pd.DataFrame, find {type(x)}.")