diff --git a/autogl/solver/classifier/base.py b/autogl/solver/classifier/base.py index 2bb8521..96f84a3 100644 --- a/autogl/solver/classifier/base.py +++ b/autogl/solver/classifier/base.py @@ -5,9 +5,7 @@ Base solver for classification problems from typing import Any from ..base import BaseSolver from ...module.ensemble import ENSEMBLE_DICT -from ...module.train import TRAINER_DICT -from ...module.model import MODEL_DICT -from ...module import BaseEnsembler, BaseModel, BaseTrainer +from ...module import BaseEnsembler class BaseClassifier(BaseSolver): @@ -15,93 +13,6 @@ class BaseClassifier(BaseSolver): Base solver for classification problems """ - def _init_graph_module( - self, - graph_models, - num_classes, - num_features, - *args, - **kwargs, - ) -> "BaseClassifier": - # load graph network module - self.graph_model_list = [] - if isinstance(graph_models, list): - for model in graph_models: - if isinstance(model, str): - if model in MODEL_DICT: - self.graph_model_list.append( - MODEL_DICT[model]( - num_classes=num_classes, - num_features=num_features, - *args, - **kwargs, - init=False, - ) - ) - else: - raise KeyError("cannot find model %s" % (model)) - elif isinstance(model, type) and issubclass(model, BaseModel): - self.graph_model_list.append( - model( - num_classes=num_classes, - num_features=num_features, - *args, - **kwargs, - init=False, - ) - ) - elif isinstance(model, BaseModel): - # setup the hp of num_classes and num_features - model.set_num_classes(num_classes) - model.set_num_features(num_features) - self.graph_model_list.append(model.to(self.runtime_device)) - elif isinstance(model, BaseTrainer): - # receive a trainer list, put trainer to list - self.graph_model_list.append(model) - else: - raise KeyError("cannot find graph network %s." % (model)) - else: - raise ValueError( - "need graph network to be (list of) str or a BaseModel class/instance, get", - graph_models, - "instead.", - ) - - # wrap all model_cls with specified trainer - for i, model in enumerate(self.graph_model_list): - # set model hp space - if self._model_hp_spaces is not None: - if self._model_hp_spaces[i] is not None: - if isinstance(model, BaseTrainer): - model.model.hyper_parameter_space = self._model_hp_spaces[i] - else: - model.hyper_parameter_space = self._model_hp_spaces[i] - # initialize trainer if needed - if isinstance(model, BaseModel): - name = ( - self._default_trainer - if isinstance(self._default_trainer, str) - else self._default_trainer[i] - ) - model = TRAINER_DICT[name]( - model=model, - num_features=num_features, - num_classes=num_classes, - *args, - **kwargs, - init=False, - ) - # set trainer hp space - if self._trainer_hp_space is not None: - if isinstance(self._trainer_hp_space[0], list): - current_hp_for_trainer = self._trainer_hp_space[i] - else: - current_hp_for_trainer = self._trainer_hp_space - model.hyper_parameter_space = current_hp_for_trainer - self.graph_model_list[i] = model - - return self - def predict_proba(self, *args, **kwargs) -> Any: """ Predict the node probability. diff --git a/autogl/solver/classifier/graph_classifier.py b/autogl/solver/classifier/graph_classifier.py index f7efedf..4ad0987 100644 --- a/autogl/solver/classifier/graph_classifier.py +++ b/autogl/solver/classifier/graph_classifier.py @@ -12,7 +12,8 @@ import yaml from .base import BaseClassifier from ...module.feature import FEATURE_DICT -from ...module.train import get_feval +from ...module.model import BaseModel, MODEL_DICT +from ...module.train import TRAINER_DICT, get_feval, GraphClassificationTrainer from ..base import _initialize_single_model, _parse_hp_space from ..utils import Leaderboard, set_seed from ...datasets import utils @@ -98,6 +99,108 @@ class AutoGraphClassifier(BaseClassifier): self.dataset = None + def _init_graph_module( + self, + graph_models, + num_classes, + num_features, + feval, + device, + loss, + num_graph_features + ) -> "AutoGraphClassifier": + # load graph network module + self.graph_model_list = [] + if isinstance(graph_models, list): + for model in graph_models: + if isinstance(model, str): + if model in MODEL_DICT: + self.graph_model_list.append( + MODEL_DICT[model]( + num_classes=num_classes, + num_features=num_features, + num_graph_features=num_graph_features, + device=device, + init=False + ) + ) + else: + raise KeyError("cannot find model %s" % (model)) + elif isinstance(model, type) and issubclass(model, BaseModel): + self.graph_model_list.append( + model( + num_classes=num_classes, + num_features=num_features, + num_graph_features=num_graph_features, + device=device, + init=False + ) + ) + elif isinstance(model, BaseModel): + # setup the hp of num_classes and num_features + model.set_num_classes(num_classes) + model.set_num_features(num_features) + model.set_num_graph_features(num_graph_features) + self.graph_model_list.append(model.to(device)) + elif isinstance(model, GraphClassificationTrainer): + # receive a trainer list, put trainer to list + assert model.get_model() is not None, "Passed trainer should contain a model" + model.set_feval(feval) + model.loss_type = loss + model.to(device) + model.model.set_num_classes(num_classes) + model.model.set_num_features(num_features) + model.model.set_num_graph_features(num_graph_features) + model.num_classes = num_classes + model.num_features = num_features + model.num_graph_features = num_graph_features + self.graph_model_list.append(model) + else: + raise KeyError("cannot find graph network %s." % (model)) + else: + raise ValueError( + "need graph network to be (list of) str or a BaseModel class/instance, get", + graph_models, + "instead.", + ) + + # wrap all model_cls with specified trainer + for i, model in enumerate(self.graph_model_list): + # set model hp space + if self._model_hp_spaces is not None: + if self._model_hp_spaces[i] is not None: + if isinstance(model, GraphClassificationTrainer): + model.model.hyper_parameter_space = self._model_hp_spaces[i] + else: + model.hyper_parameter_space = self._model_hp_spaces[i] + # initialize trainer if needed + if isinstance(model, BaseModel): + name = ( + self._default_trainer + if isinstance(self._default_trainer, str) + else self._default_trainer[i] + ) + model = TRAINER_DICT[name]( + model=model, + num_features=num_features, + num_classes=num_classes, + loss=loss, + feval=feval, + device=device, + num_graph_features=num_graph_features, + init=False + ) + # set trainer hp space + if self._trainer_hp_space is not None: + if isinstance(self._trainer_hp_space[0], list): + current_hp_for_trainer = self._trainer_hp_space[i] + else: + current_hp_for_trainer = self._trainer_hp_space + model.hyper_parameter_space = current_hp_for_trainer + self.graph_model_list[i] = model + + return self + # pylint: disable=arguments-differ def fit( self, diff --git a/autogl/solver/classifier/node_classifier.py b/autogl/solver/classifier/node_classifier.py index 25d9560..587e513 100644 --- a/autogl/solver/classifier/node_classifier.py +++ b/autogl/solver/classifier/node_classifier.py @@ -13,6 +13,8 @@ import yaml from .base import BaseClassifier from ..base import _parse_hp_space, _initialize_single_model from ...module.feature import FEATURE_DICT +from ...module.model import MODEL_DICT, BaseModel +from ...module.train import TRAINER_DICT, NodeClassificationTrainer from ...module.train import get_feval from ..utils import Leaderboard, set_seed from ...datasets import utils @@ -100,6 +102,101 @@ class AutoNodeClassifier(BaseClassifier): # data to be kept when fit self.data = None + def _init_graph_module( + self, + graph_models, + num_classes, + num_features, + feval, + device, + loss + ) -> "AutoNodeClassifier": + # load graph network module + self.graph_model_list = [] + if isinstance(graph_models, list): + for model in graph_models: + if isinstance(model, str): + if model in MODEL_DICT: + self.graph_model_list.append( + MODEL_DICT[model]( + num_classes=num_classes, + num_features=num_features, + device=device, + init=False + ) + ) + else: + raise KeyError("cannot find model %s" % (model)) + elif isinstance(model, type) and issubclass(model, BaseModel): + self.graph_model_list.append( + model( + num_classes=num_classes, + num_features=num_features, + device=device, + init=False + ) + ) + elif isinstance(model, BaseModel): + # setup the hp of num_classes and num_features + model.set_num_classes(num_classes) + model.set_num_features(num_features) + self.graph_model_list.append(model.to(device)) + elif isinstance(model, NodeClassificationTrainer): + # receive a trainer list, put trainer to list + assert model.get_model() is not None, "Passed trainer should contain a model" + model.set_feval(feval) + model.loss_type = loss + model.to(device) + model.model.set_num_classes(num_classes) + model.model.set_num_features(num_features) + model.num_classes = num_classes + model.num_features = num_features + self.graph_model_list.append(model) + else: + raise KeyError("cannot find graph network %s." % (model)) + else: + raise ValueError( + "need graph network to be (list of) str or a BaseModel class/instance, get", + graph_models, + "instead.", + ) + + # wrap all model_cls with specified trainer + for i, model in enumerate(self.graph_model_list): + # set model hp space + if self._model_hp_spaces is not None: + if self._model_hp_spaces[i] is not None: + if isinstance(model, NodeClassificationTrainer): + model.model.hyper_parameter_space = self._model_hp_spaces[i] + else: + model.hyper_parameter_space = self._model_hp_spaces[i] + # initialize trainer if needed + if isinstance(model, BaseModel): + name = ( + self._default_trainer + if isinstance(self._default_trainer, str) + else self._default_trainer[i] + ) + model = TRAINER_DICT[name]( + model=model, + num_features=num_features, + num_classes=num_classes, + loss=loss, + feval=feval, + device=device, + init=False + ) + # set trainer hp space + if self._trainer_hp_space is not None: + if isinstance(self._trainer_hp_space[0], list): + current_hp_for_trainer = self._trainer_hp_space[i] + else: + current_hp_for_trainer = self._trainer_hp_space + model.hyper_parameter_space = current_hp_for_trainer + self.graph_model_list[i] = model + + return self + # pylint: disable=arguments-differ def fit( self,