diff --git a/autogl/module/train/__init__.py b/autogl/module/train/__init__.py index 87fa030..93cbbf0 100644 --- a/autogl/module/train/__init__.py +++ b/autogl/module/train/__init__.py @@ -1,8 +1,14 @@ import importlib import os -from .base import BaseTrainer, Evaluation, EarlyStopping TRAINER_DICT = {} +EVALUATE_DICT = {} +from .base import ( + BaseTrainer, + Evaluation, + BaseNodeClassificationTrainer, + BaseGraphClassificationTrainer, +) def register_trainer(name): @@ -19,9 +25,6 @@ def register_trainer(name): return register_trainer_cls -EVALUATE_DICT = {} - - def register_evaluate(*name): def register_evaluate_cls(cls): for n in name: @@ -47,14 +50,16 @@ def get_feval(feval): raise ValueError("feval argument of type", type(feval), "is not supported!") -from .graph_classification import GraphClassificationTrainer -from .node_classification import NodeClassificationTrainer +from .graph_classification_full import GraphClassificationFullTrainer +from .node_classification_full import NodeClassificationFullTrainer from .evaluate import Acc, Auc, Logloss __all__ = [ "BaseTrainer", - "GraphClassificationTrainer", - "NodeClassificationTrainer", + "BaseNodeClassificationTrainer", + "BaseGraphClassificationTrainer", + "GraphClassificationFullTrainer", + "NodeClassificationFullTrainer", "Evaluation", "Acc", "Auc", diff --git a/autogl/module/train/base.py b/autogl/module/train/base.py index b765d80..b0ad872 100644 --- a/autogl/module/train/base.py +++ b/autogl/module/train/base.py @@ -1,12 +1,25 @@ import numpy as np from typing import Union, Iterable -from ..model import BaseModel + +import torch +from ..model import BaseModel, MODEL_DICT import pickle from ...utils import get_logger +from . import EVALUATE_DICT LOGGER_ES = get_logger("early-stopping") +def get_feval(feval): + if isinstance(feval, str): + return EVALUATE_DICT[feval] + if isinstance(feval, type) and issubclass(feval, Evaluation): + return feval + if isinstance(feval, list): + return [get_feval(f) for f in feval] + raise ValueError("feval argument of type", type(feval), "is not supported!") + + class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" @@ -81,17 +94,11 @@ class EarlyStopping: class BaseTrainer: def __init__( self, - model: Union[BaseModel, str], - optimizer=None, - lr=None, - max_epoch=None, - early_stopping_round=None, - device=None, + model: BaseModel, + device: Union[torch.device, str], init=True, feval=["acc"], loss="nll_loss", - *args, - **kwargs, ): """ The basic trainer. @@ -103,29 +110,26 @@ class BaseTrainer: model: `BaseModel` or `str` The (name of) model used to train and predict. - optimizer: `Optimizer` of `str` - The (name of) optimizer used to train and predict. - - lr: `float` - The learning rate. - - max_epoch: `int` - The max number of epochs in training. - - early_stopping_round: `int` - The round of early stop. - - device: `torch.device` or `str` - The device where model will be running on. - init: `bool` If True(False), the model will (not) be initialized. + """ + super().__init__() + self.model = model + self.to(device) + self.init = init + self.feval = get_feval(feval) + self.loss = loss - args: Other parameters. + def to(self, device): + """ + Migrate trainer to new device - kwargs: Other parameters. + Parameters + ---------- + device: `str` or `torch.device` + The device this trainer will use """ - super().__init__() + self.device = torch.device(device) def initialize(self): """Initialize the auto model in trainer.""" @@ -169,8 +173,8 @@ class BaseTrainer: @classmethod def load(cls, path): - with open(path, "rb") as input: - instance = pickle.load(input) + with open(path, "rb") as inputs: + instance = pickle.load(inputs) return instance @property @@ -279,7 +283,21 @@ class BaseTrainer: def set_feval(self, feval): """Set the evaluation metrics.""" - raise NotImplementedError() + self.feval = get_feval(feval) + + def update_parameters(self, **kwargs): + """ + Update parameters of this trainer + """ + for k, v in kwargs.items(): + if k == "feval": + self.set_feval(v) + elif k == "device": + self.to(v) + elif hasattr(self, k): + setattr(self, k, v) + else: + raise KeyError("Cannot set parameter", k, "for trainer", self.__class__) # a static class for evaluating results @@ -296,7 +314,7 @@ class Evaluation: """ Should return whether this evaluation method is higher better (bool) """ - raise True + return True @staticmethod def evaluate(predict, label): @@ -304,3 +322,84 @@ class Evaluation: Should return: the evaluation result (float) """ raise NotImplementedError() + + +class BaseNodeClassificationTrainer(BaseTrainer): + def __init__( + self, + model: Union[BaseModel, str], + num_features, + num_classes, + device="auto", + init=True, + feval=["acc"], + loss="nll_loss", + ): + self.num_features = num_features + self.num_classes = num_classes + device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device == "auto" + else torch.device(device) + ) + if isinstance(model, str): + assert model in MODEL_DICT, "Cannot parse model name " + model + self.model = MODEL_DICT[model](num_features, num_classes, device, init=init) + elif isinstance(model, BaseModel): + self.model = model + else: + raise TypeError( + "Model argument only support str or BaseModel, get", + type(model), + "instead.", + ) + super().__init__(model, device=device, init=init, feval=feval, loss=loss) + + @classmethod + def get_task_name(cls): + return "GraphClassification" + + +class BaseGraphClassificationTrainer(BaseTrainer): + def __init__( + self, + model: Union[BaseModel, str], + num_features, + num_classes, + num_graph_features=0, + device=None, + init=True, + feval=["acc"], + loss="nll_loss", + ): + self.num_features = num_features + self.num_classes = num_classes + self.num_graph_features = num_graph_features + device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device == "auto" + else torch.device(device) + ) + if isinstance(model, str): + assert model in MODEL_DICT, "Cannot parse model name " + model + self.model = MODEL_DICT[model]( + num_features, + num_classes, + device, + init=init, + num_graph_features=num_graph_features, + ) + elif isinstance(model, BaseModel): + self.model = model + else: + raise TypeError( + "Model argument only support str or BaseModel, get", + type(model), + "instead.", + ) + + super().__init__(model, device=device, init=init, feval=feval, loss=loss) + + @classmethod + def get_task_name(cls): + return "NodeClassification" diff --git a/autogl/module/train/graph_classification.py b/autogl/module/train/graph_classification_full.py similarity index 92% rename from autogl/module/train/graph_classification.py rename to autogl/module/train/graph_classification_full.py index 6f0e3bb..ac3a05d 100644 --- a/autogl/module/train/graph_classification.py +++ b/autogl/module/train/graph_classification_full.py @@ -1,4 +1,5 @@ -from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping +from . import register_trainer, EVALUATE_DICT +from .base import BaseGraphClassificationTrainer, EarlyStopping, Evaluation import torch from torch.optim.lr_scheduler import ( StepLR, @@ -28,8 +29,8 @@ def get_feval(feval): raise ValueError("feval argument of type", type(feval), "is not supported!") -@register_trainer("GraphClassification") -class GraphClassificationTrainer(BaseTrainer): +@register_trainer("GraphClassificationFull") +class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): """ The graph classification trainer. @@ -73,7 +74,7 @@ class GraphClassificationTrainer(BaseTrainer): batch_size=None, early_stopping_round=7, weight_decay=1e-4, - device=None, + device="auto", init=True, feval=[Logloss], loss="nll_loss", @@ -81,22 +82,16 @@ class GraphClassificationTrainer(BaseTrainer): *args, **kwargs ): - super(GraphClassificationTrainer, self).__init__(model) - - self.loss_type = loss - - # init model - if isinstance(model, str): - assert model in MODEL_DICT, "Cannot parse model name " + model - self.model = MODEL_DICT[model]( - num_features, - num_classes, - device, - init=init, - num_graph_features=num_graph_features, - ) - elif isinstance(model, BaseModel): - self.model = model + super().__init__( + model, + num_features, + num_classes, + num_graph_features=num_graph_features, + device=device, + init=init, + feval=feval, + loss=loss, + ) self.opt_received = optimizer if type(optimizer) == str and optimizer.lower() == "adam": @@ -108,9 +103,6 @@ class GraphClassificationTrainer(BaseTrainer): self.lr_scheduler_type = lr_scheduler_type - self.num_features = num_features - self.num_classes = num_classes - self.num_graph_features = num_graph_features self.lr = lr if lr is not None else 1e-4 self.max_epoch = max_epoch if max_epoch is not None else 100 self.batch_size = batch_size if batch_size is not None else 64 @@ -135,8 +127,6 @@ class GraphClassificationTrainer(BaseTrainer): self.valid_score = None self.initialized = False - self.num_features = num_features - self.num_classes = num_classes self.device = device self.space = [ @@ -176,8 +166,6 @@ class GraphClassificationTrainer(BaseTrainer): "scalingType": "LOG", }, ] - # self.space += self.model.space - GraphClassificationTrainer.space = self.space self.hyperparams = { "max_epoch": self.max_epoch, @@ -186,7 +174,6 @@ class GraphClassificationTrainer(BaseTrainer): "lr": self.lr, "weight_decay": self.weight_decay, } - self.hyperparams = {**self.hyperparams, **self.model.get_hyper_parameter()} if init is True: self.initialize() @@ -207,9 +194,9 @@ class GraphClassificationTrainer(BaseTrainer): # """Get task name, i.e., `GraphClassification`.""" return "GraphClassification" - def to(self, new_device): - assert isinstance(new_device, torch.device) - self.device = new_device + def to(self, device): + assert isinstance(device, torch.device) + self.device = device if self.model is not None: self.model.to(self.device) @@ -255,11 +242,11 @@ class GraphClassificationTrainer(BaseTrainer): optimizer.zero_grad() output = self.model.model(data) # loss = F.nll_loss(output, data.y) - if hasattr(F, self.loss_type): - loss = getattr(F, self.loss_type)(output, data.y) + if hasattr(F, self.loss): + loss = getattr(F, self.loss)(output, data.y) else: raise TypeError( - "PyTorch does not support loss type {}".format(self.loss_type) + "PyTorch does not support loss type {}".format(self.loss) ) loss.backward() loss_all += data.num_graphs * loss.item() @@ -569,7 +556,7 @@ class GraphClassificationTrainer(BaseTrainer): weight_decay=hp["weight_decay"], device=self.device, feval=self.feval, - loss=self.loss_type, + loss=self.loss, lr_scheduler_type=self.lr_scheduler_type, init=True, *self.args, @@ -591,7 +578,6 @@ class GraphClassificationTrainer(BaseTrainer): def hyper_parameter_space(self, space): # """Set the space of hyperparameter.""" self.space = space - GraphClassificationTrainer.space = space def get_hyper_parameter(self): # """Get the hyperparameter in this trainer.""" diff --git a/autogl/module/train/node_classification.py b/autogl/module/train/node_classification_full.py similarity index 93% rename from autogl/module/train/node_classification.py rename to autogl/module/train/node_classification_full.py index a4994fc..361a391 100644 --- a/autogl/module/train/node_classification.py +++ b/autogl/module/train/node_classification_full.py @@ -1,4 +1,10 @@ -from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping +""" +Node classification Full Trainer Implementation +""" + +from . import register_trainer, EVALUATE_DICT + +from .base import BaseNodeClassificationTrainer, EarlyStopping, Evaluation import torch from torch.optim.lr_scheduler import ( StepLR, @@ -27,8 +33,8 @@ def get_feval(feval): raise ValueError("feval argument of type", type(feval), "is not supported!") -@register_trainer("NodeClassification") -class NodeClassificationTrainer(BaseTrainer): +@register_trainer("NodeClassificationFull") +class NodeClassificationFullTrainer(BaseNodeClassificationTrainer): """ The node classification trainer. @@ -58,8 +64,6 @@ class NodeClassificationTrainer(BaseTrainer): If True(False), the model will (not) be initialized. """ - space = None - def __init__( self, model: Union[BaseModel, str], @@ -70,7 +74,7 @@ class NodeClassificationTrainer(BaseTrainer): max_epoch=None, early_stopping_round=None, weight_decay=1e-4, - device=None, + device="auto", init=True, feval=[Logloss], loss="nll_loss", @@ -78,12 +82,15 @@ class NodeClassificationTrainer(BaseTrainer): *args, **kwargs ): - super(NodeClassificationTrainer, self).__init__(model) - - self.loss_type = loss - - if device is None: - device = "cpu" + super().__init__( + model, + num_features, + num_classes, + device=device, + init=init, + feval=feval, + loss=loss, + ) # init model if isinstance(model, str): @@ -102,14 +109,11 @@ class NodeClassificationTrainer(BaseTrainer): self.lr_scheduler_type = lr_scheduler_type - self.num_features = num_features - self.num_classes = num_classes self.lr = lr if lr is not None else 1e-4 self.max_epoch = max_epoch if max_epoch is not None else 100 self.early_stopping_round = ( early_stopping_round if early_stopping_round is not None else 100 ) - self.device = device self.args = args self.kwargs = kwargs @@ -126,9 +130,6 @@ class NodeClassificationTrainer(BaseTrainer): self.valid_score = None self.initialized = False - self.num_features = num_features - self.num_classes = num_classes - self.device = device self.space = [ { @@ -160,8 +161,6 @@ class NodeClassificationTrainer(BaseTrainer): "scalingType": "LOG", }, ] - # self.space += self.model.space - NodeClassificationTrainer.space = self.space self.hyperparams = { "max_epoch": self.max_epoch, @@ -169,7 +168,6 @@ class NodeClassificationTrainer(BaseTrainer): "lr": self.lr, "weight_decay": self.weight_decay, } - self.hyperparams = {**self.hyperparams, **self.model.get_hyper_parameter()} if init is True: self.initialize() @@ -229,11 +227,11 @@ class NodeClassificationTrainer(BaseTrainer): self.model.model.train() optimizer.zero_grad() res = self.model.model.forward(data) - if hasattr(F, self.loss_type): - loss = getattr(F, self.loss_type)(res[mask], data.y[mask]) + if hasattr(F, self.loss): + loss = getattr(F, self.loss)(res[mask], data.y[mask]) else: raise TypeError( - "PyTorch does not support loss type {}".format(self.loss_type) + "PyTorch does not support loss type {}".format(self.loss) ) loss.backward() @@ -241,7 +239,7 @@ class NodeClassificationTrainer(BaseTrainer): if self.lr_scheduler_type: scheduler.step() - if hasattr(data, 'val_mask') and data.val_mask is not None: + if hasattr(data, "val_mask") and data.val_mask is not None: if type(self.feval) is list: feval = self.feval[0] else: @@ -253,7 +251,7 @@ class NodeClassificationTrainer(BaseTrainer): if self.early_stopping.early_stop: LOGGER.debug("Early stopping at %d", epoch) break - if hasattr(data, 'val_mask') and data.val_mask is not None: + if hasattr(data, "val_mask") and data.val_mask is not None: self.early_stopping.load_checkpoint(self.model.model) def predict_only(self, data, test_mask=None): @@ -516,7 +514,7 @@ class NodeClassificationTrainer(BaseTrainer): device=self.device, weight_decay=hp["weight_decay"], feval=self.feval, - loss=self.loss_type, + loss=self.loss, lr_scheduler_type=self.lr_scheduler_type, init=True, *self.args, @@ -538,7 +536,6 @@ class NodeClassificationTrainer(BaseTrainer): def hyper_parameter_space(self, space): # """Set the space of hyperparameter.""" self.space = space - NodeClassificationTrainer.space = space def get_hyper_parameter(self): # """Get the hyperparameter in this trainer.""" diff --git a/autogl/solver/base.py b/autogl/solver/base.py index 421306f..94e7c0a 100644 --- a/autogl/solver/base.py +++ b/autogl/solver/base.py @@ -11,7 +11,6 @@ import torch from ..module.feature import FEATURE_DICT from ..module.hpo import HPO_DICT from ..module.model import MODEL_DICT -from ..module.train import NodeClassificationTrainer from ..module import BaseFeatureAtom, BaseHPOptimizer, BaseTrainer from .utils import Leaderboard from ..utils import get_logger @@ -336,7 +335,7 @@ class BaseSolver: assert name in self.trained_models, "cannot find model by name" + name return self.trained_models[name] - def get_model_by_performance(self, index) -> Tuple[NodeClassificationTrainer, str]: + def get_model_by_performance(self, index) -> Tuple[BaseTrainer, str]: r""" Find and get the model instance by performance. @@ -347,7 +346,7 @@ class BaseSolver: Returns ------- - trainer: autogl.module.train.NodeClassificationTrainer + trainer: autogl.module.train.BaseTrainer A trainer instance containing the trained models and training status. name: str The name of current trainer. diff --git a/autogl/solver/classifier/graph_classifier.py b/autogl/solver/classifier/graph_classifier.py index bc5bcdb..82c5254 100644 --- a/autogl/solver/classifier/graph_classifier.py +++ b/autogl/solver/classifier/graph_classifier.py @@ -13,7 +13,7 @@ import yaml from .base import BaseClassifier from ...module.feature import FEATURE_DICT from ...module.model import BaseModel, MODEL_DICT -from ...module.train import TRAINER_DICT, get_feval, GraphClassificationTrainer +from ...module.train import TRAINER_DICT, get_feval, BaseGraphClassificationTrainer from ..base import _initialize_single_model, _parse_hp_space from ..utils import Leaderboard, set_seed from ...datasets import utils @@ -90,7 +90,7 @@ class AutoGraphClassifier(BaseClassifier): hpo_module=hpo_module, ensemble_module=ensemble_module, max_evals=max_evals, - default_trainer=default_trainer or "GraphClassification", + default_trainer=default_trainer or "GraphClassificationFull", trainer_hp_space=trainer_hp_space, model_hp_spaces=model_hp_spaces, size=size, @@ -142,20 +142,22 @@ class AutoGraphClassifier(BaseClassifier): model.set_num_features(num_features) model.set_num_graph_features(num_graph_features) self.graph_model_list.append(model.to(device)) - elif isinstance(model, GraphClassificationTrainer): + elif isinstance(model, BaseGraphClassificationTrainer): # receive a trainer list, put trainer to list assert ( model.get_model() is not None ), "Passed trainer should contain a model" - model.set_feval(feval) - model.loss_type = loss - model.to(device) model.model.set_num_classes(num_classes) model.model.set_num_features(num_features) model.model.set_num_graph_features(num_graph_features) - model.num_classes = num_classes - model.num_features = num_features - model.num_graph_features = num_graph_features + model.update_parameters( + num_classes=num_classes, + num_features=num_features, + num_graph_features=num_graph_features, + loss=loss, + feval=feval, + device=device, + ) self.graph_model_list.append(model) else: raise KeyError("cannot find graph network %s." % (model)) @@ -171,7 +173,7 @@ class AutoGraphClassifier(BaseClassifier): # set model hp space if self._model_hp_spaces is not None: if self._model_hp_spaces[i] is not None: - if isinstance(model, GraphClassificationTrainer): + if isinstance(model, BaseGraphClassificationTrainer): model.model.hyper_parameter_space = self._model_hp_spaces[i] else: model.hyper_parameter_space = self._model_hp_spaces[i] @@ -770,11 +772,11 @@ class AutoGraphClassifier(BaseClassifier): ] trainer = path_or_dict.pop("trainer", None) - default_trainer = "GraphClassification" + default_trainer = "GraphClassificationFull" trainer_space = None if isinstance(trainer, dict): # global default - default_trainer = trainer.pop("name", "GraphClassification") + default_trainer = trainer.pop("name", "GraphClassificationFull") trainer_space = _parse_hp_space(trainer.pop("hp_space", None)) default_kwargs = {"num_features": None, "num_classes": None} default_kwargs.update(trainer) @@ -793,7 +795,7 @@ class AutoGraphClassifier(BaseClassifier): trainer_space = [] for i in range(len(model_list)): train, model = trainer[i], model_list[i] - default_trainer = train.pop("name", "GraphClassification") + default_trainer = train.pop("name", "GraphClassificationFull") trainer_space.append(_parse_hp_space(train.pop("hp_space", None))) default_kwargs = {"num_features": None, "num_classes": None} default_kwargs.update(train) diff --git a/autogl/solver/classifier/node_classifier.py b/autogl/solver/classifier/node_classifier.py index 4fb37a0..20b915c 100644 --- a/autogl/solver/classifier/node_classifier.py +++ b/autogl/solver/classifier/node_classifier.py @@ -14,7 +14,7 @@ from .base import BaseClassifier from ..base import _parse_hp_space, _initialize_single_model from ...module.feature import FEATURE_DICT from ...module.model import MODEL_DICT, BaseModel -from ...module.train import TRAINER_DICT, NodeClassificationTrainer +from ...module.train import TRAINER_DICT, BaseNodeClassificationTrainer from ...module.train import get_feval from ..utils import Leaderboard, set_seed from ...datasets import utils @@ -92,7 +92,7 @@ class AutoNodeClassifier(BaseClassifier): hpo_module=hpo_module, ensemble_module=ensemble_module, max_evals=max_evals, - default_trainer=default_trainer or "NodeClassification", + default_trainer=default_trainer or "NodeClassificationFull", trainer_hp_space=trainer_hp_space, model_hp_spaces=model_hp_spaces, size=size, @@ -135,18 +135,20 @@ class AutoNodeClassifier(BaseClassifier): model.set_num_classes(num_classes) model.set_num_features(num_features) self.graph_model_list.append(model.to(device)) - elif isinstance(model, NodeClassificationTrainer): + elif isinstance(model, BaseNodeClassificationTrainer): # receive a trainer list, put trainer to list assert ( model.get_model() is not None ), "Passed trainer should contain a model" - model.set_feval(feval) - model.loss_type = loss - model.to(device) model.model.set_num_classes(num_classes) model.model.set_num_features(num_features) - model.num_classes = num_classes - model.num_features = num_features + model.update_parameters( + num_classes=num_classes, + num_features=num_features, + loss=loss, + feval=feval, + device=device, + ) self.graph_model_list.append(model) else: raise KeyError("cannot find graph network %s." % (model)) @@ -162,7 +164,7 @@ class AutoNodeClassifier(BaseClassifier): # set model hp space if self._model_hp_spaces is not None: if self._model_hp_spaces[i] is not None: - if isinstance(model, NodeClassificationTrainer): + if isinstance(model, BaseNodeClassificationTrainer): model.model.hyper_parameter_space = self._model_hp_spaces[i] else: model.hyper_parameter_space = self._model_hp_spaces[i] @@ -689,11 +691,11 @@ class AutoNodeClassifier(BaseClassifier): ] trainer = path_or_dict.pop("trainer", None) - default_trainer = "NodeClassification" + default_trainer = "NodeClassificationFull" trainer_space = None if isinstance(trainer, dict): # global default - default_trainer = trainer.pop("name", "NodeClassification") + default_trainer = trainer.pop("name", "NodeClassificationFull") trainer_space = _parse_hp_space(trainer.pop("hp_space", None)) default_kwargs = {"num_features": None, "num_classes": None} default_kwargs.update(trainer) @@ -712,7 +714,7 @@ class AutoNodeClassifier(BaseClassifier): trainer_space = [] for i in range(len(model_list)): train, model = trainer[i], model_list[i] - default_trainer = train.pop("name", "NodeClassification") + default_trainer = train.pop("name", "NodeClassificationFull") trainer_space.append(_parse_hp_space(train.pop("hp_space", None))) default_kwargs = {"num_features": None, "num_classes": None} default_kwargs.update(train)