| @@ -1,8 +1,14 @@ | |||
| import importlib | |||
| import os | |||
| from .base import BaseTrainer, Evaluation, EarlyStopping | |||
| TRAINER_DICT = {} | |||
| EVALUATE_DICT = {} | |||
| from .base import ( | |||
| BaseTrainer, | |||
| Evaluation, | |||
| BaseNodeClassificationTrainer, | |||
| BaseGraphClassificationTrainer, | |||
| ) | |||
| def register_trainer(name): | |||
| @@ -19,9 +25,6 @@ def register_trainer(name): | |||
| return register_trainer_cls | |||
| EVALUATE_DICT = {} | |||
| def register_evaluate(*name): | |||
| def register_evaluate_cls(cls): | |||
| for n in name: | |||
| @@ -47,14 +50,16 @@ def get_feval(feval): | |||
| raise ValueError("feval argument of type", type(feval), "is not supported!") | |||
| from .graph_classification import GraphClassificationTrainer | |||
| from .node_classification import NodeClassificationTrainer | |||
| from .graph_classification_full import GraphClassificationFullTrainer | |||
| from .node_classification_full import NodeClassificationFullTrainer | |||
| from .evaluate import Acc, Auc, Logloss | |||
| __all__ = [ | |||
| "BaseTrainer", | |||
| "GraphClassificationTrainer", | |||
| "NodeClassificationTrainer", | |||
| "BaseNodeClassificationTrainer", | |||
| "BaseGraphClassificationTrainer", | |||
| "GraphClassificationFullTrainer", | |||
| "NodeClassificationFullTrainer", | |||
| "Evaluation", | |||
| "Acc", | |||
| "Auc", | |||
| @@ -1,12 +1,25 @@ | |||
| import numpy as np | |||
| from typing import Union, Iterable | |||
| from ..model import BaseModel | |||
| import torch | |||
| from ..model import BaseModel, MODEL_DICT | |||
| import pickle | |||
| from ...utils import get_logger | |||
| from . import EVALUATE_DICT | |||
| LOGGER_ES = get_logger("early-stopping") | |||
| def get_feval(feval): | |||
| if isinstance(feval, str): | |||
| return EVALUATE_DICT[feval] | |||
| if isinstance(feval, type) and issubclass(feval, Evaluation): | |||
| return feval | |||
| if isinstance(feval, list): | |||
| return [get_feval(f) for f in feval] | |||
| raise ValueError("feval argument of type", type(feval), "is not supported!") | |||
| class EarlyStopping: | |||
| """Early stops the training if validation loss doesn't improve after a given patience.""" | |||
| @@ -81,17 +94,11 @@ class EarlyStopping: | |||
| class BaseTrainer: | |||
| def __init__( | |||
| self, | |||
| model: Union[BaseModel, str], | |||
| optimizer=None, | |||
| lr=None, | |||
| max_epoch=None, | |||
| early_stopping_round=None, | |||
| device=None, | |||
| model: BaseModel, | |||
| device: Union[torch.device, str], | |||
| init=True, | |||
| feval=["acc"], | |||
| loss="nll_loss", | |||
| *args, | |||
| **kwargs, | |||
| ): | |||
| """ | |||
| The basic trainer. | |||
| @@ -103,29 +110,26 @@ class BaseTrainer: | |||
| model: `BaseModel` or `str` | |||
| The (name of) model used to train and predict. | |||
| optimizer: `Optimizer` of `str` | |||
| The (name of) optimizer used to train and predict. | |||
| lr: `float` | |||
| The learning rate. | |||
| max_epoch: `int` | |||
| The max number of epochs in training. | |||
| early_stopping_round: `int` | |||
| The round of early stop. | |||
| device: `torch.device` or `str` | |||
| The device where model will be running on. | |||
| init: `bool` | |||
| If True(False), the model will (not) be initialized. | |||
| """ | |||
| super().__init__() | |||
| self.model = model | |||
| self.to(device) | |||
| self.init = init | |||
| self.feval = get_feval(feval) | |||
| self.loss = loss | |||
| args: Other parameters. | |||
| def to(self, device): | |||
| """ | |||
| Migrate trainer to new device | |||
| kwargs: Other parameters. | |||
| Parameters | |||
| ---------- | |||
| device: `str` or `torch.device` | |||
| The device this trainer will use | |||
| """ | |||
| super().__init__() | |||
| self.device = torch.device(device) | |||
| def initialize(self): | |||
| """Initialize the auto model in trainer.""" | |||
| @@ -169,8 +173,8 @@ class BaseTrainer: | |||
| @classmethod | |||
| def load(cls, path): | |||
| with open(path, "rb") as input: | |||
| instance = pickle.load(input) | |||
| with open(path, "rb") as inputs: | |||
| instance = pickle.load(inputs) | |||
| return instance | |||
| @property | |||
| @@ -279,7 +283,21 @@ class BaseTrainer: | |||
| def set_feval(self, feval): | |||
| """Set the evaluation metrics.""" | |||
| raise NotImplementedError() | |||
| self.feval = get_feval(feval) | |||
| def update_parameters(self, **kwargs): | |||
| """ | |||
| Update parameters of this trainer | |||
| """ | |||
| for k, v in kwargs.items(): | |||
| if k == "feval": | |||
| self.set_feval(v) | |||
| elif k == "device": | |||
| self.to(v) | |||
| elif hasattr(self, k): | |||
| setattr(self, k, v) | |||
| else: | |||
| raise KeyError("Cannot set parameter", k, "for trainer", self.__class__) | |||
| # a static class for evaluating results | |||
| @@ -296,7 +314,7 @@ class Evaluation: | |||
| """ | |||
| Should return whether this evaluation method is higher better (bool) | |||
| """ | |||
| raise True | |||
| return True | |||
| @staticmethod | |||
| def evaluate(predict, label): | |||
| @@ -304,3 +322,84 @@ class Evaluation: | |||
| Should return: the evaluation result (float) | |||
| """ | |||
| raise NotImplementedError() | |||
| class BaseNodeClassificationTrainer(BaseTrainer): | |||
| def __init__( | |||
| self, | |||
| model: Union[BaseModel, str], | |||
| num_features, | |||
| num_classes, | |||
| device="auto", | |||
| init=True, | |||
| feval=["acc"], | |||
| loss="nll_loss", | |||
| ): | |||
| self.num_features = num_features | |||
| self.num_classes = num_classes | |||
| device = ( | |||
| torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| if device == "auto" | |||
| else torch.device(device) | |||
| ) | |||
| if isinstance(model, str): | |||
| assert model in MODEL_DICT, "Cannot parse model name " + model | |||
| self.model = MODEL_DICT[model](num_features, num_classes, device, init=init) | |||
| elif isinstance(model, BaseModel): | |||
| self.model = model | |||
| else: | |||
| raise TypeError( | |||
| "Model argument only support str or BaseModel, get", | |||
| type(model), | |||
| "instead.", | |||
| ) | |||
| super().__init__(model, device=device, init=init, feval=feval, loss=loss) | |||
| @classmethod | |||
| def get_task_name(cls): | |||
| return "GraphClassification" | |||
| class BaseGraphClassificationTrainer(BaseTrainer): | |||
| def __init__( | |||
| self, | |||
| model: Union[BaseModel, str], | |||
| num_features, | |||
| num_classes, | |||
| num_graph_features=0, | |||
| device=None, | |||
| init=True, | |||
| feval=["acc"], | |||
| loss="nll_loss", | |||
| ): | |||
| self.num_features = num_features | |||
| self.num_classes = num_classes | |||
| self.num_graph_features = num_graph_features | |||
| device = ( | |||
| torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| if device == "auto" | |||
| else torch.device(device) | |||
| ) | |||
| if isinstance(model, str): | |||
| assert model in MODEL_DICT, "Cannot parse model name " + model | |||
| self.model = MODEL_DICT[model]( | |||
| num_features, | |||
| num_classes, | |||
| device, | |||
| init=init, | |||
| num_graph_features=num_graph_features, | |||
| ) | |||
| elif isinstance(model, BaseModel): | |||
| self.model = model | |||
| else: | |||
| raise TypeError( | |||
| "Model argument only support str or BaseModel, get", | |||
| type(model), | |||
| "instead.", | |||
| ) | |||
| super().__init__(model, device=device, init=init, feval=feval, loss=loss) | |||
| @classmethod | |||
| def get_task_name(cls): | |||
| return "NodeClassification" | |||
| @@ -1,4 +1,5 @@ | |||
| from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping | |||
| from . import register_trainer, EVALUATE_DICT | |||
| from .base import BaseGraphClassificationTrainer, EarlyStopping, Evaluation | |||
| import torch | |||
| from torch.optim.lr_scheduler import ( | |||
| StepLR, | |||
| @@ -28,8 +29,8 @@ def get_feval(feval): | |||
| raise ValueError("feval argument of type", type(feval), "is not supported!") | |||
| @register_trainer("GraphClassification") | |||
| class GraphClassificationTrainer(BaseTrainer): | |||
| @register_trainer("GraphClassificationFull") | |||
| class GraphClassificationFullTrainer(BaseGraphClassificationTrainer): | |||
| """ | |||
| The graph classification trainer. | |||
| @@ -73,7 +74,7 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| batch_size=None, | |||
| early_stopping_round=7, | |||
| weight_decay=1e-4, | |||
| device=None, | |||
| device="auto", | |||
| init=True, | |||
| feval=[Logloss], | |||
| loss="nll_loss", | |||
| @@ -81,22 +82,16 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| *args, | |||
| **kwargs | |||
| ): | |||
| super(GraphClassificationTrainer, self).__init__(model) | |||
| self.loss_type = loss | |||
| # init model | |||
| if isinstance(model, str): | |||
| assert model in MODEL_DICT, "Cannot parse model name " + model | |||
| self.model = MODEL_DICT[model]( | |||
| num_features, | |||
| num_classes, | |||
| device, | |||
| init=init, | |||
| num_graph_features=num_graph_features, | |||
| ) | |||
| elif isinstance(model, BaseModel): | |||
| self.model = model | |||
| super().__init__( | |||
| model, | |||
| num_features, | |||
| num_classes, | |||
| num_graph_features=num_graph_features, | |||
| device=device, | |||
| init=init, | |||
| feval=feval, | |||
| loss=loss, | |||
| ) | |||
| self.opt_received = optimizer | |||
| if type(optimizer) == str and optimizer.lower() == "adam": | |||
| @@ -108,9 +103,6 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| self.lr_scheduler_type = lr_scheduler_type | |||
| self.num_features = num_features | |||
| self.num_classes = num_classes | |||
| self.num_graph_features = num_graph_features | |||
| self.lr = lr if lr is not None else 1e-4 | |||
| self.max_epoch = max_epoch if max_epoch is not None else 100 | |||
| self.batch_size = batch_size if batch_size is not None else 64 | |||
| @@ -135,8 +127,6 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| self.valid_score = None | |||
| self.initialized = False | |||
| self.num_features = num_features | |||
| self.num_classes = num_classes | |||
| self.device = device | |||
| self.space = [ | |||
| @@ -176,8 +166,6 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| "scalingType": "LOG", | |||
| }, | |||
| ] | |||
| # self.space += self.model.space | |||
| GraphClassificationTrainer.space = self.space | |||
| self.hyperparams = { | |||
| "max_epoch": self.max_epoch, | |||
| @@ -186,7 +174,6 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| "lr": self.lr, | |||
| "weight_decay": self.weight_decay, | |||
| } | |||
| self.hyperparams = {**self.hyperparams, **self.model.get_hyper_parameter()} | |||
| if init is True: | |||
| self.initialize() | |||
| @@ -207,9 +194,9 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| # """Get task name, i.e., `GraphClassification`.""" | |||
| return "GraphClassification" | |||
| def to(self, new_device): | |||
| assert isinstance(new_device, torch.device) | |||
| self.device = new_device | |||
| def to(self, device): | |||
| assert isinstance(device, torch.device) | |||
| self.device = device | |||
| if self.model is not None: | |||
| self.model.to(self.device) | |||
| @@ -255,11 +242,11 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| optimizer.zero_grad() | |||
| output = self.model.model(data) | |||
| # loss = F.nll_loss(output, data.y) | |||
| if hasattr(F, self.loss_type): | |||
| loss = getattr(F, self.loss_type)(output, data.y) | |||
| if hasattr(F, self.loss): | |||
| loss = getattr(F, self.loss)(output, data.y) | |||
| else: | |||
| raise TypeError( | |||
| "PyTorch does not support loss type {}".format(self.loss_type) | |||
| "PyTorch does not support loss type {}".format(self.loss) | |||
| ) | |||
| loss.backward() | |||
| loss_all += data.num_graphs * loss.item() | |||
| @@ -569,7 +556,7 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| weight_decay=hp["weight_decay"], | |||
| device=self.device, | |||
| feval=self.feval, | |||
| loss=self.loss_type, | |||
| loss=self.loss, | |||
| lr_scheduler_type=self.lr_scheduler_type, | |||
| init=True, | |||
| *self.args, | |||
| @@ -591,7 +578,6 @@ class GraphClassificationTrainer(BaseTrainer): | |||
| def hyper_parameter_space(self, space): | |||
| # """Set the space of hyperparameter.""" | |||
| self.space = space | |||
| GraphClassificationTrainer.space = space | |||
| def get_hyper_parameter(self): | |||
| # """Get the hyperparameter in this trainer.""" | |||
| @@ -1,4 +1,10 @@ | |||
| from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping | |||
| """ | |||
| Node classification Full Trainer Implementation | |||
| """ | |||
| from . import register_trainer, EVALUATE_DICT | |||
| from .base import BaseNodeClassificationTrainer, EarlyStopping, Evaluation | |||
| import torch | |||
| from torch.optim.lr_scheduler import ( | |||
| StepLR, | |||
| @@ -27,8 +33,8 @@ def get_feval(feval): | |||
| raise ValueError("feval argument of type", type(feval), "is not supported!") | |||
| @register_trainer("NodeClassification") | |||
| class NodeClassificationTrainer(BaseTrainer): | |||
| @register_trainer("NodeClassificationFull") | |||
| class NodeClassificationFullTrainer(BaseNodeClassificationTrainer): | |||
| """ | |||
| The node classification trainer. | |||
| @@ -58,8 +64,6 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| If True(False), the model will (not) be initialized. | |||
| """ | |||
| space = None | |||
| def __init__( | |||
| self, | |||
| model: Union[BaseModel, str], | |||
| @@ -70,7 +74,7 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| max_epoch=None, | |||
| early_stopping_round=None, | |||
| weight_decay=1e-4, | |||
| device=None, | |||
| device="auto", | |||
| init=True, | |||
| feval=[Logloss], | |||
| loss="nll_loss", | |||
| @@ -78,12 +82,15 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| *args, | |||
| **kwargs | |||
| ): | |||
| super(NodeClassificationTrainer, self).__init__(model) | |||
| self.loss_type = loss | |||
| if device is None: | |||
| device = "cpu" | |||
| super().__init__( | |||
| model, | |||
| num_features, | |||
| num_classes, | |||
| device=device, | |||
| init=init, | |||
| feval=feval, | |||
| loss=loss, | |||
| ) | |||
| # init model | |||
| if isinstance(model, str): | |||
| @@ -102,14 +109,11 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| self.lr_scheduler_type = lr_scheduler_type | |||
| self.num_features = num_features | |||
| self.num_classes = num_classes | |||
| self.lr = lr if lr is not None else 1e-4 | |||
| self.max_epoch = max_epoch if max_epoch is not None else 100 | |||
| self.early_stopping_round = ( | |||
| early_stopping_round if early_stopping_round is not None else 100 | |||
| ) | |||
| self.device = device | |||
| self.args = args | |||
| self.kwargs = kwargs | |||
| @@ -126,9 +130,6 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| self.valid_score = None | |||
| self.initialized = False | |||
| self.num_features = num_features | |||
| self.num_classes = num_classes | |||
| self.device = device | |||
| self.space = [ | |||
| { | |||
| @@ -160,8 +161,6 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| "scalingType": "LOG", | |||
| }, | |||
| ] | |||
| # self.space += self.model.space | |||
| NodeClassificationTrainer.space = self.space | |||
| self.hyperparams = { | |||
| "max_epoch": self.max_epoch, | |||
| @@ -169,7 +168,6 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| "lr": self.lr, | |||
| "weight_decay": self.weight_decay, | |||
| } | |||
| self.hyperparams = {**self.hyperparams, **self.model.get_hyper_parameter()} | |||
| if init is True: | |||
| self.initialize() | |||
| @@ -229,11 +227,11 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| self.model.model.train() | |||
| optimizer.zero_grad() | |||
| res = self.model.model.forward(data) | |||
| if hasattr(F, self.loss_type): | |||
| loss = getattr(F, self.loss_type)(res[mask], data.y[mask]) | |||
| if hasattr(F, self.loss): | |||
| loss = getattr(F, self.loss)(res[mask], data.y[mask]) | |||
| else: | |||
| raise TypeError( | |||
| "PyTorch does not support loss type {}".format(self.loss_type) | |||
| "PyTorch does not support loss type {}".format(self.loss) | |||
| ) | |||
| loss.backward() | |||
| @@ -241,7 +239,7 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| if self.lr_scheduler_type: | |||
| scheduler.step() | |||
| if hasattr(data, 'val_mask') and data.val_mask is not None: | |||
| if hasattr(data, "val_mask") and data.val_mask is not None: | |||
| if type(self.feval) is list: | |||
| feval = self.feval[0] | |||
| else: | |||
| @@ -253,7 +251,7 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| if self.early_stopping.early_stop: | |||
| LOGGER.debug("Early stopping at %d", epoch) | |||
| break | |||
| if hasattr(data, 'val_mask') and data.val_mask is not None: | |||
| if hasattr(data, "val_mask") and data.val_mask is not None: | |||
| self.early_stopping.load_checkpoint(self.model.model) | |||
| def predict_only(self, data, test_mask=None): | |||
| @@ -516,7 +514,7 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| device=self.device, | |||
| weight_decay=hp["weight_decay"], | |||
| feval=self.feval, | |||
| loss=self.loss_type, | |||
| loss=self.loss, | |||
| lr_scheduler_type=self.lr_scheduler_type, | |||
| init=True, | |||
| *self.args, | |||
| @@ -538,7 +536,6 @@ class NodeClassificationTrainer(BaseTrainer): | |||
| def hyper_parameter_space(self, space): | |||
| # """Set the space of hyperparameter.""" | |||
| self.space = space | |||
| NodeClassificationTrainer.space = space | |||
| def get_hyper_parameter(self): | |||
| # """Get the hyperparameter in this trainer.""" | |||
| @@ -11,7 +11,6 @@ import torch | |||
| from ..module.feature import FEATURE_DICT | |||
| from ..module.hpo import HPO_DICT | |||
| from ..module.model import MODEL_DICT | |||
| from ..module.train import NodeClassificationTrainer | |||
| from ..module import BaseFeatureAtom, BaseHPOptimizer, BaseTrainer | |||
| from .utils import Leaderboard | |||
| from ..utils import get_logger | |||
| @@ -336,7 +335,7 @@ class BaseSolver: | |||
| assert name in self.trained_models, "cannot find model by name" + name | |||
| return self.trained_models[name] | |||
| def get_model_by_performance(self, index) -> Tuple[NodeClassificationTrainer, str]: | |||
| def get_model_by_performance(self, index) -> Tuple[BaseTrainer, str]: | |||
| r""" | |||
| Find and get the model instance by performance. | |||
| @@ -347,7 +346,7 @@ class BaseSolver: | |||
| Returns | |||
| ------- | |||
| trainer: autogl.module.train.NodeClassificationTrainer | |||
| trainer: autogl.module.train.BaseTrainer | |||
| A trainer instance containing the trained models and training status. | |||
| name: str | |||
| The name of current trainer. | |||
| @@ -13,7 +13,7 @@ import yaml | |||
| from .base import BaseClassifier | |||
| from ...module.feature import FEATURE_DICT | |||
| from ...module.model import BaseModel, MODEL_DICT | |||
| from ...module.train import TRAINER_DICT, get_feval, GraphClassificationTrainer | |||
| from ...module.train import TRAINER_DICT, get_feval, BaseGraphClassificationTrainer | |||
| from ..base import _initialize_single_model, _parse_hp_space | |||
| from ..utils import Leaderboard, set_seed | |||
| from ...datasets import utils | |||
| @@ -90,7 +90,7 @@ class AutoGraphClassifier(BaseClassifier): | |||
| hpo_module=hpo_module, | |||
| ensemble_module=ensemble_module, | |||
| max_evals=max_evals, | |||
| default_trainer=default_trainer or "GraphClassification", | |||
| default_trainer=default_trainer or "GraphClassificationFull", | |||
| trainer_hp_space=trainer_hp_space, | |||
| model_hp_spaces=model_hp_spaces, | |||
| size=size, | |||
| @@ -142,20 +142,22 @@ class AutoGraphClassifier(BaseClassifier): | |||
| model.set_num_features(num_features) | |||
| model.set_num_graph_features(num_graph_features) | |||
| self.graph_model_list.append(model.to(device)) | |||
| elif isinstance(model, GraphClassificationTrainer): | |||
| elif isinstance(model, BaseGraphClassificationTrainer): | |||
| # receive a trainer list, put trainer to list | |||
| assert ( | |||
| model.get_model() is not None | |||
| ), "Passed trainer should contain a model" | |||
| model.set_feval(feval) | |||
| model.loss_type = loss | |||
| model.to(device) | |||
| model.model.set_num_classes(num_classes) | |||
| model.model.set_num_features(num_features) | |||
| model.model.set_num_graph_features(num_graph_features) | |||
| model.num_classes = num_classes | |||
| model.num_features = num_features | |||
| model.num_graph_features = num_graph_features | |||
| model.update_parameters( | |||
| num_classes=num_classes, | |||
| num_features=num_features, | |||
| num_graph_features=num_graph_features, | |||
| loss=loss, | |||
| feval=feval, | |||
| device=device, | |||
| ) | |||
| self.graph_model_list.append(model) | |||
| else: | |||
| raise KeyError("cannot find graph network %s." % (model)) | |||
| @@ -171,7 +173,7 @@ class AutoGraphClassifier(BaseClassifier): | |||
| # set model hp space | |||
| if self._model_hp_spaces is not None: | |||
| if self._model_hp_spaces[i] is not None: | |||
| if isinstance(model, GraphClassificationTrainer): | |||
| if isinstance(model, BaseGraphClassificationTrainer): | |||
| model.model.hyper_parameter_space = self._model_hp_spaces[i] | |||
| else: | |||
| model.hyper_parameter_space = self._model_hp_spaces[i] | |||
| @@ -770,11 +772,11 @@ class AutoGraphClassifier(BaseClassifier): | |||
| ] | |||
| trainer = path_or_dict.pop("trainer", None) | |||
| default_trainer = "GraphClassification" | |||
| default_trainer = "GraphClassificationFull" | |||
| trainer_space = None | |||
| if isinstance(trainer, dict): | |||
| # global default | |||
| default_trainer = trainer.pop("name", "GraphClassification") | |||
| default_trainer = trainer.pop("name", "GraphClassificationFull") | |||
| trainer_space = _parse_hp_space(trainer.pop("hp_space", None)) | |||
| default_kwargs = {"num_features": None, "num_classes": None} | |||
| default_kwargs.update(trainer) | |||
| @@ -793,7 +795,7 @@ class AutoGraphClassifier(BaseClassifier): | |||
| trainer_space = [] | |||
| for i in range(len(model_list)): | |||
| train, model = trainer[i], model_list[i] | |||
| default_trainer = train.pop("name", "GraphClassification") | |||
| default_trainer = train.pop("name", "GraphClassificationFull") | |||
| trainer_space.append(_parse_hp_space(train.pop("hp_space", None))) | |||
| default_kwargs = {"num_features": None, "num_classes": None} | |||
| default_kwargs.update(train) | |||
| @@ -14,7 +14,7 @@ from .base import BaseClassifier | |||
| from ..base import _parse_hp_space, _initialize_single_model | |||
| from ...module.feature import FEATURE_DICT | |||
| from ...module.model import MODEL_DICT, BaseModel | |||
| from ...module.train import TRAINER_DICT, NodeClassificationTrainer | |||
| from ...module.train import TRAINER_DICT, BaseNodeClassificationTrainer | |||
| from ...module.train import get_feval | |||
| from ..utils import Leaderboard, set_seed | |||
| from ...datasets import utils | |||
| @@ -92,7 +92,7 @@ class AutoNodeClassifier(BaseClassifier): | |||
| hpo_module=hpo_module, | |||
| ensemble_module=ensemble_module, | |||
| max_evals=max_evals, | |||
| default_trainer=default_trainer or "NodeClassification", | |||
| default_trainer=default_trainer or "NodeClassificationFull", | |||
| trainer_hp_space=trainer_hp_space, | |||
| model_hp_spaces=model_hp_spaces, | |||
| size=size, | |||
| @@ -135,18 +135,20 @@ class AutoNodeClassifier(BaseClassifier): | |||
| model.set_num_classes(num_classes) | |||
| model.set_num_features(num_features) | |||
| self.graph_model_list.append(model.to(device)) | |||
| elif isinstance(model, NodeClassificationTrainer): | |||
| elif isinstance(model, BaseNodeClassificationTrainer): | |||
| # receive a trainer list, put trainer to list | |||
| assert ( | |||
| model.get_model() is not None | |||
| ), "Passed trainer should contain a model" | |||
| model.set_feval(feval) | |||
| model.loss_type = loss | |||
| model.to(device) | |||
| model.model.set_num_classes(num_classes) | |||
| model.model.set_num_features(num_features) | |||
| model.num_classes = num_classes | |||
| model.num_features = num_features | |||
| model.update_parameters( | |||
| num_classes=num_classes, | |||
| num_features=num_features, | |||
| loss=loss, | |||
| feval=feval, | |||
| device=device, | |||
| ) | |||
| self.graph_model_list.append(model) | |||
| else: | |||
| raise KeyError("cannot find graph network %s." % (model)) | |||
| @@ -162,7 +164,7 @@ class AutoNodeClassifier(BaseClassifier): | |||
| # set model hp space | |||
| if self._model_hp_spaces is not None: | |||
| if self._model_hp_spaces[i] is not None: | |||
| if isinstance(model, NodeClassificationTrainer): | |||
| if isinstance(model, BaseNodeClassificationTrainer): | |||
| model.model.hyper_parameter_space = self._model_hp_spaces[i] | |||
| else: | |||
| model.hyper_parameter_space = self._model_hp_spaces[i] | |||
| @@ -689,11 +691,11 @@ class AutoNodeClassifier(BaseClassifier): | |||
| ] | |||
| trainer = path_or_dict.pop("trainer", None) | |||
| default_trainer = "NodeClassification" | |||
| default_trainer = "NodeClassificationFull" | |||
| trainer_space = None | |||
| if isinstance(trainer, dict): | |||
| # global default | |||
| default_trainer = trainer.pop("name", "NodeClassification") | |||
| default_trainer = trainer.pop("name", "NodeClassificationFull") | |||
| trainer_space = _parse_hp_space(trainer.pop("hp_space", None)) | |||
| default_kwargs = {"num_features": None, "num_classes": None} | |||
| default_kwargs.update(trainer) | |||
| @@ -712,7 +714,7 @@ class AutoNodeClassifier(BaseClassifier): | |||
| trainer_space = [] | |||
| for i in range(len(model_list)): | |||
| train, model = trainer[i], model_list[i] | |||
| default_trainer = train.pop("name", "NodeClassification") | |||
| default_trainer = train.pop("name", "NodeClassificationFull") | |||
| trainer_space.append(_parse_hp_space(train.pop("hp_space", None))) | |||
| default_kwargs = {"num_features": None, "num_classes": None} | |||
| default_kwargs.update(train) | |||