diff --git a/autogl/module/model/__init__.py b/autogl/module/model/__init__.py index 8ad37e2..de96fea 100644 --- a/autogl/module/model/__init__.py +++ b/autogl/module/model/__init__.py @@ -3,8 +3,8 @@ import sys from ...backend import DependentBackend from . import _utils -from .decoders import BaseAutoDecoderMaintainer, DecoderRegistry -from .encoders import BaseAutoEncoderMaintainer, AutoHomogeneousEncoderMaintainer, EncoderRegistry +from .decoders import BaseAutoDecoderMaintainer, DecoderUniversalRegistry +from .encoders import BaseAutoEncoderMaintainer, AutoHomogeneousEncoderMaintainer, EncoderUniversalRegistry # load corresponding backend model of subclass def _load_subclass_backend(backend): diff --git a/autogl/module/model/decoders/__init__.py b/autogl/module/model/decoders/__init__.py index 1749409..e3583c9 100644 --- a/autogl/module/model/decoders/__init__.py +++ b/autogl/module/model/decoders/__init__.py @@ -1,6 +1,4 @@ from .base_decoder import BaseAutoDecoderMaintainer +from .decoder_registry import DecoderUniversalRegistry -# TODO: Implement this -DecoderRegistry = None - -__all__ = ["BaseAutoDecoderMaintainer", "DecoderRegistry"] +__all__ = ["BaseAutoDecoderMaintainer", "DecoderUniversalRegistry"] diff --git a/autogl/module/model/decoders/_pyg/_pyg_decoders.py b/autogl/module/model/decoders/_pyg/_pyg_decoders.py index 3c5ba15..fec7a3d 100644 --- a/autogl/module/model/decoders/_pyg/_pyg_decoders.py +++ b/autogl/module/model/decoders/_pyg/_pyg_decoders.py @@ -274,3 +274,24 @@ class DiffPoolDecoderMaintainer(base_decoder.BaseAutoDecoderMaintainer): "dropout": 0.5, "act": "relu" } + +class _DotProductLinkPredictonDecoder(torch.nn.Module): + def forward(self, + features: _typing.Sequence[torch.Tensor], + graph: torch_geometric.data.Data, + pos_edge: torch.Tensor, + neg_edge: torch.Tensor, + **__kwargs + ): + z = features[-1] + edge_index = torch.cat([pos_edge, neg_edge], dim=-1) + logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) + return logits + +@decoder_registry.DecoderUniversalRegistry.register_decoder('lpdecoder'.lower()) +@decoder_registry.DecoderUniversalRegistry.register_decoder('dotproduct'.lower()) +@decoder_registry.DecoderUniversalRegistry.register_decoder('lp-decoder'.lower()) +@decoder_registry.DecoderUniversalRegistry.register_decoder('dot-product'.lower()) +class DotProductLinkPredictonDecoderMaintainer(base_decoder.BaseAutoDecoderMaintainer): + def _initialize(self): + self._decoder = _DotProductLinkPredictonDecoder() diff --git a/autogl/module/model/encoders/__init__.py b/autogl/module/model/encoders/__init__.py index b3a1202..489d403 100644 --- a/autogl/module/model/encoders/__init__.py +++ b/autogl/module/model/encoders/__init__.py @@ -1,10 +1,8 @@ from .base_encoder import BaseAutoEncoderMaintainer, AutoHomogeneousEncoderMaintainer - -# TODO: implement this -EncoderRegistry = None +from .encoder_registry import EncoderUniversalRegistry __all__ = [ "BaseAutoEncoderMaintainer", - "EncoderRegistry", + "EncoderUniversalRegistry", "AutoHomogeneousEncoderMaintainer", ] diff --git a/autogl/module/train/base.py b/autogl/module/train/base.py index 47c2465..553fcf1 100644 --- a/autogl/module/train/base.py +++ b/autogl/module/train/base.py @@ -7,8 +7,8 @@ import pickle from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer from ..model import ( - EncoderRegistry, - DecoderRegistry, + EncoderUniversalRegistry, + DecoderUniversalRegistry, BaseAutoEncoderMaintainer, BaseAutoDecoderMaintainer, BaseAutoModel @@ -330,7 +330,7 @@ class _BaseClassificationTrainer(BaseTrainer): @encoder.setter def encoder(self, enc: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None]): if isinstance(enc, str): - self._encoder = EncoderRegistry.get_model(enc)( + self._encoder = EncoderUniversalRegistry.get_encoder(enc)( self.num_features, last_dim=self.last_dim, device=self.device, init=self.initialized ) elif isinstance(enc, BaseAutoEncoderMaintainer): @@ -356,7 +356,7 @@ class _BaseClassificationTrainer(BaseTrainer): self._decoder = None return if isinstance(dec, str): - self._decoder = DecoderRegistry.get_model(dec)( + self._decoder = DecoderUniversalRegistry.get_decoder(dec)( self.num_classes, input_dim=self.last_dim, device=self.device, init=self.initialized ) elif isinstance(dec, BaseAutoDecoderMaintainer): @@ -458,7 +458,7 @@ class BaseGraphClassificationTrainer(_BaseClassificationTrainer): @encoder.setter def encoder(self, enc: _typing.Union[BaseAutoModel, BaseAutoEncoderMaintainer, str, None]): if isinstance(enc, str): - self._encoder = EncoderRegistry.get_model(enc)( + self._encoder = EncoderUniversalRegistry.get_encoder(enc)( self.num_features, last_dim=self.last_dim, num_graph_features=self.num_graph_features, @@ -487,7 +487,7 @@ class BaseGraphClassificationTrainer(_BaseClassificationTrainer): self._decoder = None return if isinstance(dec, str): - self._decoder = DecoderRegistry.get_model(dec)( + self._decoder = DecoderUniversalRegistry.get_decoder(dec)( self.num_classes, input_dim=self.last_dim, num_graph_features=self.num_graph_features, @@ -545,7 +545,7 @@ class BaseLinkPredictionTrainer(_BaseClassificationTrainer): self._decoder = None return if isinstance(dec, str): - self._decoder = DecoderRegistry.get_model(dec)( + self._decoder = DecoderUniversalRegistry.get_decoder(dec)( input_dim=self.last_dim, device=self.device, init=self.initialized diff --git a/autogl/module/train/link_prediction.py b/autogl/module/train/link_prediction.py deleted file mode 100644 index b0b12e6..0000000 --- a/autogl/module/train/link_prediction.py +++ /dev/null @@ -1,528 +0,0 @@ -from . import register_trainer, Evaluation -import torch -from torch.optim.lr_scheduler import StepLR -import torch.nn.functional as F -from ..model import MODEL_DICT, BaseAutoModel -from .evaluation import Auc, EVALUATE_DICT -from .base import EarlyStopping, BaseLinkPredictionTrainer -from typing import Union -from copy import deepcopy -from torch_geometric.utils import negative_sampling - -from ...utils import get_logger - -LOGGER = get_logger("link prediction trainer") - - -def get_feval(feval): - if isinstance(feval, str): - return EVALUATE_DICT[feval] - if isinstance(feval, type) and issubclass(feval, Evaluation): - return feval - if isinstance(feval, list): - return [get_feval(f) for f in feval] - raise ValueError("feval argument of type", type(feval), "is not supported!") - - -@register_trainer("LinkPredictionFull") -class LinkPredictionTrainer(BaseLinkPredictionTrainer): - """ - The link prediction trainer. - - Used to automatically train the link prediction problem. - - Parameters - ---------- - model: ``BaseAutoModel`` or ``str`` - The (name of) model used to train and predict. - - optimizer: ``Optimizer`` of ``str`` - The (name of) optimizer used to train and predict. - - lr: ``float`` - The learning rate of link prediction task. - - max_epoch: ``int`` - The max number of epochs in training. - - early_stopping_round: ``int`` - The round of early stop. - - device: ``torch.device`` or ``str`` - The device where model will be running on. - - init: ``bool`` - If True(False), the model will (not) be initialized. - """ - - space = None - - def __init__( - self, - model: Union[BaseAutoModel, str] = None, - num_features=None, - optimizer=None, - lr=1e-4, - max_epoch=100, - early_stopping_round=101, - weight_decay=1e-4, - device="auto", - init=True, - feval=[Auc], - loss="binary_cross_entropy_with_logits", - *args, - **kwargs, - ): - super().__init__(model, num_features, device, init, feval, loss) - - if type(optimizer) == str and optimizer.lower() == "adam": - self.optimizer = torch.optim.Adam - elif type(optimizer) == str and optimizer.lower() == "sgd": - self.optimizer = torch.optim.SGD - else: - self.optimizer = torch.optim.Adam - - self.lr = lr - self.max_epoch = max_epoch - self.early_stopping_round = early_stopping_round - self.device = device - self.args = args - self.kwargs = kwargs - self.weight_decay = weight_decay - - self.early_stopping = EarlyStopping( - patience=early_stopping_round, verbose=False - ) - - self.valid_result = None - self.valid_result_prob = None - self.valid_score = None - - self.initialized = False - self.device = device - - self.space = [ - { - "parameterName": "max_epoch", - "type": "INTEGER", - "maxValue": 500, - "minValue": 10, - "scalingType": "LINEAR", - }, - { - "parameterName": "early_stopping_round", - "type": "INTEGER", - "maxValue": 30, - "minValue": 10, - "scalingType": "LINEAR", - }, - { - "parameterName": "lr", - "type": "DOUBLE", - "maxValue": 1e-1, - "minValue": 1e-4, - "scalingType": "LOG", - }, - { - "parameterName": "weight_decay", - "type": "DOUBLE", - "maxValue": 1e-2, - "minValue": 1e-4, - "scalingType": "LOG", - }, - ] - - LinkPredictionTrainer.space = self.space - - self.hyperparams = { - "max_epoch": self.max_epoch, - "early_stopping_round": self.early_stopping_round, - "lr": self.lr, - "weight_decay": self.weight_decay, - } - - if init is True: - self.initialize() - - def initialize(self): - # Initialize the auto model in trainer. - if self.initialized is True: - return - self.initialized = True - self.model.set_num_classes(self.num_classes) - self.model.set_num_features(self.num_features) - self.model.initialize() - - def get_model(self): - # Get auto model used in trainer. - return self.model - - @classmethod - def get_task_name(cls): - # Get task name, i.e., `LinkPrediction`. - return "LinkPrediction" - - def train_only(self, data, train_mask=None): - """ - The function of training on the given dataset and mask. - - Parameters - ---------- - data: The link prediction dataset used to be trained. It should consist of masks, including train_mask, and etc. - train_mask: The mask used in training stage. - - Returns - ------- - self: ``autogl.train.LinkPredictionTrainer`` - A reference of current trainer. - - """ - - # data.train_mask = data.val_mask = data.test_mask = data.y = None - # data = train_test_split_edges(data) - data = data.to(self.device) - # mask = data.train_mask if train_mask is None else train_mask - optimizer = self.optimizer( - self.model.model.parameters(), lr=self.lr, weight_decay=self.weight_decay - ) - scheduler = StepLR(optimizer, step_size=100, gamma=0.1) - for epoch in range(1, self.max_epoch): - self.model.model.train() - - neg_edge_index = negative_sampling( - edge_index=data.train_pos_edge_index, - num_nodes=data.num_nodes, - num_neg_samples=data.train_pos_edge_index.size(1), - ) - - optimizer.zero_grad() - # res = self.model.model.forward(data) - z = self.model.model.lp_encode(data) - link_logits = self.model.model.lp_decode( - z, data.train_pos_edge_index, neg_edge_index - ) - link_labels = self.get_link_labels( - data.train_pos_edge_index, neg_edge_index - ) - # loss = F.binary_cross_entropy_with_logits(link_logits, link_labels) - if hasattr(F, self.loss): - loss = getattr(F, self.loss)(link_logits, link_labels) - else: - raise TypeError( - "PyTorch does not support loss type {}".format(self.loss) - ) - - loss.backward() - optimizer.step() - scheduler.step() - - if type(self.feval) is list: - feval = self.feval[0] - else: - feval = self.feval - val_loss = self.evaluate([data], mask="val", feval=feval) - if feval.is_higher_better() is True: - val_loss = -val_loss - self.early_stopping(val_loss, self.model.model) - if self.early_stopping.early_stop: - LOGGER.debug("Early stopping at %d", epoch) - break - self.early_stopping.load_checkpoint(self.model.model) - - def predict_only(self, data, test_mask=None): - """ - The function of predicting on the given dataset and mask. - - Parameters - ---------- - data: The link prediction dataset used to be predicted. - train_mask: The mask used in training stage. - - Returns - ------- - res: The result of predicting on the given dataset. - - """ - try: - mask = data.test_mask if test_mask is None else test_mask - except: - mask = None - data = data.to(self.device) - self.model.model.eval() - with torch.no_grad(): - res = self.model.model.lp_encode(data) - - if mask is None: - return res - else: - return res[mask] - - def train(self, dataset, keep_valid_result=True): - """ - The function of training on the given dataset and keeping valid result. - - Parameters - ---------- - dataset: The link prediction dataset used to be trained. - - keep_valid_result: ``bool`` - If True(False), save the validation result after training. - - Returns - ------- - self: ``autogl.train.LinkPredictionTrainer`` - A reference of current trainer. - - """ - data = dataset[0] - data.edge_index = data.train_pos_edge_index - self.train_only(data) - if keep_valid_result: - self.valid_result = self.predict_only(data) - self.valid_result_prob = self.predict_proba(dataset, "val") - self.valid_score = self.evaluate(dataset, mask="val", feval=self.feval) - - def predict(self, dataset, mask=None): - """ - The function of predicting on the given dataset. - - Parameters - ---------- - dataset: The link prediction dataset used to be predicted. - - mask: ``train``, ``val``, or ``test``. - The dataset mask. - - Returns - ------- - The prediction result of ``predict_proba``. - """ - return self.predict_proba(dataset, mask=mask, in_log_format=False) - - def predict_proba(self, dataset, mask=None, in_log_format=False): - """ - The function of predicting the probability on the given dataset. - - Parameters - ---------- - dataset: The link prediction dataset used to be predicted. - - mask: ``train``, ``val``, or ``test``. - The dataset mask. - - in_log_format: ``bool``. - If True(False), the probability will (not) be log format. - - Returns - ------- - The prediction result. - """ - data = dataset[0] - data.edge_index = data.train_pos_edge_index - data = data.to(self.device) - if mask in ["train", "val", "test"]: - pos_edge_index = data[f"{mask}_pos_edge_index"] - neg_edge_index = data[f"{mask}_neg_edge_index"] - else: - pos_edge_index = data[f"test_pos_edge_index"] - neg_edge_index = data[f"test_neg_edge_index"] - - self.model.model.eval() - with torch.no_grad(): - z = self.predict_only(data) - link_logits = self.model.model.lp_decode(z, pos_edge_index, neg_edge_index) - link_probs = link_logits.sigmoid() - - return link_probs - - def get_valid_predict(self): - # """Get the valid result.""" - return self.valid_result - - def get_valid_predict_proba(self): - # """Get the valid result (prediction probability).""" - return self.valid_result_prob - - def get_valid_score(self, return_major=True): - """ - The function of getting the valid score. - - Parameters - ---------- - return_major: ``bool``. - If True, the return only consists of the major result. - If False, the return consists of the all results. - - Returns - ------- - result: The valid score in training stage. - """ - if isinstance(self.feval, list): - if return_major: - return self.valid_score[0], self.feval[0].is_higher_better() - else: - return self.valid_score, [f.is_higher_better() for f in self.feval] - else: - return self.valid_score, self.feval.is_higher_better() - - def get_name_with_hp(self): - # """Get the name of hyperparameter.""" - name = "-".join( - [ - str(self.optimizer), - str(self.lr), - str(self.max_epoch), - str(self.early_stopping_round), - str(self.model), - str(self.device), - ] - ) - name = ( - name - + "|" - + "-".join( - [ - str(x[0]) + "-" + str(x[1]) - for x in self.model.get_hyper_parameter().items() - ] - ) - ) - return name - - def evaluate(self, dataset, mask=None, feval=None): - """ - The function of training on the given dataset and keeping valid result. - - Parameters - ---------- - dataset: The link prediction dataset used to be evaluated. - - mask: ``train``, ``val``, or ``test``. - The dataset mask. - - feval: ``str``. - The evaluation method used in this function. - - Returns - ------- - res: The evaluation result on the given dataset. - - """ - data = dataset[0] - data = data.to(self.device) - test_mask = mask - if feval is None: - feval = self.feval - else: - feval = get_feval(feval) - - if mask in ["train", "val", "test"]: - pos_edge_index = data[f"{mask}_pos_edge_index"] - neg_edge_index = data[f"{mask}_neg_edge_index"] - else: - pos_edge_index = data[f"test_pos_edge_index"] - neg_edge_index = data[f"test_neg_edge_index"] - - self.model.model.eval() - with torch.no_grad(): - link_probs = self.predict_proba(dataset, mask) - link_labels = self.get_link_labels(pos_edge_index, neg_edge_index) - - if not isinstance(feval, list): - feval = [feval] - return_signle = True - else: - return_signle = False - - res = [] - for f in feval: - res.append(f.evaluate(link_probs.cpu().numpy(), link_labels.cpu().numpy())) - if return_signle: - return res[0] - return res - - def to(self, new_device): - assert isinstance(new_device, torch.device) - self.device = new_device - if self.model is not None: - self.model.to(self.device) - - def duplicate_from_hyper_parameter(self, hp: dict, model=None, restricted=True): - """ - The function of duplicating a new instance from the given hyperparameter. - - Parameters - ---------- - hp: ``dict``. - The hyperparameter used in the new instance. - - model: The model used in the new instance of trainer. - - restricted: ``bool``. - If False(True), the hyperparameter should (not) be updated from origin hyperparameter. - - Returns - ------- - self: ``autogl.train.LinkPredictionTrainer`` - A new instance of trainer. - - """ - if not restricted: - origin_hp = deepcopy(self.hyperparams) - origin_hp.update(hp) - hp = origin_hp - if model is None: - model = self.model - model.set_num_classes(self.num_classes) - model.set_num_features(self.num_features) - model = model.from_hyper_parameter( - dict( - [ - x - for x in hp.items() - if x[0] in [y["parameterName"] for y in model.space] - ] - ) - ) - - ret = self.__class__( - model=model, - num_features=self.num_features, - optimizer=self.optimizer, - lr=hp["lr"], - max_epoch=hp["max_epoch"], - early_stopping_round=hp["early_stopping_round"], - device=self.device, - weight_decay=hp["weight_decay"], - feval=self.feval, - init=True, - *self.args, - **self.kwargs, - ) - - return ret - - def set_feval(self, feval): - # """Set the evaluation metrics.""" - self.feval = get_feval(feval) - - @property - def hyper_parameter_space(self): - # """Get the space of hyperparameter.""" - return self.space - - @hyper_parameter_space.setter - def hyper_parameter_space(self, space): - # """Set the space of hyperparameter.""" - self.space = space - LinkPredictionTrainer.space = space - - def get_hyper_parameter(self): - # """Get the hyperparameter in this trainer.""" - return self.hyperparams - - def get_link_labels(self, pos_edge_index, neg_edge_index): - E = pos_edge_index.size(1) + neg_edge_index.size(1) - link_labels = torch.zeros(E, dtype=torch.float, device=self.device) - link_labels[: pos_edge_index.size(1)] = 1.0 - return link_labels diff --git a/autogl/module/train/link_prediction_full.py b/autogl/module/train/link_prediction_full.py index a8662fe..3a979e9 100644 --- a/autogl/module/train/link_prediction_full.py +++ b/autogl/module/train/link_prediction_full.py @@ -100,7 +100,7 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer): elif isinstance(model, BaseAutoModel): encoder, decoder = model, None else: - encoder, decoder = model, "LPDecoder" + encoder, decoder = model, "lp-decoder" super().__init__(encoder, decoder, num_features, "auto", device, feval, loss) self.opt_received = optimizer diff --git a/test/trainer/pyg/link_prediction.py b/test/trainer/pyg/link_prediction.py index a450d6e..a087767 100644 --- a/test/trainer/pyg/link_prediction.py +++ b/test/trainer/pyg/link_prediction.py @@ -1,21 +1,10 @@ import torch -from autogl.module.model.encoders._pyg._gcn import GCNEncoderMaintainer from autogl.module.model import BaseAutoDecoderMaintainer from autogl.module.train import LinkPredictionTrainer from autogl.datasets import build_dataset_from_name from autogl.datasets.utils.conversion._to_pyg_dataset import general_static_graphs_to_pyg_dataset from torch_geometric.utils import train_test_split_edges -class LPDecoder(torch.nn.Module): - def forward(self, features, graph, pos_edge, neg_edge): - z = features[-1] - edge_index = torch.cat([pos_edge, neg_edge], dim=-1) - logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) - return logits - -class LPDecoderMaintainer(BaseAutoDecoderMaintainer): - def _initialize(self): - self._decoder = LPDecoder() def test_lp_trainer(): @@ -25,9 +14,7 @@ def test_lp_trainer(): data = train_test_split_edges(data, 0.1, 0.1) dataset = [data] - lp_trainer = LinkPredictionTrainer( - model=(GCNEncoderMaintainer(), LPDecoderMaintainer()), init=False - ) + lp_trainer = LinkPredictionTrainer(model='gcn', init=False) lp_trainer.num_features = data.x.size(1) lp_trainer.initialize()