| @@ -5,9 +5,7 @@ Base solver for classification problems | |||
| from typing import Any | |||
| from ..base import BaseSolver | |||
| from ...module.ensemble import ENSEMBLE_DICT | |||
| from ...module.train import TRAINER_DICT | |||
| from ...module.model import MODEL_DICT | |||
| from ...module import BaseEnsembler, BaseModel, BaseTrainer | |||
| from ...module import BaseEnsembler | |||
| class BaseClassifier(BaseSolver): | |||
| @@ -15,93 +13,6 @@ class BaseClassifier(BaseSolver): | |||
| Base solver for classification problems | |||
| """ | |||
| def _init_graph_module( | |||
| self, | |||
| graph_models, | |||
| num_classes, | |||
| num_features, | |||
| *args, | |||
| **kwargs, | |||
| ) -> "BaseClassifier": | |||
| # load graph network module | |||
| self.graph_model_list = [] | |||
| if isinstance(graph_models, list): | |||
| for model in graph_models: | |||
| if isinstance(model, str): | |||
| if model in MODEL_DICT: | |||
| self.graph_model_list.append( | |||
| MODEL_DICT[model]( | |||
| num_classes=num_classes, | |||
| num_features=num_features, | |||
| *args, | |||
| **kwargs, | |||
| init=False, | |||
| ) | |||
| ) | |||
| else: | |||
| raise KeyError("cannot find model %s" % (model)) | |||
| elif isinstance(model, type) and issubclass(model, BaseModel): | |||
| self.graph_model_list.append( | |||
| model( | |||
| num_classes=num_classes, | |||
| num_features=num_features, | |||
| *args, | |||
| **kwargs, | |||
| init=False, | |||
| ) | |||
| ) | |||
| elif isinstance(model, BaseModel): | |||
| # setup the hp of num_classes and num_features | |||
| model.set_num_classes(num_classes) | |||
| model.set_num_features(num_features) | |||
| self.graph_model_list.append(model.to(self.runtime_device)) | |||
| elif isinstance(model, BaseTrainer): | |||
| # receive a trainer list, put trainer to list | |||
| self.graph_model_list.append(model) | |||
| else: | |||
| raise KeyError("cannot find graph network %s." % (model)) | |||
| else: | |||
| raise ValueError( | |||
| "need graph network to be (list of) str or a BaseModel class/instance, get", | |||
| graph_models, | |||
| "instead.", | |||
| ) | |||
| # wrap all model_cls with specified trainer | |||
| for i, model in enumerate(self.graph_model_list): | |||
| # set model hp space | |||
| if self._model_hp_spaces is not None: | |||
| if self._model_hp_spaces[i] is not None: | |||
| if isinstance(model, BaseTrainer): | |||
| model.model.hyper_parameter_space = self._model_hp_spaces[i] | |||
| else: | |||
| model.hyper_parameter_space = self._model_hp_spaces[i] | |||
| # initialize trainer if needed | |||
| if isinstance(model, BaseModel): | |||
| name = ( | |||
| self._default_trainer | |||
| if isinstance(self._default_trainer, str) | |||
| else self._default_trainer[i] | |||
| ) | |||
| model = TRAINER_DICT[name]( | |||
| model=model, | |||
| num_features=num_features, | |||
| num_classes=num_classes, | |||
| *args, | |||
| **kwargs, | |||
| init=False, | |||
| ) | |||
| # set trainer hp space | |||
| if self._trainer_hp_space is not None: | |||
| if isinstance(self._trainer_hp_space[0], list): | |||
| current_hp_for_trainer = self._trainer_hp_space[i] | |||
| else: | |||
| current_hp_for_trainer = self._trainer_hp_space | |||
| model.hyper_parameter_space = current_hp_for_trainer | |||
| self.graph_model_list[i] = model | |||
| return self | |||
| def predict_proba(self, *args, **kwargs) -> Any: | |||
| """ | |||
| Predict the node probability. | |||
| @@ -12,7 +12,8 @@ import yaml | |||
| from .base import BaseClassifier | |||
| from ...module.feature import FEATURE_DICT | |||
| from ...module.train import get_feval | |||
| from ...module.model import BaseModel, MODEL_DICT | |||
| from ...module.train import TRAINER_DICT, get_feval, GraphClassificationTrainer | |||
| from ..base import _initialize_single_model, _parse_hp_space | |||
| from ..utils import Leaderboard, set_seed | |||
| from ...datasets import utils | |||
| @@ -98,6 +99,108 @@ class AutoGraphClassifier(BaseClassifier): | |||
| self.dataset = None | |||
| def _init_graph_module( | |||
| self, | |||
| graph_models, | |||
| num_classes, | |||
| num_features, | |||
| feval, | |||
| device, | |||
| loss, | |||
| num_graph_features | |||
| ) -> "AutoGraphClassifier": | |||
| # load graph network module | |||
| self.graph_model_list = [] | |||
| if isinstance(graph_models, list): | |||
| for model in graph_models: | |||
| if isinstance(model, str): | |||
| if model in MODEL_DICT: | |||
| self.graph_model_list.append( | |||
| MODEL_DICT[model]( | |||
| num_classes=num_classes, | |||
| num_features=num_features, | |||
| num_graph_features=num_graph_features, | |||
| device=device, | |||
| init=False | |||
| ) | |||
| ) | |||
| else: | |||
| raise KeyError("cannot find model %s" % (model)) | |||
| elif isinstance(model, type) and issubclass(model, BaseModel): | |||
| self.graph_model_list.append( | |||
| model( | |||
| num_classes=num_classes, | |||
| num_features=num_features, | |||
| num_graph_features=num_graph_features, | |||
| device=device, | |||
| init=False | |||
| ) | |||
| ) | |||
| elif isinstance(model, BaseModel): | |||
| # setup the hp of num_classes and num_features | |||
| model.set_num_classes(num_classes) | |||
| model.set_num_features(num_features) | |||
| model.set_num_graph_features(num_graph_features) | |||
| self.graph_model_list.append(model.to(device)) | |||
| elif isinstance(model, GraphClassificationTrainer): | |||
| # receive a trainer list, put trainer to list | |||
| assert model.get_model() is not None, "Passed trainer should contain a model" | |||
| model.set_feval(feval) | |||
| model.loss_type = loss | |||
| model.to(device) | |||
| model.model.set_num_classes(num_classes) | |||
| model.model.set_num_features(num_features) | |||
| model.model.set_num_graph_features(num_graph_features) | |||
| model.num_classes = num_classes | |||
| model.num_features = num_features | |||
| model.num_graph_features = num_graph_features | |||
| self.graph_model_list.append(model) | |||
| else: | |||
| raise KeyError("cannot find graph network %s." % (model)) | |||
| else: | |||
| raise ValueError( | |||
| "need graph network to be (list of) str or a BaseModel class/instance, get", | |||
| graph_models, | |||
| "instead.", | |||
| ) | |||
| # wrap all model_cls with specified trainer | |||
| for i, model in enumerate(self.graph_model_list): | |||
| # set model hp space | |||
| if self._model_hp_spaces is not None: | |||
| if self._model_hp_spaces[i] is not None: | |||
| if isinstance(model, GraphClassificationTrainer): | |||
| model.model.hyper_parameter_space = self._model_hp_spaces[i] | |||
| else: | |||
| model.hyper_parameter_space = self._model_hp_spaces[i] | |||
| # initialize trainer if needed | |||
| if isinstance(model, BaseModel): | |||
| name = ( | |||
| self._default_trainer | |||
| if isinstance(self._default_trainer, str) | |||
| else self._default_trainer[i] | |||
| ) | |||
| model = TRAINER_DICT[name]( | |||
| model=model, | |||
| num_features=num_features, | |||
| num_classes=num_classes, | |||
| loss=loss, | |||
| feval=feval, | |||
| device=device, | |||
| num_graph_features=num_graph_features, | |||
| init=False | |||
| ) | |||
| # set trainer hp space | |||
| if self._trainer_hp_space is not None: | |||
| if isinstance(self._trainer_hp_space[0], list): | |||
| current_hp_for_trainer = self._trainer_hp_space[i] | |||
| else: | |||
| current_hp_for_trainer = self._trainer_hp_space | |||
| model.hyper_parameter_space = current_hp_for_trainer | |||
| self.graph_model_list[i] = model | |||
| return self | |||
| # pylint: disable=arguments-differ | |||
| def fit( | |||
| self, | |||
| @@ -13,6 +13,8 @@ import yaml | |||
| from .base import BaseClassifier | |||
| from ..base import _parse_hp_space, _initialize_single_model | |||
| from ...module.feature import FEATURE_DICT | |||
| from ...module.model import MODEL_DICT, BaseModel | |||
| from ...module.train import TRAINER_DICT, NodeClassificationTrainer | |||
| from ...module.train import get_feval | |||
| from ..utils import Leaderboard, set_seed | |||
| from ...datasets import utils | |||
| @@ -100,6 +102,101 @@ class AutoNodeClassifier(BaseClassifier): | |||
| # data to be kept when fit | |||
| self.data = None | |||
| def _init_graph_module( | |||
| self, | |||
| graph_models, | |||
| num_classes, | |||
| num_features, | |||
| feval, | |||
| device, | |||
| loss | |||
| ) -> "AutoNodeClassifier": | |||
| # load graph network module | |||
| self.graph_model_list = [] | |||
| if isinstance(graph_models, list): | |||
| for model in graph_models: | |||
| if isinstance(model, str): | |||
| if model in MODEL_DICT: | |||
| self.graph_model_list.append( | |||
| MODEL_DICT[model]( | |||
| num_classes=num_classes, | |||
| num_features=num_features, | |||
| device=device, | |||
| init=False | |||
| ) | |||
| ) | |||
| else: | |||
| raise KeyError("cannot find model %s" % (model)) | |||
| elif isinstance(model, type) and issubclass(model, BaseModel): | |||
| self.graph_model_list.append( | |||
| model( | |||
| num_classes=num_classes, | |||
| num_features=num_features, | |||
| device=device, | |||
| init=False | |||
| ) | |||
| ) | |||
| elif isinstance(model, BaseModel): | |||
| # setup the hp of num_classes and num_features | |||
| model.set_num_classes(num_classes) | |||
| model.set_num_features(num_features) | |||
| self.graph_model_list.append(model.to(device)) | |||
| elif isinstance(model, NodeClassificationTrainer): | |||
| # receive a trainer list, put trainer to list | |||
| assert model.get_model() is not None, "Passed trainer should contain a model" | |||
| model.set_feval(feval) | |||
| model.loss_type = loss | |||
| model.to(device) | |||
| model.model.set_num_classes(num_classes) | |||
| model.model.set_num_features(num_features) | |||
| model.num_classes = num_classes | |||
| model.num_features = num_features | |||
| self.graph_model_list.append(model) | |||
| else: | |||
| raise KeyError("cannot find graph network %s." % (model)) | |||
| else: | |||
| raise ValueError( | |||
| "need graph network to be (list of) str or a BaseModel class/instance, get", | |||
| graph_models, | |||
| "instead.", | |||
| ) | |||
| # wrap all model_cls with specified trainer | |||
| for i, model in enumerate(self.graph_model_list): | |||
| # set model hp space | |||
| if self._model_hp_spaces is not None: | |||
| if self._model_hp_spaces[i] is not None: | |||
| if isinstance(model, NodeClassificationTrainer): | |||
| model.model.hyper_parameter_space = self._model_hp_spaces[i] | |||
| else: | |||
| model.hyper_parameter_space = self._model_hp_spaces[i] | |||
| # initialize trainer if needed | |||
| if isinstance(model, BaseModel): | |||
| name = ( | |||
| self._default_trainer | |||
| if isinstance(self._default_trainer, str) | |||
| else self._default_trainer[i] | |||
| ) | |||
| model = TRAINER_DICT[name]( | |||
| model=model, | |||
| num_features=num_features, | |||
| num_classes=num_classes, | |||
| loss=loss, | |||
| feval=feval, | |||
| device=device, | |||
| init=False | |||
| ) | |||
| # set trainer hp space | |||
| if self._trainer_hp_space is not None: | |||
| if isinstance(self._trainer_hp_space[0], list): | |||
| current_hp_for_trainer = self._trainer_hp_space[i] | |||
| else: | |||
| current_hp_for_trainer = self._trainer_hp_space | |||
| model.hyper_parameter_space = current_hp_for_trainer | |||
| self.graph_model_list[i] = model | |||
| return self | |||
| # pylint: disable=arguments-differ | |||
| def fit( | |||
| self, | |||