Browse Source

customize init process

tags/v0.3.1
Frozenmad 5 years ago
parent
commit
d19bc12a22
3 changed files with 202 additions and 91 deletions
  1. +1
    -90
      autogl/solver/classifier/base.py
  2. +104
    -1
      autogl/solver/classifier/graph_classifier.py
  3. +97
    -0
      autogl/solver/classifier/node_classifier.py

+ 1
- 90
autogl/solver/classifier/base.py View File

@@ -5,9 +5,7 @@ Base solver for classification problems
from typing import Any
from ..base import BaseSolver
from ...module.ensemble import ENSEMBLE_DICT
from ...module.train import TRAINER_DICT
from ...module.model import MODEL_DICT
from ...module import BaseEnsembler, BaseModel, BaseTrainer
from ...module import BaseEnsembler


class BaseClassifier(BaseSolver):
@@ -15,93 +13,6 @@ class BaseClassifier(BaseSolver):
Base solver for classification problems
"""

def _init_graph_module(
self,
graph_models,
num_classes,
num_features,
*args,
**kwargs,
) -> "BaseClassifier":
# load graph network module
self.graph_model_list = []
if isinstance(graph_models, list):
for model in graph_models:
if isinstance(model, str):
if model in MODEL_DICT:
self.graph_model_list.append(
MODEL_DICT[model](
num_classes=num_classes,
num_features=num_features,
*args,
**kwargs,
init=False,
)
)
else:
raise KeyError("cannot find model %s" % (model))
elif isinstance(model, type) and issubclass(model, BaseModel):
self.graph_model_list.append(
model(
num_classes=num_classes,
num_features=num_features,
*args,
**kwargs,
init=False,
)
)
elif isinstance(model, BaseModel):
# setup the hp of num_classes and num_features
model.set_num_classes(num_classes)
model.set_num_features(num_features)
self.graph_model_list.append(model.to(self.runtime_device))
elif isinstance(model, BaseTrainer):
# receive a trainer list, put trainer to list
self.graph_model_list.append(model)
else:
raise KeyError("cannot find graph network %s." % (model))
else:
raise ValueError(
"need graph network to be (list of) str or a BaseModel class/instance, get",
graph_models,
"instead.",
)

# wrap all model_cls with specified trainer
for i, model in enumerate(self.graph_model_list):
# set model hp space
if self._model_hp_spaces is not None:
if self._model_hp_spaces[i] is not None:
if isinstance(model, BaseTrainer):
model.model.hyper_parameter_space = self._model_hp_spaces[i]
else:
model.hyper_parameter_space = self._model_hp_spaces[i]
# initialize trainer if needed
if isinstance(model, BaseModel):
name = (
self._default_trainer
if isinstance(self._default_trainer, str)
else self._default_trainer[i]
)
model = TRAINER_DICT[name](
model=model,
num_features=num_features,
num_classes=num_classes,
*args,
**kwargs,
init=False,
)
# set trainer hp space
if self._trainer_hp_space is not None:
if isinstance(self._trainer_hp_space[0], list):
current_hp_for_trainer = self._trainer_hp_space[i]
else:
current_hp_for_trainer = self._trainer_hp_space
model.hyper_parameter_space = current_hp_for_trainer
self.graph_model_list[i] = model

return self

def predict_proba(self, *args, **kwargs) -> Any:
"""
Predict the node probability.


+ 104
- 1
autogl/solver/classifier/graph_classifier.py View File

@@ -12,7 +12,8 @@ import yaml

from .base import BaseClassifier
from ...module.feature import FEATURE_DICT
from ...module.train import get_feval
from ...module.model import BaseModel, MODEL_DICT
from ...module.train import TRAINER_DICT, get_feval, GraphClassificationTrainer
from ..base import _initialize_single_model, _parse_hp_space
from ..utils import Leaderboard, set_seed
from ...datasets import utils
@@ -98,6 +99,108 @@ class AutoGraphClassifier(BaseClassifier):

self.dataset = None

def _init_graph_module(
self,
graph_models,
num_classes,
num_features,
feval,
device,
loss,
num_graph_features
) -> "AutoGraphClassifier":
# load graph network module
self.graph_model_list = []
if isinstance(graph_models, list):
for model in graph_models:
if isinstance(model, str):
if model in MODEL_DICT:
self.graph_model_list.append(
MODEL_DICT[model](
num_classes=num_classes,
num_features=num_features,
num_graph_features=num_graph_features,
device=device,
init=False
)
)
else:
raise KeyError("cannot find model %s" % (model))
elif isinstance(model, type) and issubclass(model, BaseModel):
self.graph_model_list.append(
model(
num_classes=num_classes,
num_features=num_features,
num_graph_features=num_graph_features,
device=device,
init=False
)
)
elif isinstance(model, BaseModel):
# setup the hp of num_classes and num_features
model.set_num_classes(num_classes)
model.set_num_features(num_features)
model.set_num_graph_features(num_graph_features)
self.graph_model_list.append(model.to(device))
elif isinstance(model, GraphClassificationTrainer):
# receive a trainer list, put trainer to list
assert model.get_model() is not None, "Passed trainer should contain a model"
model.set_feval(feval)
model.loss_type = loss
model.to(device)
model.model.set_num_classes(num_classes)
model.model.set_num_features(num_features)
model.model.set_num_graph_features(num_graph_features)
model.num_classes = num_classes
model.num_features = num_features
model.num_graph_features = num_graph_features
self.graph_model_list.append(model)
else:
raise KeyError("cannot find graph network %s." % (model))
else:
raise ValueError(
"need graph network to be (list of) str or a BaseModel class/instance, get",
graph_models,
"instead.",
)

