diff --git a/autogl/module/hpo/__init__.py b/autogl/module/hpo/__init__.py index e8fe41a..0a1b6c6 100644 --- a/autogl/module/hpo/__init__.py +++ b/autogl/module/hpo/__init__.py @@ -1,6 +1,7 @@ import importlib import os from .base import BaseHPOptimizer +from .auto_module import AutoModule HPO_DICT = {} @@ -52,6 +53,7 @@ def build_hpo_from_name(name: str) -> BaseHPOptimizer: __all__ = [ + "AutoModule", "BaseHPOptimizer", "AnnealAdvisorHPO", "AutoNE", diff --git a/autogl/module/model/_utils/auto_module.py b/autogl/module/hpo/auto_module.py similarity index 63% rename from autogl/module/model/_utils/auto_module.py rename to autogl/module/hpo/auto_module.py index 0499b9f..da35faa 100644 --- a/autogl/module/model/_utils/auto_module.py +++ b/autogl/module/hpo/auto_module.py @@ -26,50 +26,39 @@ class AutoModule: return self.__device @device.setter - def device(self, device: _typing.Union[torch.device, str, int, None]): - if ( - isinstance(device, str) or isinstance(device, int) or - isinstance(device, torch.device) - ): - self.__device = torch.device(device) + def device(self, __device: _typing.Union[torch.device, str, int, None]): + if type(__device) == torch.device or ( + type(__device) == str and __device.lower() != "auto" + ) or type(__device) == int: + self.__device: torch.device = torch.device(__device) else: - self.__device = torch.device("cpu") + self.__device: torch.device = torch.device( + "cuda" + if torch.cuda.is_available() and torch.cuda.device_count() > 0 + else "cpu" + ) def __init__( self, initialize: bool, device: _typing.Union[torch.device, str, int, None] = ..., *args, **kwargs ): - self.__hyper_parameter: _typing.Mapping[str, _typing.Any] = {} + self.__hyper_parameters: _typing.Mapping[str, _typing.Any] = {} self.__hyper_parameter_space: _typing.Iterable[_typing.Mapping[str, _typing.Any]] = [] - if ( - isinstance(device, str) or isinstance(device, int) or - isinstance(device, torch.device) - ): - self.__device: torch.device = torch.device(device) - else: - self.__device: torch.device = torch.device("cpu") + self.device = device self.__args: _typing.Tuple[_typing.Any, ...] = args self.__kwargs: _typing.Mapping[str, _typing.Any] = kwargs self.__initialized: bool = False if initialize: self.initialize() - @property - def hyper_parameter(self) -> _typing.Mapping[str, _typing.Any]: - return self.__hyper_parameter - - @hyper_parameter.setter - def hyper_parameter(self, hp: _typing.Mapping[str, _typing.Any]): - self.__hyper_parameter = hp - @property def hyper_parameters(self) -> _typing.Mapping[str, _typing.Any]: - return self.__hyper_parameter + return self.__hyper_parameters @hyper_parameters.setter def hyper_parameters(self, hp: _typing.Mapping[str, _typing.Any]): - self.__hyper_parameter = hp + self.__hyper_parameters = hp @property def hyper_parameter_space(self) -> _typing.Iterable[ diff --git a/autogl/module/model/decoders/base_decoder.py b/autogl/module/model/decoders/base_decoder.py index 4b30ce4..7f29990 100644 --- a/autogl/module/model/decoders/base_decoder.py +++ b/autogl/module/model/decoders/base_decoder.py @@ -1,10 +1,10 @@ import torch import typing as _typing -from .._utils import auto_module +from ...hpo import AutoModule from ..encoders import base_encoder -class BaseAutoDecoderMaintainer(auto_module.AutoModule): +class BaseAutoDecoderMaintainer(AutoModule): def _initialize(self) -> _typing.Optional[bool]: """ Abstract initialization method to override """ raise NotImplementedError @@ -24,9 +24,17 @@ class BaseAutoDecoderMaintainer(auto_module.AutoModule): super(BaseAutoDecoderMaintainer, self).__init__( initialize, device, *args, **kwargs ) - self._output_dimension: _typing.Optional[int] = output_dimension + self.output_dimension = output_dimension self._decoder: _typing.Optional[torch.nn.Module] = None + @property + def output_dimension(self): + return self.__output_dimension + + @output_dimension.setter + def output_dimension(self, output_dimension): + self.__output_dimension = output_dimension + @property def decoder(self) -> _typing.Optional[torch.nn.Module]: return self._decoder diff --git a/autogl/module/model/encoders/base_encoder.py b/autogl/module/model/encoders/base_encoder.py index 4615f79..b3dec01 100644 --- a/autogl/module/model/encoders/base_encoder.py +++ b/autogl/module/model/encoders/base_encoder.py @@ -1,9 +1,9 @@ import torch import typing as _typing -from .._utils import auto_module +from ...hpo import AutoModule -class BaseAutoEncoderMaintainer(auto_module.AutoModule): +class BaseAutoEncoderMaintainer(AutoModule): def __init__( self, initialize: bool, device: _typing.Union[torch.device, str, int, None] = ..., @@ -47,14 +47,31 @@ class AutoHomogeneousEncoderMaintainer(BaseAutoEncoderMaintainer): device: _typing.Union[torch.device, str, int, None] = ..., *args, **kwargs ): - self._input_dimension: _typing.Optional[int] = input_dimension - self._final_dimension: _typing.Optional[int] = final_dimension + self.input_dimension: _typing.Optional[int] = input_dimension + self.final_dimension: _typing.Optional[int] = final_dimension super(AutoHomogeneousEncoderMaintainer, self).__init__( initialize, device, *args, **kwargs ) self.__args: _typing.Tuple[_typing.Any, ...] = args self.__kwargs: _typing.Mapping[str, _typing.Any] = kwargs + @property + def input_dimension(self) -> _typing.Optional[int]: + return self.__input_dimension + + @input_dimension.setter + def input_dimension(self, input_dimension): + self.__input_dimension = input_dimension + + @property + def final_dimension(self): + return self.__final_dimension + + @final_dimension.setter + def final_dimension(self, final_dimension): + # TODO: may mutate search space according to the final dimension + self.__final_dimension = final_dimension + def from_hyper_parameter( self, hyper_parameter: _typing.Mapping[str, _typing.Any], **kwargs ): diff --git a/autogl/module/train/base.py b/autogl/module/train/base.py index 50a4a43..085b76e 100644 --- a/autogl/module/train/base.py +++ b/autogl/module/train/base.py @@ -3,16 +3,17 @@ import typing as _typing import torch import pickle + +from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer + from ..model import ( EncoderRegistry, DecoderRegistry, - BaseAutoModel, - BaseAutoEncoder, - BaseAutoDecoder, - AutoClassifierDecoder, - AutoHomogeneousEncoder + BaseAutoEncoderMaintainer, + BaseAutoDecoderMaintainer, + BaseAutoModel ) -from autogl.utils.autobase import AutoModule +from ..hpo import AutoModule import logging from .evaluation import Evaluation, get_feval, Acc from ...utils import get_logger @@ -94,8 +95,8 @@ class EarlyStopping: class BaseTrainer(AutoModule): def __init__( self, - encoder: _typing.Union[BaseAutoModel, BaseAutoEncoder, None], - decoder: _typing.Union[BaseAutoDecoder, None], + encoder: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, None], + decoder: _typing.Union[BaseAutoDecoderMaintainer, None], device: _typing.Union[torch.device, str], init: bool = True, feval: _typing.Union[ @@ -116,39 +117,11 @@ class BaseTrainer(AutoModule): init: `bool` If True(False), the model will (not) be initialized. """ - super().__init__() - self.init: bool = init self.encoder = encoder self.decoder = None if isinstance(encoder, BaseAutoModel) else decoder - if type(device) == torch.device or ( - type(device) == str and device.lower() != "auto" - ): - self.__device: torch.device = torch.device(device) - else: - self.__device: torch.device = torch.device( - "cuda" - if torch.cuda.is_available() and torch.cuda.device_count() > 0 - else "cpu" - ) - self.__feval: _typing.Sequence[_typing.Type[Evaluation]] = get_feval(feval) - self.loss: str = loss - - @property - def device(self) -> torch.device: - return self.__device - - @device.setter - def device(self, __device: _typing.Union[torch.device, str]): - if type(__device) == torch.device or ( - type(__device) == str and __device.lower() != "auto" - ): - self.__device: torch.device = torch.device(__device) - else: - self.__device: torch.device = torch.device( - "cuda" - if torch.cuda.is_available() and torch.cuda.device_count() > 0 - else "cpu" - ) + self.feval = feval + self.loss = loss + super().__init__(init, device) @property def feval(self) -> _typing.Sequence[_typing.Type[Evaluation]]: @@ -174,13 +147,9 @@ class BaseTrainer(AutoModule): """ self.device = device if self.encoder is not None: - self.encoder.to(self.device) + self.encoder.to_device(self.device) if self.decoder is not None: - self.decoder.to(self.device) - - def initialize(self): - """Initialize the auto model in trainer.""" - pass + self.decoder.to_device(self.device) def get_feval( self, return_major: bool = False @@ -335,8 +304,8 @@ class _BaseClassificationTrainer(BaseTrainer): def __init__( self, - encoder: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None], - decoder: _typing.Union[BaseAutoDecoder, str, None], + encoder: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None], + decoder: _typing.Union[BaseAutoDecoderMaintainer, str, None], num_features: int, num_classes: int, last_dim: _typing.Union[int, str] = "auto", @@ -371,12 +340,12 @@ class _BaseClassificationTrainer(BaseTrainer): return self.__encoder @encoder.setter - def encoder(self, enc: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None]): + def encoder(self, enc: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None]): if isinstance(enc, str): self.__encoder = EncoderRegistry.get_model(enc)( - self.num_features, last_dim=self.last_dim, device=self.device, init=self.init + self.num_features, last_dim=self.last_dim, device=self.device, init=self.initialized ) - elif isinstance(enc, BaseAutoEncoder): + elif isinstance(enc, BaseAutoEncoderMaintainer): self.__encoder = enc elif isinstance(enc, BaseAutoModel): self.__encoder = enc @@ -393,16 +362,16 @@ class _BaseClassificationTrainer(BaseTrainer): return self.__decoder @decoder.setter - def decoder(self, dec: _typing.Union[BaseAutoDecoder, str, None]): + def decoder(self, dec: _typing.Union[BaseAutoDecoderMaintainer, str, None]): if isinstance(self.encoder, BaseAutoModel): logging.warn("Ignore passed dec since enc is a whole model") self.__decoder = None return if isinstance(dec, str): self.__decoder = DecoderRegistry.get_model(dec)( - self.num_classes, input_dim=self.last_dim, device=self.device, init=self.init + self.num_classes, input_dim=self.last_dim, device=self.device, init=self.initialized ) - elif isinstance(dec, BaseAutoDecoder): + elif isinstance(dec, BaseAutoDecoderMaintainer): self.__decoder = dec elif dec is None: self.__decoder = None @@ -417,9 +386,19 @@ class _BaseClassificationTrainer(BaseTrainer): def num_classes(self, num_classes): self.__num_classes = num_classes if isinstance(self.encoder, BaseAutoModel): - self.encoder.num_classes = num_classes - elif isinstance(self.decoder, AutoClassifierDecoder): - self.decoder.num_classes = num_classes + self.encoder.output_dimension = num_classes + elif isinstance(self.decoder, BaseAutoDecoderMaintainer): + self.decoder.output_dimension = num_classes + + @property + def last_dim(self): + return self.__last_dim + + @last_dim.setter + def last_dim(self, dim): + self.__last_dim = dim + if isinstance(self.encoder, AutoHomogeneousEncoderMaintainer): + self.encoder.final_dimension = self.__last_dim @property def num_features(self): @@ -429,13 +408,13 @@ class _BaseClassificationTrainer(BaseTrainer): def num_features(self, num_features): self.__num_features = num_features if self.encoder is not None: - self.encoder.num_features = num_features + self.encoder.input_dimension = num_features class BaseNodeClassificationTrainer(_BaseClassificationTrainer): def __init__( self, - encoder: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None], - decoder: _typing.Union[BaseAutoDecoder, str, None], + encoder: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None], + decoder: _typing.Union[BaseAutoDecoderMaintainer, str, None], num_features: int, num_classes: int, device: _typing.Union[torch.device, str, None] = None, @@ -448,13 +427,27 @@ class BaseNodeClassificationTrainer(_BaseClassificationTrainer): super(BaseNodeClassificationTrainer, self).__init__( encoder, decoder, num_features, num_classes, num_classes, device, init, feval, loss ) + + # override num_classes property to support last_dim setting + @property + def num_classes(self): + return self.__num_classes + + @num_classes.setter + def num_classes(self, num_classes): + self.__num_classes = num_classes + if isinstance(self.encoder, BaseAutoModel): + self.encoder.output_dimension = num_classes + elif isinstance(self.decoder, BaseAutoDecoderMaintainer): + self.decoder.output_dimension = num_classes + self.last_dim = num_classes class BaseGraphClassificationTrainer(_BaseClassificationTrainer): def __init__( self, - encoder: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None] = None, - decoder: _typing.Union[BaseAutoDecoder, str, None] = None, + encoder: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None] = None, + decoder: _typing.Union[BaseAutoDecoderMaintainer, str, None] = None, num_features: _typing.Optional[int] = None, num_classes: _typing.Optional[int] = None, num_graph_features: int = 0, @@ -477,16 +470,16 @@ class BaseGraphClassificationTrainer(_BaseClassificationTrainer): return self.__encoder @encoder.setter - def encoder(self, enc: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None]): + def encoder(self, enc: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None]): if isinstance(enc, str): self.__encoder = EncoderRegistry.get_model(enc)( self.num_features, last_dim=self.last_dim, num_graph_features=self.num_graph_features, device=self.device, - init=self.init + init=self.initialized ) - elif isinstance(enc, (BaseAutoModel, BaseAutoEncoder)): + elif isinstance(enc, (BaseAutoModel, BaseAutoEncoderMaintainer)): self.__encoder = enc if isinstance(enc, BaseAutoModel) and self.decoder is not None: logging.warn("will disable decoder since a whole model is passed") @@ -502,7 +495,7 @@ class BaseGraphClassificationTrainer(_BaseClassificationTrainer): return self.__decoder @decoder.setter - def decoder(self, dec: _typing.Union[BaseAutoDecoder, str, None]): + def decoder(self, dec: _typing.Union[BaseAutoDecoderMaintainer, str, None]): if isinstance(self.encoder, BaseAutoModel): logging.warn("Ignore passed dec since enc is a whole model") self.__decoder = None @@ -513,19 +506,34 @@ class BaseGraphClassificationTrainer(_BaseClassificationTrainer): input_dim=self.last_dim, num_graph_features=self.num_graph_features, device=self.device, - init=self.init + init=self.initialized ) - elif isinstance(dec, (BaseAutoDecoder, None)): + elif isinstance(dec, (BaseAutoDecoderMaintainer, None)): self.__decoder = dec else: raise ValueError("Invalid decoder setting") + # override num_classes property to support last_dim setting + @property + def num_classes(self): + return self.__num_classes + + @num_classes.setter + def num_classes(self, num_classes): + self.__num_classes = num_classes + if isinstance(self.encoder, BaseAutoModel): + self.encoder.output_dimension = num_classes + elif isinstance(self.decoder, BaseAutoDecoderMaintainer): + self.decoder.output_dimension = num_classes + self.last_dim = num_classes + + # TODO: according to discussion, link prediction may not belong to classification tasks class BaseLinkPredictionTrainer(_BaseClassificationTrainer): def __init__( self, - encoder: _typing.Union[BaseAutoModel, BaseAutoEncoder, str, None] = None, - decoder: _typing.Union[BaseAutoDecoder, str, None] = None, + encoder: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None] = None, + decoder: _typing.Union[BaseAutoDecoderMaintainer, str, None] = None, num_features: _typing.Optional[int] = None, last_dim: _typing.Union[int, str] = "auto", device: _typing.Union[torch.device, str, None] = None, @@ -546,7 +554,7 @@ class BaseLinkPredictionTrainer(_BaseClassificationTrainer): return self.__decoder @decoder.setter - def decoder(self, dec: _typing.Union[BaseAutoDecoder, str, None]): + def decoder(self, dec: _typing.Union[BaseAutoDecoderMaintainer, str, None]): if isinstance(self.encoder, BaseAutoModel): logging.warn("Ignore passed dec since enc is a whole model") self.__decoder = None @@ -555,9 +563,9 @@ class BaseLinkPredictionTrainer(_BaseClassificationTrainer): self.__decoder = DecoderRegistry.get_model(dec)( input_dim=self.last_dim, device=self.device, - init=self.init + init=self.initialized ) - elif isinstance(dec, (BaseAutoDecoder, None)): + elif isinstance(dec, (BaseAutoDecoderMaintainer, None)): self.__decoder = dec else: raise ValueError("Invalid decoder setting") diff --git a/autogl/module/train/node_classification_full.py b/autogl/module/train/node_classification_full.py index 2a36142..a88361b 100644 --- a/autogl/module/train/node_classification_full.py +++ b/autogl/module/train/node_classification_full.py @@ -13,7 +13,7 @@ from torch.optim.lr_scheduler import ( ReduceLROnPlateau, ) import torch.nn.functional as F -from ..model import AutoClassifierDecoder, AutoHomogeneousEncoder, BaseAutoModel +from ..model import BaseAutoEncoderMaintainer, BaseAutoDecoderMaintainer, BaseAutoModel from .evaluation import Evaluation, get_feval, Logloss from typing import Callable, Iterable, Optional, Type, Union from copy import deepcopy @@ -58,8 +58,8 @@ class NodeClassificationFullTrainer(BaseNodeClassificationTrainer): def __init__( self, - encoder: Union[BaseAutoModel, AutoHomogeneousEncoder, str, None] = None, - decoder: Union[AutoClassifierDecoder, str, None] = "LogSoftmaxDecoder", + encoder: Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None] = None, + decoder: Union[BaseAutoDecoderMaintainer, str, None] = "LogSoftmaxDecoder", num_features: Optional[int] = None, num_classes: Optional[int] = None, optimizer: Union[str, Type[torch.optim.Optimizer]] = torch.optim.Adam, @@ -146,7 +146,7 @@ class NodeClassificationFullTrainer(BaseNodeClassificationTrainer): }, ] - self.hyper_parameter = { + self.hyper_parameters = { "max_epoch": self.max_epoch, "early_stopping_round": self.early_stopping_round, "lr": self.lr, @@ -163,7 +163,7 @@ class NodeClassificationFullTrainer(BaseNodeClassificationTrainer): self.initialized = True if isinstance(self.encoder, BaseAutoModel): self.encoder.initialize() - elif isinstance(self.encoder, AutoHomogeneousEncoder) and isinstance(self.decoder, AutoClassifierDecoder): + elif isinstance(self.encoder, BaseAutoEncoderMaintainer) and isinstance(self.decoder, BaseAutoDecoderMaintainer): self.encoder.initialize() # pass the necessary message to decoder self.decoder.initialize(self.encoder) @@ -496,7 +496,7 @@ class NodeClassificationFullTrainer(BaseNodeClassificationTrainer): encoder = encoder if encoder != "same" else self.encoder decoder = decoder if decoder != "same" else self.decoder encoder = encoder.from_hyper_parameter(hp_encoder) - if isinstance(encoder, AutoHomogeneousEncoder) and isinstance(decoder, AutoClassifierDecoder): + if isinstance(encoder, BaseAutoEncoderMaintainer) and isinstance(decoder, BaseAutoDecoderMaintainer): decoder = decoder.from_hyper_parameter_and_encoder(hp_decoder, encoder) ret = self.__class__(