|
- import numpy as np
- import typing as _typing
-
- import torch
- import pickle
-
- from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer
-
- from ..model import (
- EncoderUniversalRegistry,
- DecoderUniversalRegistry,
- BaseEncoderMaintainer,
- BaseDecoderMaintainer,
- BaseAutoModel,
- ModelUniversalRegistry
- )
- from ..hpo import AutoModule
- import logging
- from .evaluation import Evaluation, get_feval, Acc
- from ...utils import get_logger
-
- LOGGER_ES = get_logger("early-stopping")
-
- class _DummyModel(torch.nn.Module):
- def __init__(self, encoder: _typing.Union[BaseEncoderMaintainer, BaseAutoModel], decoder: _typing.Optional[BaseDecoderMaintainer]):
- super().__init__()
- if isinstance(encoder, BaseAutoModel):
- self.encoder = encoder.model
- self.decoder = None
- else:
- self.encoder = encoder.encoder
- self.decoder = None if decoder is None else decoder.decoder
-
- def __str__(self, ):
- return "DummyModel(encoder={}, decoder={})".format(self.encoder, self.decoder)
-
- def encode(self, *args, **kwargs):
- return self.encoder(*args, **kwargs)
-
- def decode(self, *args, **kwargs):
- if self.decoder is None: return args[0]
- return self.decoder(*args, **kwargs)
-
- def forward(self, *args, **kwargs):
- res = self.encode(*args, **kwargs)
- return self.decode(res, *args, **kwargs)
-
- class EarlyStopping:
- """Early stops the training if validation loss doesn't improve after a given patience."""
-
- def __init__(
- self,
- patience=7,
- verbose=False,
- delta=0,
- path="checkpoint.pt",
- trace_func=LOGGER_ES.info,
- ):
- """
- Args:
- patience (int): How long to wait after last time validation loss improved.
- Default: 7
- verbose (bool): If True, prints a message for each validation loss improvement.
- Default: False
- delta (float): Minimum change in the monitored quantity to qualify as an improvement.
- Default: 0
- path (str): Path for the checkpoint to be saved to.
- Default: 'checkpoint.pt'
- trace_func (function): trace print function.
- Default: print
- """
- self.patience = 100 if patience is None else patience
- self.verbose = verbose
- self.counter = 0
- self.best_score = None
- self.early_stop = False
- self.val_loss_min = np.Inf
- self.delta = delta
- self.path = path
- self.trace_func = trace_func
-
- def __call__(self, val_loss, model):
- score = -val_loss
-
- if self.best_score is None:
- self.best_score = score
- self.save_checkpoint(val_loss, model)
- elif score <= self.best_score + self.delta:
- self.counter += 1
- if self.verbose is True:
- self.trace_func(
- f"EarlyStopping counter: {self.counter} out of {self.patience}"
- )
- if self.counter >= self.patience:
- self.early_stop = True
- else:
- self.best_score = score
- self.save_checkpoint(val_loss, model)
- self.counter = 0
-
- def save_checkpoint(self, val_loss, model):
- """Saves model when validation loss decrease."""
- if self.verbose:
- self.trace_func(
- f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..."
- )
- self.best_param = pickle.dumps(model.state_dict())
- # torch.save(model.state_dict(), self.path)
- self.val_loss_min = val_loss
-
- def load_checkpoint(self, model):
- """Load models"""
- if hasattr(self, "best_param"):
- model.load_state_dict(pickle.loads(self.best_param))
- else:
- LOGGER_ES.warn("try to load checkpoint while no checkpoint is saved")
-
-
- class BaseTrainer(AutoModule):
- def __init__(
- self,
- encoder: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, None],
- decoder: _typing.Union[BaseDecoderMaintainer, None],
- device: _typing.Union[torch.device, str],
- feval: _typing.Union[
- _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
- ] = (Acc,),
- loss: str = "nll_loss",
- ):
- """
- The basic trainer.
- Used to automatically train the problems, e.g., node classification, graph classification, etc.
- Parameters
- ----------
- model: `BaseModel` or `str`
- The (name of) model used to train and predict.
- init: `bool`
- If True(False), the model will (not) be initialized.
- """
- super().__init__(device)
- self.encoder = encoder
- self.decoder = None if isinstance(encoder, BaseAutoModel) else decoder
- self.feval = feval
- self.loss = loss
-
- def _compose_model(self):
- return _DummyModel(self.encoder, self.decoder).to(self.device)
-
- def _initialize(self):
- self.encoder.initialize()
- if self.decoder is not None:
- self.decoder.initialize(self.encoder)
-
- @property
- def feval(self) -> _typing.Sequence[_typing.Type[Evaluation]]:
- return self.__feval
-
- @feval.setter
- def feval(
- self,
- _feval: _typing.Union[
- _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
- ],
- ):
- self.__feval: _typing.Sequence[_typing.Type[Evaluation]] = get_feval(_feval)
-
- @property
- def model(self):
- # compatible with v0.2
- return self.encoder
-
- @model.setter
- def model(self, model):
- # compatible with v0.2
- self.encoder = model
-
- def to(self, device: _typing.Union[str, torch.device]):
- """
- Transfer the trainer to another device
- Parameters
- ----------
- device: `str` or `torch.device`
- The device this trainer will use
- """
- self.device = device
- if self.encoder is not None:
- self.encoder.to_device(self.device)
- if self.decoder is not None:
- self.decoder.to_device(self.device)
-
- def get_feval(
- self, return_major: bool = False
- ) -> _typing.Union[
- _typing.Type[Evaluation], _typing.Sequence[_typing.Type[Evaluation]]
- ]:
- """
- Parameters
- ----------
- return_major: ``bool``
- Wether to return the major ``feval``. Default ``False``.
- Returns
- -------
- ``evaluation`` or list of ``evaluation``:
- If ``return_major=True``, will return the major ``evaluation`` method.
- Otherwise, will return the ``evaluation`` element passed when constructing.
- """
- if return_major:
- if isinstance(self.feval, _typing.Sequence):
- return self.feval[0]
- else:
- return self.feval
- return self.feval
-
- @classmethod
- def save(cls, instance, path):
- with open(path, "wb") as output:
- pickle.dump(instance, output, pickle.HIGHEST_PROTOCOL)
-
- @classmethod
- def load(cls, path):
- with open(path, "rb") as inputs:
- instance = pickle.load(inputs)
- return instance
-
- def duplicate_from_hyper_parameter(self, *args, **kwargs) -> "BaseTrainer":
- """Create a new trainer with the given hyper parameter."""
- raise NotImplementedError()
-
- def train(self, dataset, keep_valid_result):
- """
- Train on the given dataset.
- Parameters
- ----------
- dataset: The dataset used in training.
- keep_valid_result: `bool`
- If True(False), save the validation result after training.
- Returns
- -------
- """
- raise NotImplementedError()
-
- def predict(self, dataset, mask=None):
- """
- Predict on the given dataset.
- Parameters
- ----------
- dataset: The dataset used in predicting.
- mask: `train`, `val`, or `test`.
- The dataset mask.
- Returns
- -------
- prediction result
- """
- raise NotImplementedError()
-
- def predict_proba(self, dataset, mask=None, in_log_format=False):
- """
- Predict the probability on the given dataset.
- Parameters
- ----------
- dataset: The dataset used in predicting.
- mask: `train`, `val`, or `test`.
- The dataset mask.
- in_log_format: `bool`.
- If True(False), the probability will (not) be log format.
- Returns
- -------
- The prediction result.
- """
- raise NotImplementedError()
-
- def get_valid_predict_proba(self):
- """Get the valid result (prediction probability)."""
- raise NotImplementedError()
-
- def get_valid_predict(self):
- """Get the valid result."""
- raise NotImplementedError()
-
- def get_valid_score(self, return_major=True):
- """Get the validation score."""
- raise NotImplementedError()
-
- def __repr__(self) -> str:
- raise NotImplementedError
-
- def evaluate(self, dataset, mask=None, feval=None):
- """
- Parameters
- ----------
- dataset: The dataset used in evaluation.
- mask: `train`, `val`, or `test`.
- The dataset mask.
- feval: The evaluation methods.
- Returns
- -------
- The evaluation result.
- """
- raise NotImplementedError
-
- def update_parameters(self, **kwargs):
- """
- Update parameters of this trainer
- """
- for k, v in kwargs.items():
- if k == "feval":
- self.feval = get_feval(v)
- elif k == "device":
- self.to(v)
- elif hasattr(self, k):
- setattr(self, k, v)
- else:
- raise KeyError("Cannot set parameter", k, "for trainer", self.__class__)
-
- def combined_hyper_parameter_space(self):
- return {
- "trainer": self.hyper_parameter_space,
- "encoder": self.encoder.hyper_parameter_space,
- "decoder": [] if self.decoder is None else self.decoder.hyper_parameter_space
- }
-
-
- class _BaseClassificationTrainer(BaseTrainer):
- """ Base class of trainer for classification tasks """
-
- def __init__(
- self,
- encoder: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None],
- decoder: _typing.Union[BaseDecoderMaintainer, str, None],
- num_features: int,
- num_classes: int,
- last_dim: _typing.Union[int, str] = "auto",
- device: _typing.Union[torch.device, str, None] = "auto",
- feval: _typing.Union[
- _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
- ] = (Acc,),
- loss: str = "nll_loss",
- ):
- self._encoder = None
- self._decoder = None
- self.num_features = num_features
- self.num_classes = num_classes
- self.last_dim: _typing.Union[int, str] = last_dim
- super(_BaseClassificationTrainer, self).__init__(
- encoder, decoder, device, feval, loss
- )
-
- @property
- def encoder(self):
- return self._encoder
-
- @encoder.setter
- def encoder(self, enc: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None]):
- if isinstance(enc, str):
- if enc in EncoderUniversalRegistry:
- self._encoder = EncoderUniversalRegistry.get_encoder(enc)(
- self.num_features, final_dimension=self.last_dim, device=self.device, init=self.initialized
- )
- else:
- self._encoder = ModelUniversalRegistry.get_model(enc)(
- self.num_features, final_dimension=self.last_dim, device=self.device
- )
-
- elif isinstance(enc, BaseEncoderMaintainer):
- self._encoder = enc
- elif isinstance(enc, BaseAutoModel):
- self._encoder = enc
- if self.decoder is not None:
- logging.warn("will disable decoder since a whole model is passed")
- self.decoder = None
- elif enc is None:
- self._encoder = None
- else:
- raise ValueError("Enc {} is not supported!".format(enc))
- self.num_features = self.num_features
- self.num_classes = self.num_classes
- self.last_dim = self.last_dim
-
- @property
- def decoder(self):
- return self._decoder
-
- @decoder.setter
- def decoder(self, dec: _typing.Union[BaseDecoderMaintainer, str, None]):
- if isinstance(self.encoder, BaseAutoModel):
- logging.warn("Ignore passed dec since enc is a whole model")
- self._decoder = None
- return
- if isinstance(dec, str):
- self._decoder = DecoderUniversalRegistry.get_decoder(dec)(
- self.num_classes, input_dimension=self.last_dim, device=self.device, init=self.initialized
- )
- elif isinstance(dec, BaseDecoderMaintainer):
- self._decoder = dec
- elif dec is None:
- self._decoder = None
- else:
- raise ValueError("Dec {} is not supported!".format(dec))
- self.num_features = self.num_features
- self.num_classes = self.num_classes
- self.last_dim = self.last_dim
-
- @property
- def num_classes(self):
- return self.__num_classes
-
- @num_classes.setter
- def num_classes(self, num_classes):
- self.__num_classes = num_classes
- if isinstance(self.encoder, BaseAutoModel):
- self.encoder.output_dimension = num_classes
- elif isinstance(self.decoder, BaseDecoderMaintainer):
- self.decoder.output_dimension = num_classes
-
- @property
- def last_dim(self):
- return self._last_dim
-
- @last_dim.setter
- def last_dim(self, dim):
- self._last_dim = dim
- if isinstance(self.encoder, AutoHomogeneousEncoderMaintainer):
- self.encoder.final_dimension = self._last_dim
-
- @property
- def num_features(self):
- return self._num_features
-
- @num_features.setter
- def num_features(self, num_features):
- self._num_features = num_features
- if self.encoder is not None:
- self.encoder.input_dimension = num_features
-
- class BaseNodeClassificationTrainer(_BaseClassificationTrainer):
- def __init__(
- self,
- encoder: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None],
- decoder: _typing.Union[BaseDecoderMaintainer, str, None],
- num_features: int,
- num_classes: int,
- device: _typing.Union[torch.device, str, None] = None,
- feval: _typing.Union[
- _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
- ] = (Acc,),
- loss: str = "nll_loss",
- ):
- super(BaseNodeClassificationTrainer, self).__init__(
- encoder, decoder, num_features, num_classes, num_classes, device, feval, loss
- )
-
- # override num_classes property to support last_dim setting
- @property
- def num_classes(self):
- return self.__num_classes
-
- @num_classes.setter
- def num_classes(self, num_classes):
- self.__num_classes = num_classes
- if isinstance(self.encoder, BaseAutoModel):
- self.encoder.output_dimension = num_classes
- elif isinstance(self.decoder, BaseDecoderMaintainer):
- self.decoder.output_dimension = num_classes
- self.last_dim = num_classes
-
-
- class BaseGraphClassificationTrainer(_BaseClassificationTrainer):
- def __init__(
- self,
- encoder: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None] = None,
- decoder: _typing.Union[BaseDecoderMaintainer, str, None] = None,
- num_features: _typing.Optional[int] = None,
- num_classes: _typing.Optional[int] = None,
- num_graph_features: int = 0,
- last_dim: _typing.Union[int, str] = "auto",
- device: _typing.Union[torch.device, str, None] = None,
- feval: _typing.Union[
- _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
- ] = (Acc,),
- loss: str = "nll_loss",
- ):
- self._encoder = None
- self._decoder = None
- self.num_graph_features: int = num_graph_features
- super(BaseGraphClassificationTrainer, self).__init__(
- encoder, decoder, num_features, num_classes, last_dim, device, feval, loss
- )
-
- # override encoder and decoder to depend on graph level features
- @property
- def encoder(self):
- return self._encoder
-
- @encoder.setter
- def encoder(self, enc: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None]):
- if isinstance(enc, str):
- if enc in EncoderUniversalRegistry:
- self._encoder = EncoderUniversalRegistry.get_encoder(enc)(
- self.num_features,
- last_dim=self.last_dim,
- num_graph_features=self.num_graph_features,
- device=self.device,
- init=self.initialized
- )
- else:
- self._encoder = ModelUniversalRegistry.get_model(enc)(
- self.num_features,
- self.last_dim,
- device=self.device,
- num_graph_features=self.num_graph_features,
- )
-
- elif isinstance(enc, (BaseAutoModel, BaseEncoderMaintainer)):
- self._encoder = enc
- if isinstance(enc, BaseAutoModel) and self.decoder is not None:
- logging.warn("will disable decoder since a whole model is passed")
- self.decoder = None
- elif enc is None:
- self._encoder = None
- else:
- raise ValueError("Enc {} is not supported!".format(enc))
- self.num_features = self.num_features
- self.num_classes = self.num_classes
- self.last_dim = self.last_dim
- self.num_graph_features = self.num_graph_features
-
-
- @property
- def decoder(self):
- if isinstance(self.encoder, BaseAutoModel): return None
- return self._decoder
-
- @decoder.setter
- def decoder(self, dec: _typing.Union[BaseDecoderMaintainer, str, None]):
- if isinstance(self.encoder, BaseAutoModel):
- logging.warn("Ignore passed dec since enc is a whole model")
- self._decoder = None
- return
- if isinstance(dec, str):
- self._decoder = DecoderUniversalRegistry.get_decoder(dec)(
- self.num_classes,
- input_dim=self.last_dim,
- num_graph_features=self.num_graph_features,
- device=self.device,
- init=self.initialized
- )
- elif isinstance(dec, (BaseDecoderMaintainer, None)):
- self._decoder = dec
- else:
- raise ValueError("Invalid decoder setting")
- self.num_features = self.num_features
- self.num_classes = self.num_classes
- self.last_dim = self.last_dim
-
- # override num_classes property to support last_dim setting
- @property
- def num_classes(self):
- return self.__num_classes
-
- @num_classes.setter
- def num_classes(self, num_classes):
- self.__num_classes = num_classes
- if isinstance(self.encoder, BaseAutoModel):
- self.encoder.output_dimension = num_classes
- elif isinstance(self.decoder, BaseDecoderMaintainer):
- self.decoder.output_dimension = num_classes
-
- @property
- def num_graph_features(self):
- return self._num_graph_features
-
- @num_graph_features.setter
- def num_graph_features(self, num_graph_features):
- self._num_graph_features = num_graph_features
- if self.encoder is not None: self.encoder.num_graph_features = self._num_graph_features
- if self.decoder is not None: self.decoder.num_graph_features = self._num_graph_features
-
-
- # TODO: according to discussion, link prediction may not belong to classification tasks
- class BaseLinkPredictionTrainer(_BaseClassificationTrainer):
- def __init__(
- self,
- encoder: _typing.Union[BaseAutoModel, BaseEncoderMaintainer, str, None] = None,
- decoder: _typing.Union[BaseDecoderMaintainer, str, None] = None,
- num_features: _typing.Optional[int] = None,
- last_dim: _typing.Union[int, str] = "auto",
- device: _typing.Union[torch.device, str, None] = None,
- feval: _typing.Union[
- _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
- ] = (Acc,),
- loss: str = "nll_loss",
- ):
- super(BaseLinkPredictionTrainer, self).__init__(
- encoder, decoder, num_features, 2, last_dim, device, feval, loss
- )
-
- # override decoder since no num_classes is needed
- @property
- def decoder(self):
- if isinstance(self.encoder, BaseAutoModel): return None
- return self._decoder
-
- @decoder.setter
- def decoder(self, dec: _typing.Union[BaseDecoderMaintainer, str, None]):
- if isinstance(self.encoder, BaseAutoModel):
- logging.warn("Ignore passed dec since enc is a whole model")
- self._decoder = None
- return
- if isinstance(dec, str):
- self._decoder = DecoderUniversalRegistry.get_decoder(dec)(
- input_dim=self.last_dim,
- device=self.device,
- init=self.initialized
- )
- elif isinstance(dec, BaseDecoderMaintainer):
- self._decoder = dec
- elif dec is None:
- self._decoder = None
- else:
- raise ValueError("Invalid decoder setting")
- self.num_features = self.num_features
- self.num_classes = self.num_classes
- self.last_dim = self.last_dim
-
-
- # ============== Het =================
- class BaseNodeClassificationHetTrainer(BaseNodeClassificationTrainer):
- """ Base class of trainer for classification tasks """
-
- def __init__(
- self,
- model: _typing.Union[BaseAutoModel, str],
- dataset: None,
- num_features: int,
- num_classes: int,
- device: _typing.Union[torch.device, str, None] = "auto",
- feval: _typing.Union[
- _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
- ] = (Acc,),
- loss: str = "nll_loss",
- ):
- self._dataset = dataset
- super(BaseNodeClassificationHetTrainer, self).__init__(
- model, None, num_features, num_classes, device, feval, loss
- )
- self.from_dataset(dataset)
-
- def from_dataset(self, dataset):
- self._dataset = dataset
- if self.encoder is not None:
- self.encoder.from_dataset(self._dataset)
-
- @property
- def encoder(self):
- return self._encoder
-
- @encoder.setter
- def encoder(self, enc: _typing.Union[BaseAutoModel, str, None]):
- if isinstance(enc, str):
- self._encoder = ModelUniversalRegistry.get_model(enc)(
- self.num_features, self.num_classes, device=self.device
- )
- elif isinstance(enc, BaseAutoModel):
- self._encoder = enc
- if self.decoder is not None:
- logging.warn("will disable decoder since a whole model is passed")
- self.decoder = None
- elif enc is None:
- self._encoder = None
- else:
- raise ValueError("Enc {} is not supported!".format(enc))
- self.num_features = self.num_features
- self.num_classes = self.num_classes
- if self._dataset is not None:
- self.from_dataset(self._dataset)
|