# wrap all model_cls with specified trainer
for i, model in enumerate(self.graph_model_list):
# set model hp space
if self._model_hp_spaces is not None:
if self._model_hp_spaces[i] is not None:
if isinstance(model, GraphClassificationTrainer):
model.model.hyper_parameter_space = self._model_hp_spaces[i]
else:
model.hyper_parameter_space = self._model_hp_spaces[i]
# initialize trainer if needed
if isinstance(model, BaseModel):
name = (
self._default_trainer
if isinstance(self._default_trainer, str)
else self._default_trainer[i]
)
model = TRAINER_DICT[name](
model=model,
num_features=num_features,
num_classes=num_classes,
loss=loss,
feval=feval,
device=device,
num_graph_features=num_graph_features,
init=False
)
# set trainer hp space
if self._trainer_hp_space is not None:
if isinstance(self._trainer_hp_space[0], list):
current_hp_for_trainer = self._trainer_hp_space[i]
else:
current_hp_for_trainer = self._trainer_hp_space
model.hyper_parameter_space = current_hp_for_trainer
self.graph_model_list[i] = model

return self

# pylint: disable=arguments-differ
def fit(
self,


+ 97
- 0
autogl/solver/classifier/node_classifier.py View File

@@ -13,6 +13,8 @@ import yaml
from .base import BaseClassifier
from ..base import _parse_hp_space, _initialize_single_model
from ...module.feature import FEATURE_DICT
from ...module.model import MODEL_DICT, BaseModel
from ...module.train import TRAINER_DICT, NodeClassificationTrainer
from ...module.train import get_feval
from ..utils import Leaderboard, set_seed
from ...datasets import utils
@@ -100,6 +102,101 @@ class AutoNodeClassifier(BaseClassifier):
# data to be kept when fit
self.data = None

def _init_graph_module(
self,
graph_models,
num_classes,
num_features,
feval,
device,
loss
) -> "AutoNodeClassifier":
# load graph network module
self.graph_model_list = []
if isinstance(graph_models, list):
for model in graph_models:
if isinstance(model, str):
if model in MODEL_DICT:
self.graph_model_list.append(
MODEL_DICT[model](
num_classes=num_classes,
num_features=num_features,
device=device,
init=False
)
)
else:
raise KeyError("cannot find model %s" % (model))
elif isinstance(model, type) and issubclass(model, BaseModel):
self.graph_model_list.append(
model(
num_classes=num_classes,
num_features=num_features,
device=device,
init=False
)
)
elif isinstance(model, BaseModel):
# setup the hp of num_classes and num_features
model.set_num_classes(num_classes)
model.set_num_features(num_features)
self.graph_model_list.append(model.to(device))
elif isinstance(model, NodeClassificationTrainer):
# receive a trainer list, put trainer to list
assert model.get_model() is not None, "Passed trainer should contain a model"
model.set_feval(feval)
model.loss_type = loss
model.to(device)
model.model.set_num_classes(num_classes)
model.model.set_num_features(num_features)
model.num_classes = num_classes
model.num_features = num_features
self.graph_model_list.append(model)
else:
raise KeyError("cannot find graph network %s." % (model))
else:
raise ValueError(
"need graph network to be (list of) str or a BaseModel class/instance, get",
graph_models,
"instead.",
)

# wrap all model_cls with specified trainer
for i, model in enumerate(self.graph_model_list):
# set model hp space
if self._model_hp_spaces is not None:
if self._model_hp_spaces[i] is not None:
if isinstance(model, NodeClassificationTrainer):
model.model.hyper_parameter_space = self._model_hp_spaces[i]
else:
model.hyper_parameter_space = self._model_hp_spaces[i]
# initialize trainer if needed
if isinstance(model, BaseModel):
name = (
self._default_trainer
if isinstance(self._default_trainer, str)
else self._default_trainer[i]
)
model = TRAINER_DICT[name](
model=model,
num_features=num_features,
num_classes=num_classes,
loss=loss,
feval=feval,
device=device,
init=False
)
# set trainer hp space
if self._trainer_hp_space is not None:
if isinstance(self._trainer_hp_space[0], list):
current_hp_for_trainer = self._trainer_hp_space[i]
else:
current_hp_for_trainer = self._trainer_hp_space
model.hyper_parameter_space = current_hp_for_trainer
self.graph_model_list[i] = model

return self

# pylint: disable=arguments-differ
def fit(
self,


Loading…
Cancel
Save