|
- import os
- import torch
- import torch_geometric
- import logging
-
- import numpy as np
- import typing as _typing
- from tqdm import trange
- from typing import Union, Tuple, Sequence, Type, Callable
-
- import torch.nn.functional as F
- import torch.utils.data
-
- from torch.optim.lr_scheduler import (
- StepLR,
- MultiStepLR,
- ExponentialLR,
- ReduceLROnPlateau,
- )
-
- from .losses import NTXent_loss
- from .utils import get_view_by_name
-
- from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer
-
- from ..base import BaseTrainer, EarlyStopping, _DummyModel
- from ..evaluation import Evaluation, get_feval, Acc
- from ...model import (
- BaseAutoModel,
- BaseEncoderMaintainer,
- BaseDecoderMaintainer,
- EncoderUniversalRegistry,
- DecoderUniversalRegistry
- )
- from ....utils import get_logger
- from ....datasets import utils
-
- LOGGER = get_logger("contrastive trainer")
-
- class BaseContrastiveTrainer(BaseTrainer):
- def __init__(
- self,
- encoder: _typing.Union[BaseEncoderMaintainer, str, None],
- decoder: _typing.Union[BaseDecoderMaintainer, str, None],
- decoder_node: _typing.Union[BaseDecoderMaintainer, None],
- num_features: _typing.Union[int, None] = None,
- num_graph_features: _typing.Union[int, None] = None,
- device: _typing.Union[torch.device, str] = "auto",
- feval: _typing.Union[
- _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
- ] = (Acc,),
- loss: Union[str, Callable] = "NT_Xent",
- f_loss: Union[str, Callable] = "nll_loss",
- views_fn: _typing.Union[
- _typing.Sequence[_typing.Callable], None
- ] = None,
- aug_ratio: Union[float, Sequence[float]] = 0.2,
- graph_level: bool = True,
- node_level: bool = False,
- z_dim: _typing.Union[int, None] = None,
- z_node_dim: _typing.Union[int, None] = None,
- tau: int = 0.5,
- p_optim: Union[torch.optim.Optimizer, str] = "Adam",
- p_lr: float = 0.0001,
- p_lr_scheduler_type: str = None,
- p_weight_decay: int = 0,
- p_epoch: int = 100,
- p_early_stopping_round: int = 20,
- f_optim: Union[torch.optim.Optimizer, str] = "Adam",
- f_lr: float = 0.001,
- f_lr_scheduler_type: str = None,
- f_weight_decay: int = 0,
- f_epoch: int = 100,
- f_early_stopping_round: int = 20,
- model_path: _typing.Union[str, None] = "./models"
- ):
- """
- The basic trainer for self-supervised learning with contrastive method.
- Used to automatically train the self-supervised problems.
- Parameters
- ----------
- encoder: `BaseEncoderMaintainer`, `str` or None
- A graph encoder shared by all views.
- decoder: `BaseDecoderMaintainer`, `str` or None
- A decoder which can be understood as a projection head for graph-level representations.
- Only required if `graph_level` = True.
- decoder_node: `BaseDecoderMaintainer`, `str` or None
- A decoder which can be understood as a projection head for node-level representations.
- Only required if `node_level` = True.
- num_features: `int` or None, Optional
- The number of features in dataset.
- num_graph_features: `int` or None, Optional
- The number of graph level features in dataset.
- device: `torch.device` or `str`, Optional
- The device this trainer will use.
- When `device` = "auto", if GPU exists in the device and dependency is installed,
- the trainer will give priority to GPU, otherwise CPU will be used
- feval: a sequence of `str` or a sequence of `Evaluation`, Optional
- The evaluation methods.
- loss: `str` or `Callable`, Optional
- The loss function or the learning objective of contrastive model.
- views_fn: a list of `Callable` or None, Optional
- List of functions or augmentation methods to generate views from give graphs.
- graph_level: `bool`, Optional
- Whether to include graph-level representations
- node_level: `bool`, Optional
- Whether to include node-level representations
- z_dim: `int`, Optional
- The dimension of graph-level representations
- z_node_dim: `int`, Optional
- The dimension of node-level representations
- tau: `int`, Optional
- The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent"
- model_path: `str` or None, Optional
- The directory to restore the saved model.
- If `model_path` = None, the model will not be saved.
- """
- assert (node_level or graph_level) is True
- assert isinstance(encoder, BaseEncoderMaintainer) or isinstance(encoder, str) or encoder is None
- self.loss = self._get_loss(loss)
- self.node_level = node_level
- self.graph_level = graph_level
- self.z_dim = z_dim
- self.z_node_dim = z_node_dim
- self._encoder = None
- # parameters that record the results
- self.valid_result = None
- self.valid_result_prob = None
- self.valid_score = None
- # TODO
- # do not support method with both node-level representation and graph-level representation
- # so the decoder will be either _decoder or _decoder_node, one of them
- self._decoder = None
- self.views_fn_opt = views_fn
- self._views_fn = None
- self._aug_ratio = aug_ratio
- self.last_dim = z_dim if graph_level else z_node_dim
- self.num_features = num_features
- self.num_graph_features = num_graph_features
- self.tau = tau
- self.model_path = model_path
- if isinstance(device, str):
- if device == "auto":
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- else:
- self.device = torch.device(device)
- elif isinstance(device, torch.device):
- self.device = device
- else:
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.p_opt_received = p_optim
- self.p_optimizer = self._get_optimizer(p_optim)
- self.p_lr_scheduler_type = p_lr_scheduler_type
- self.p_lr = p_lr
- self.p_epoch = p_epoch
- self.p_weight_decay = p_weight_decay
- self.p_early_stopping_round = (
- p_early_stopping_round if p_early_stopping_round is not None else 100
- )
- self.p_early_stopping = EarlyStopping(
- patience=self.p_early_stopping_round, verbose=False
- )
- self.f_opt_received = f_optim
- self.f_optimizer = self._get_optimizer(f_optim)
- self.f_lr_scheduler_type = f_lr_scheduler_type
- self.f_lr = f_lr
- self.f_epoch = f_epoch
- self.f_weight_decay = f_weight_decay
- self.f_early_stopping_round = (
- f_early_stopping_round if f_early_stopping_round is not None else 100
- )
- self.f_early_stopping = EarlyStopping(
- patience=self.f_early_stopping_round, verbose=False
- )
- if isinstance(f_loss, str) and hasattr(F, f_loss):
- self.f_loss = getattr(F, f_loss)
- elif callable(f_loss):
- self.f_loss = f_loss
- else:
- raise NotImplementedError(f"The loss {f_loss} is not supported yet.")
- super().__init__(
- encoder=encoder,
- decoder=decoder if graph_level else decoder_node,
- device=self.device,
- feval=feval,
- loss=self.loss
- )
-
- def _get_loss(self, loss):
- if callable(loss):
- return loss
- elif isinstance(loss, str):
- assert loss in ['NT_Xent']
- return {'NT_Xent': NTXent_loss}[loss]
- else:
- raise NotImplementedError("The argument `loss` should be str or callable which returns a loss tensor")
-
- # override encoder and decoder to depend on contrastive learning
- @property
- def encoder(self):
- return self._encoder
-
- @encoder.setter
- def encoder(self, enc: _typing.Union[BaseEncoderMaintainer, str, None]):
- if isinstance(enc, str):
- if enc in EncoderUniversalRegistry:
- if self.node_level:
- self._encoder = EncoderUniversalRegistry.get_encoder(enc)(
- self.num_features, final_dimension=self.last_dim, device=self.device, init=self.initialized
- )
- elif self.graph_level:
- self._encoder = EncoderUniversalRegistry.get_encoder(enc)(
- self.num_features,
- final_dimension=self.last_dim,
- num_graph_featues=self.num_graph_features,
- device=self.device,
- init=self.initialized
- )
- else:
- raise NotImplementedError(f"Sorry. Encoder {enc} is not supported yet.")
- elif isinstance(enc, BaseEncoderMaintainer):
- self._encoder = enc
- elif enc is None:
- self._encoder = None
- else:
- raise NotImplementedError(f"Sorry. Encoder {enc} is not supported yet.")
- self.num_features = self.num_features
- self.last_dim = self.last_dim
- self.num_graph_features = self.num_graph_features
-
- @property
- def decoder(self):
- if isinstance(self.encoder, BaseAutoModel): return None
- return self._decoder
-
- @decoder.setter
- def decoder(self, dec: _typing.Union[BaseDecoderMaintainer, str, None]):
- if isinstance(self.encoder, BaseAutoModel):
- logging.warn("Ignore passed dec since enc is a whole model")
- self._decoder = None
- return
- if isinstance(dec, str):
- if self.node_level:
- self._decoder = DecoderUniversalRegistry.get_decoder(dec)(
- self.last_dim,
- input_dimension=self.last_dim,
- device=self.device,
- init=self.initialized
- )
- elif self.graph_level:
- self._decoder = DecoderUniversalRegistry.get_decoder(dec)(
- self.last_dim,
- input_dimension=self.last_dim,
- num_graph_features=self.num_graph_features,
- device=self.device,
- init=self.initialized
- )
- elif isinstance(dec, BaseDecoderMaintainer) or dec is None:
- self._decoder = dec
- else:
- raise NotImplementedError(f"Sorry. The decoder {dec} is not supported yet.")
- self.num_features = self.num_features
- self.last_dim = self.last_dim
- self.num_graph_features = self.num_graph_features
-
- @property
- def num_graph_features(self):
- return self._num_graph_features
-
- @num_graph_features.setter
- def num_graph_features(self, num_graph_featues):
- self._num_graph_features = num_graph_featues
- if self.graph_level:
- if self.encoder is not None: self.encoder.num_graph_features = self._num_graph_features
- if self.decoder is not None: self.decoder.num_graph_features = self._num_graph_features
-
- @property
- def num_features(self):
- return self._num_features
-
- @num_features.setter
- def num_features(self, num_features):
- self._num_features = num_features
- if self.encoder is not None:
- self.encoder.input_dimension = num_features
-
- @property
- def num_classes(self):
- return self._num_classes
-
- @num_classes.setter
- def num_classes(self, num_classes):
- self._num_classes = num_classes
- if self.prediction_head is not None:
- self.prediction_head.output_dimension = num_classes
-
- @property
- def views_fn(self):
- return self._views_fn
-
- @views_fn.setter
- def views_fn(self, views_fn):
- self.views_fn_opt = views_fn
- # set augmentation methods
- if isinstance(views_fn, list):
- if isinstance(self.aug_ratio, float):
- self.aug_ratio = [self.aug_ratio] * len(views_fn)
- if isinstance(self.aug_ratio, list) and len(self.aug_ratio) != len(views_fn):
- self.aug_ratio = [self.aug_ratio[0]] * len(views_fn)
- self._views_fn = self._get_views_fn(views_fn, self.aug_ratio)
-
- @property
- def aug_ratio(self):
- return self._aug_ratio
-
- @aug_ratio.setter
- def aug_ratio(self, aug_ratio):
- # set augmentation methods
- if isinstance(self.views_fn, list):
- if isinstance(aug_ratio, float):
- aug_ratio = [aug_ratio] * len(self.views_fn)
- assert len(aug_ratio) >= len(self.views_fn)
- self._aug_ratio = aug_ratio
- self._views_fn = self._get_views_fn(self.views_fn_opt, self.aug_ratio)
-
- @property
- def prediction_head(self):
- return self._prediction_head
-
- @prediction_head.setter
- def prediction_head(self, head: _typing.Union[BaseDecoderMaintainer, str, None]):
- if isinstance(self.encoder, BaseAutoModel):
- raise ValueError("Encoder shouldn't be a `BaseAutoModel` in GraphCLSemisupervisedTrainer.")
- if isinstance(head, str):
- self._prediction_head = DecoderUniversalRegistry.get_decoder(head)(
- self.num_classes,
- input_dim=self.last_dim,
- num_graph_features=self.num_graph_features,
- device=self.device,
- init=self.initialized
- )
- elif isinstance(head, BaseDecoderMaintainer) or head is None or (hasattr(head, 'fit') and hasattr(head, 'predict')):
- self._prediction_head = head
- else:
- raise NotImplementedError(f"Sorry. The head {head} is not supported yet.")
- self.num_features = self.num_features
- self.last_dim = self.last_dim
- self.num_graph_features = self.num_graph_features
- self.num_classes = self.num_classes
-
- def train(self, dataset, keep_valid_result=True):
- """
- The function of training on the given dataset and keeping valid result.
-
- Parameters
- ----------
- dataset: The graph dataset used to be trained.
- keep_valid_result: `bool`
- if True, save the validation result after training.
-
- Returns
- -------
- self: `autogl.train.GraphCLSemisupervisedTrainer`
- A reference of current trainer.
-
- """
- try:
- if self.graph_level:
- valid_loader = utils.graph_get_split(
- dataset, "val", batch_size=self.batch_size, num_workers=self.num_workers
- )
- else:
- # TODO node_level
- valid_loader = None
- except ValueError:
- valid_loader = None
- except AttributeError:
- valid_loader = None
- self._train_only(dataset)
- if keep_valid_result and valid_loader:
- # save the validation result after training
- pred = self._predict_only(valid_loader)
- self.valid_result = pred.max(1)[1]
- self.valid_result_prob = pred
- self.valid_score = self.evaluate(dataset, mask="val", feval=self.feval)
-
- def _get_optimizer(self, optimizer):
- if isinstance(optimizer, str):
- if optimizer.lower() == "adam":
- optimizer = torch.optim.Adam
- elif optimizer.lower() == "sgd":
- optimizer = torch.optim.SGD
- else:
- raise ValueError("Currently not support optimizer {}".format(optimizer))
- elif isinstance(optimizer, type) and issubclass(optimizer, torch.optim.Optimizer):
- optimizer = optimizer
- else:
- raise ValueError("Currently not support optimizer {}".format(optimizer))
- return optimizer
-
- def _get_views_fn(self, views_fn, aug_ratio):
- # GraphCL only need two kinds of augmentation methods
- if views_fn is None:
- return None
- final_views_fn = []
- for i, view in enumerate(views_fn):
- if isinstance(view, str):
- assert view in ["dropN", "permE", "subgraph", "maskN", "random2", "random3", "random4"]
- final_views_fn.append(get_view_by_name(view, aug_ratio[i]))
- else:
- final_views_fn.append(view)
- return final_views_fn
-
- def _get_scheduler(self, stage, optimizer):
- if stage == 'pretraining':
- lr_scheduler_type = self.p_lr_scheduler_type
- else:
- lr_scheduler_type = self.f_lr_scheduler_type
- if type(lr_scheduler_type) == str and lr_scheduler_type == "steplr":
- scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
- elif type(lr_scheduler_type) == str and lr_scheduler_type == "multisteplr":
- scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
- elif type(lr_scheduler_type) == str and lr_scheduler_type == "exponentiallr":
- scheduler = ExponentialLR(optimizer, gamma=0.1)
- elif (
- type(lr_scheduler_type) == str and lr_scheduler_type == "reducelronplateau"
- ):
- scheduler = ReduceLROnPlateau(optimizer, "min")
- else:
- scheduler = None
- return scheduler
-
- def _get_loader_loss(self, loader, optimizer=None, scheduler=None, mode="train"):
- epoch_loss = 0.0
- last_loss = 0.0
- if mode == "train":
- for data in loader:
- optimizer.zero_grad()
- if None in self.views_fn:
- # For view fn that returns multiple views
- views = []
- for v_fn in self.views_fn:
- if v_fn is not None:
- views += [*v_fn(data)]
- else:
- views = [v_fn(data) for v_fn in self.views_fn]
- zs = []
- for view in views:
- z = self._get_embed(view.to(self.device))
- zs.append(self.decoder.decoder(z, view.to(self.device)))
- loss = self.loss(zs, tau=self.tau)
- loss.backward()
- optimizer.step()
- if self.p_lr_scheduler_type:
- scheduler.step()
- epoch_loss += loss.item()
- last_loss = loss.item()
- else:
- for data in loader:
- if None in self.views_fn:
- # For view fn that returns multiple views
- views = []
- for v_fn in self.views_fn:
- if v_fn is not None:
- views += [*v_fn(data)]
- else:
- views = [v_fn(data) for v_fn in self.views_fn]
- zs = []
- for view in views:
- z = self._get_embed(view.to(self.device))
- zs.append(self.decoder.decoder(z, view.to(self.device)))
- loss = self.loss(zs, tau=self.tau)
- epoch_loss += loss.item()
- last_loss = loss.item()
- return epoch_loss, last_loss
-
- def _train_pretraining_only(self, dataset, per_epoch=False):
- """
- Pretraining stage
- As a matter of fact, it trains encoder, and decoder is just an auxiliary task
- """
- import torch_geometric
- if int(torch_geometric.__version__.split('.')[0]) >= 2:
- # version 2.x
- from torch_geometric.loader import DataLoader
- else:
- from torch_geometric.data import DataLoader
- if self.graph_level:
- pre_train_loader = DataLoader(
- dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True
- )
- else:
- # TODO node level
- pre_train_loader = None
- optimizer = self.p_optimizer(
- self.encoder.encoder.parameters(), lr=self.p_lr, weight_decay=self.p_weight_decay
- )
- optimizer.add_param_group({"params": self.decoder.decoder.parameters()})
- scheduler = self._get_scheduler("pretraining", optimizer)
- with trange(self.p_epoch) as t:
- for epoch in t:
- self.encoder.encoder.train()
- self.decoder.decoder.train()
- t.set_description(f"Pretraining: epoch {epoch + 1}")
- try:
- epoch_loss, last_loss = self._get_loader_loss(pre_train_loader, optimizer, scheduler, "train")
- except ValueError:
- epoch_loss = 1e10
- pass
- t.set_postfix(loss="{:.6f}".format(float(last_loss)))
- if per_epoch:
- yield epoch
- else:
- self.p_early_stopping(epoch_loss, self.encoder.encoder)
- if self.p_early_stopping.early_stop:
- LOGGER.debug("Early stopping at", epoch)
- break
-
- if not per_epoch:
- self.p_early_stopping.load_checkpoint(self.encoder.encoder)
- yield self.p_epoch
-
- def _compose_model(self, pretrain=False):
- if pretrain:
- return _DummyModel(self.encoder, self.decoder).to(self.device)
- elif self.prediction_head is not None and isinstance(self.prediction_head, BaseDecoderMaintainer):
- return _DummyModel(self.encoder, self.prediction_head).to(self.device)
- else:
- return self.encoder.encoder.to(self.device)
-
- def _get_embed(self, view):
- z = self.encoder.encoder(view)
- return z
-
- def predict(self, dataset, mask="test"):
- """
- The function of predicting on the given dataset.
-
- Parameters
- ----------
- dataset: The graph classification dataset used to be predicted.
-
- mask: `str`
- "train", "val" or "test"
- The dataset mask.
-
- Returns
- -------
- The prediction result of `predict_proba`.
- """
- if self.graph_level:
- loader = utils.graph_get_split(
- dataset, mask, batch_size=self.batch_size, num_workers=self.num_workers
- )
- else:
- # TODO node level
- loader = None
- return self._predict_proba(loader, in_log_format=True).max(1)[1]
-
- def predict_proba(self, dataset, mask="test", in_log_format=False):
- """
- The function of predicting the probability on the given dataset.
-
- Parameters
- ----------
- dataset: The graph dataset used to be predicted.
-
- mask: `str`
- "train", "val" or "test"
- The dataset mask.
-
- in_log_format: `bool`
- If True(False), the probability will (not) be log format.
-
- Returns
- -------
- The prediction result.
- """
- if self.graph_level:
- loader = utils.graph_get_split(
- dataset, mask, batch_size=self.batch_size, num_workers=self.num_workers
- )
- else:
- # TODO node level
- loader = None
- return self._predict_proba(loader, in_log_format)
-
- def _predict_proba(self, loader, in_log_format=False, return_label=False):
- if return_label:
- ret, label = self._predict_only(loader, return_label=True)
- else:
- ret = self._predict_only(loader, return_label=False)
- if in_log_format is False:
- ret = torch.exp(ret)
- if return_label:
- return ret, label
- else:
- return ret
-
- def _predict_only(self, loader, return_label=False):
- raise NotImplementedError()
-
- def evaluate(self, dataset, mask="val", feval=None):
- """
- The function of evaluating the model on the given dataset and keeping valid result.
-
- Parameters
- ----------
- dataset: The graph dataset used to be evaluated.
-
- mask: `str`
- "Train", "val" or "test"
- The dataset mask
-
- feval: `str`
- The evaluation method used in this function
-
- Returns
- -------
- res: The evaluation result on the given dataset.
- """
- if self.graph_level:
- loader = utils.graph_get_split(
- dataset, mask, batch_size=self.batch_size, num_workers=self.num_workers
- )
- else:
- # TODO node level
- loader = None
- return self._evaluate(loader, feval)
-
- def _evaluate(self, loader, feval=None):
- if feval is None:
- feval = self.feval
- else:
- feval = get_feval(feval)
-
- y_pred_prob, y_true = self._predict_proba(loader=loader, return_label=True)
- y_pred = y_pred_prob.max(1)[1]
-
- if not isinstance(feval, list):
- feval = [feval]
- return_signle = True
- else:
- return_signle = False
-
- res = []
- for f in feval:
- flag = False
- try:
- res.append(f.evaluate(y_pred_prob, y_true))
- flag = False
- except:
- flag = True
- if flag:
- try:
- res.append(
- f.evaluate(y_pred_prob.cpu().numpy(), y_true.cpu().numpy())
- )
- flag = False
- except:
- flag = True
- if flag:
- try:
- res.append(
- f.evaluate(
- y_pred_prob.detach().numpy(), y_true.detach().numpy()
- )
- )
- flag = False
- except:
- flag = True
- if flag:
- try:
- res.append(
- f.evaluate(
- y_pred_prob.cpu().detach().numpy(),
- y_true.cpu().detach().numpy(),
- )
- )
- flag = False
- except:
- flag = True
- if flag:
- assert False
-
- if return_signle:
- return res[0]
- return res
-
- def combined_hyper_parameter_space(self):
- return {
- "trainer": self.hyper_parameter_space,
- "encoder": self.encoder.hyper_parameter_space,
- "decoder": [] if self.decoder is None else self.decoder.hyper_parameter_space,
- "prediction_head": [] if self.prediction_head is None else self.prediction_head.hyper_parameter_space
- }
-
- def get_valid_predict(self):
- # """Get the valid result."""
- return self.valid_result
-
- def get_valid_predict_proba(self):
- # """Get the valid result (prediction probability)."""
- return self.valid_result_prob
-
- def get_valid_score(self, return_major=True):
- """
- The function of getting the valid score.
-
- Parameters
- ----------
- return_major: ``bool``.
- If True, then return only consists of the major result.
- If False, then return consists of the all results.
-
- Returns
- -------
- result: The valid score in training stage.
- """
- if isinstance(self.feval, list):
- if return_major:
- return self.valid_score[0], self.feval[0].is_higher_better()
- else:
- return self.valid_score, [f.is_higher_better() for f in self.feval]
- else:
- return self.valid_score, self.feval.is_higher_better()
-
- def __repr__(self) -> str:
- import yaml
-
- return yaml.dump(
- {
- "trainer_name": self.__class__.__name__,
- "p_optimizer": self.p_optimizer,
- "p_learning_rate": self.p_lr,
- "p_max_epoch": self.p_epoch,
- "p_early_stopping_round": self.p_early_stopping_round,
- "encoder": repr(self.encoder),
- "decoder": repr(self.decoder)
- }
- )
|