Browse Source

rebase trainer

tags/v0.3.1
Frozenmad 5 years ago
parent
commit
38ab2b06b9
7 changed files with 221 additions and 131 deletions
  1. +13
    -8
      autogl/module/train/__init__.py
  2. +130
    -31
      autogl/module/train/base.py
  3. +22
    -36
      autogl/module/train/graph_classification_full.py
  4. +25
    -28
      autogl/module/train/node_classification_full.py
  5. +2
    -3
      autogl/solver/base.py
  6. +15
    -13
      autogl/solver/classifier/graph_classifier.py
  7. +14
    -12
      autogl/solver/classifier/node_classifier.py

+ 13
- 8
autogl/module/train/__init__.py View File

@@ -1,8 +1,14 @@
import importlib
import os
from .base import BaseTrainer, Evaluation, EarlyStopping

TRAINER_DICT = {}
EVALUATE_DICT = {}
from .base import (
BaseTrainer,
Evaluation,
BaseNodeClassificationTrainer,
BaseGraphClassificationTrainer,
)


def register_trainer(name):
@@ -19,9 +25,6 @@ def register_trainer(name):
return register_trainer_cls


EVALUATE_DICT = {}


def register_evaluate(*name):
def register_evaluate_cls(cls):
for n in name:
@@ -47,14 +50,16 @@ def get_feval(feval):
raise ValueError("feval argument of type", type(feval), "is not supported!")


from .graph_classification import GraphClassificationTrainer
from .node_classification import NodeClassificationTrainer
from .graph_classification_full import GraphClassificationFullTrainer
from .node_classification_full import NodeClassificationFullTrainer
from .evaluate import Acc, Auc, Logloss

__all__ = [
"BaseTrainer",
"GraphClassificationTrainer",
"NodeClassificationTrainer",
"BaseNodeClassificationTrainer",
"BaseGraphClassificationTrainer",
"GraphClassificationFullTrainer",
"NodeClassificationFullTrainer",
"Evaluation",
"Acc",
"Auc",


+ 130
- 31
autogl/module/train/base.py View File

@@ -1,12 +1,25 @@
import numpy as np
from typing import Union, Iterable
from ..model import BaseModel

import torch
from ..model import BaseModel, MODEL_DICT
import pickle
from ...utils import get_logger
from . import EVALUATE_DICT

LOGGER_ES = get_logger("early-stopping")


def get_feval(feval):
if isinstance(feval, str):
return EVALUATE_DICT[feval]
if isinstance(feval, type) and issubclass(feval, Evaluation):
return feval
if isinstance(feval, list):
return [get_feval(f) for f in feval]
raise ValueError("feval argument of type", type(feval), "is not supported!")


class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""

@@ -81,17 +94,11 @@ class EarlyStopping:
class BaseTrainer:
def __init__(
self,
model: Union[BaseModel, str],
optimizer=None,
lr=None,
max_epoch=None,
early_stopping_round=None,
device=None,
model: BaseModel,
device: Union[torch.device, str],
init=True,
feval=["acc"],
loss="nll_loss",
*args,
**kwargs,
):
"""
The basic trainer.
@@ -103,29 +110,26 @@ class BaseTrainer:
model: `BaseModel` or `str`
The (name of) model used to train and predict.

optimizer: `Optimizer` of `str`
The (name of) optimizer used to train and predict.

lr: `float`
The learning rate.

max_epoch: `int`
The max number of epochs in training.

early_stopping_round: `int`
The round of early stop.

device: `torch.device` or `str`
The device where model will be running on.

init: `bool`
If True(False), the model will (not) be initialized.
"""
super().__init__()
self.model = model
self.to(device)
self.init = init
self.feval = get_feval(feval)
self.loss = loss

args: Other parameters.
def to(self, device):
"""
Migrate trainer to new device

kwargs: Other parameters.
Parameters
----------
device: `str` or `torch.device`
The device this trainer will use
"""
super().__init__()
self.device = torch.device(device)

def initialize(self):
"""Initialize the auto model in trainer."""
@@ -169,8 +173,8 @@ class BaseTrainer:

@classmethod
def load(cls, path):
with open(path, "rb") as input:
instance = pickle.load(input)
with open(path, "rb") as inputs:
instance = pickle.load(inputs)
return instance

@property
@@ -279,7 +283,21 @@ class BaseTrainer:

def set_feval(self, feval):
"""Set the evaluation metrics."""
raise NotImplementedError()
self.feval = get_feval(feval)

def update_parameters(self, **kwargs):
"""
Update parameters of this trainer
"""
for k, v in kwargs.items():
if k == "feval":
self.set_feval(v)
elif k == "device":
self.to(v)
elif hasattr(self, k):
setattr(self, k, v)
else:
raise KeyError("Cannot set parameter", k, "for trainer", self.__class__)


# a static class for evaluating results
@@ -296,7 +314,7 @@ class Evaluation:
"""
Should return whether this evaluation method is higher better (bool)
"""
raise True
return True

@staticmethod
def evaluate(predict, label):
@@ -304,3 +322,84 @@ class Evaluation:
Should return: the evaluation result (float)
"""
raise NotImplementedError()


class BaseNodeClassificationTrainer(BaseTrainer):
def __init__(
self,
model: Union[BaseModel, str],
num_features,
num_classes,
device="auto",
init=True,
feval=["acc"],
loss="nll_loss",
):
self.num_features = num_features
self.num_classes = num_classes
device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "auto"
else torch.device(device)
)
if isinstance(model, str):
assert model in MODEL_DICT, "Cannot parse model name " + model
self.model = MODEL_DICT[model](num_features, num_classes, device, init=init)
elif isinstance(model, BaseModel):
self.model = model
else:
raise TypeError(
"Model argument only support str or BaseModel, get",
type(model),
"instead.",
)
super().__init__(model, device=device, init=init, feval=feval, loss=loss)

@classmethod
def get_task_name(cls):
return "GraphClassification"


class BaseGraphClassificationTrainer(BaseTrainer):
def __init__(
self,
model: Union[BaseModel, str],
num_features,
num_classes,
num_graph_features=0,
device=None,
init=True,
feval=["acc"],
loss="nll_loss",
):
self.num_features = num_features
self.num_classes = num_classes
self.num_graph_features = num_graph_features
device = (
torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "auto"
else torch.device(device)
)
if isinstance(model, str):
assert model in MODEL_DICT, "Cannot parse model name " + model
self.model = MODEL_DICT[model](
num_features,
num_classes,
device,
init=init,
num_graph_features=num_graph_features,
)
elif isinstance(model, BaseModel):
self.model = model
else:
raise TypeError(
"Model argument only support str or BaseModel, get",
type(model),
"instead.",
)

super().__init__(model, device=device, init=init, feval=feval, loss=loss)

@classmethod
def get_task_name(cls):
return "NodeClassification"

autogl/module/train/graph_classification.py → autogl/module/train/graph_classification_full.py View File

@@ -1,4 +1,5 @@
from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping
from . import register_trainer, EVALUATE_DICT
from .base import BaseGraphClassificationTrainer, EarlyStopping, Evaluation
import torch
from torch.optim.lr_scheduler import (
StepLR,
@@ -28,8 +29,8 @@ def get_feval(feval):
raise ValueError("feval argument of type", type(feval), "is not supported!")


@register_trainer("GraphClassification")
class GraphClassificationTrainer(BaseTrainer):
@register_trainer("GraphClassificationFull")
class GraphClassificationFullTrainer(BaseGraphClassificationTrainer):
"""
The graph classification trainer.

@@ -73,7 +74,7 @@ class GraphClassificationTrainer(BaseTrainer):
batch_size=None,
early_stopping_round=7,
weight_decay=1e-4,
device=None,
device="auto",
init=True,
feval=[Logloss],
loss="nll_loss",
@@ -81,22 +82,16 @@ class GraphClassificationTrainer(BaseTrainer):
*args,
**kwargs
):
super(GraphClassificationTrainer, self).__init__(model)

self.loss_type = loss

# init model
if isinstance(model, str):
assert model in MODEL_DICT, "Cannot parse model name " + model
self.model = MODEL_DICT[model](
num_features,
num_classes,
device,
init=init,
num_graph_features=num_graph_features,
)
elif isinstance(model, BaseModel):
self.model = model
super().__init__(
model,
num_features,
num_classes,
num_graph_features=num_graph_features,
device=device,
init=init,
feval=feval,
loss=loss,
)

self.opt_received = optimizer
if type(optimizer) == str and optimizer.lower() == "adam":
@@ -108,9 +103,6 @@ class GraphClassificationTrainer(BaseTrainer):

self.lr_scheduler_type = lr_scheduler_type

self.num_features = num_features
self.num_classes = num_classes
self.num_graph_features = num_graph_features
self.lr = lr if lr is not None else 1e-4
self.max_epoch = max_epoch if max_epoch is not None else 100
self.batch_size = batch_size if batch_size is not None else 64
@@ -135,8 +127,6 @@ class GraphClassificationTrainer(BaseTrainer):
self.valid_score = None

self.initialized = False
self.num_features = num_features
self.num_classes = num_classes
self.device = device

self.space = [
@@ -176,8 +166,6 @@ class GraphClassificationTrainer(BaseTrainer):
"scalingType": "LOG",
},
]
# self.space += self.model.space
GraphClassificationTrainer.space = self.space

self.hyperparams = {
"max_epoch": self.max_epoch,
@@ -186,7 +174,6 @@ class GraphClassificationTrainer(BaseTrainer):
"lr": self.lr,
"weight_decay": self.weight_decay,
}
self.hyperparams = {**self.hyperparams, **self.model.get_hyper_parameter()}

if init is True:
self.initialize()
@@ -207,9 +194,9 @@ class GraphClassificationTrainer(BaseTrainer):
# """Get task name, i.e., `GraphClassification`."""
return "GraphClassification"

def to(self, new_device):
assert isinstance(new_device, torch.device)
self.device = new_device
def to(self, device):
assert isinstance(device, torch.device)
self.device = device
if self.model is not None:
self.model.to(self.device)

@@ -255,11 +242,11 @@ class GraphClassificationTrainer(BaseTrainer):
optimizer.zero_grad()
output = self.model.model(data)
# loss = F.nll_loss(output, data.y)
if hasattr(F, self.loss_type):
loss = getattr(F, self.loss_type)(output, data.y)
if hasattr(F, self.loss):
loss = getattr(F, self.loss)(output, data.y)
else:
raise TypeError(
"PyTorch does not support loss type {}".format(self.loss_type)
"PyTorch does not support loss type {}".format(self.loss)
)
loss.backward()
loss_all += data.num_graphs * loss.item()
@@ -569,7 +556,7 @@ class GraphClassificationTrainer(BaseTrainer):
weight_decay=hp["weight_decay"],
device=self.device,
feval=self.feval,
loss=self.loss_type,
loss=self.loss,
lr_scheduler_type=self.lr_scheduler_type,
init=True,
*self.args,
@@ -591,7 +578,6 @@ class GraphClassificationTrainer(BaseTrainer):
def hyper_parameter_space(self, space):
# """Set the space of hyperparameter."""
self.space = space
GraphClassificationTrainer.space = space

def get_hyper_parameter(self):
# """Get the hyperparameter in this trainer."""

autogl/module/train/node_classification.py → autogl/module/train/node_classification_full.py View File

@@ -1,4 +1,10 @@
from . import register_trainer, BaseTrainer, Evaluation, EVALUATE_DICT, EarlyStopping
"""
Node classification Full Trainer Implementation
"""

from . import register_trainer, EVALUATE_DICT

from .base import BaseNodeClassificationTrainer, EarlyStopping, Evaluation
import torch
from torch.optim.lr_scheduler import (
StepLR,
@@ -27,8 +33,8 @@ def get_feval(feval):
raise ValueError("feval argument of type", type(feval), "is not supported!")


@register_trainer("NodeClassification")
class NodeClassificationTrainer(BaseTrainer):
@register_trainer("NodeClassificationFull")
class NodeClassificationFullTrainer(BaseNodeClassificationTrainer):
"""
The node classification trainer.

@@ -58,8 +64,6 @@ class NodeClassificationTrainer(BaseTrainer):
If True(False), the model will (not) be initialized.
"""

space = None

def __init__(
self,
model: Union[BaseModel, str],
@@ -70,7 +74,7 @@ class NodeClassificationTrainer(BaseTrainer):
max_epoch=None,
early_stopping_round=None,
weight_decay=1e-4,
device=None,
device="auto",
init=True,
feval=[Logloss],
loss="nll_loss",
@@ -78,12 +82,15 @@ class NodeClassificationTrainer(BaseTrainer):
*args,
**kwargs
):
super(NodeClassificationTrainer, self).__init__(model)

self.loss_type = loss

if device is None:
device = "cpu"
super().__init__(
model,
num_features,
num_classes,
device=device,
init=init,
feval=feval,
loss=loss,
)

# init model
if isinstance(model, str):
@@ -102,14 +109,11 @@ class NodeClassificationTrainer(BaseTrainer):

self.lr_scheduler_type = lr_scheduler_type

self.num_features = num_features
self.num_classes = num_classes
self.lr = lr if lr is not None else 1e-4
self.max_epoch = max_epoch if max_epoch is not None else 100
self.early_stopping_round = (
early_stopping_round if early_stopping_round is not None else 100
)
self.device = device
self.args = args
self.kwargs = kwargs

@@ -126,9 +130,6 @@ class NodeClassificationTrainer(BaseTrainer):
self.valid_score = None

self.initialized = False
self.num_features = num_features
self.num_classes = num_classes
self.device = device

self.space = [
{
@@ -160,8 +161,6 @@ class NodeClassificationTrainer(BaseTrainer):
"scalingType": "LOG",
},
]
# self.space += self.model.space
NodeClassificationTrainer.space = self.space

self.hyperparams = {
"max_epoch": self.max_epoch,
@@ -169,7 +168,6 @@ class NodeClassificationTrainer(BaseTrainer):
"lr": self.lr,
"weight_decay": self.weight_decay,
}
self.hyperparams = {**self.hyperparams, **self.model.get_hyper_parameter()}

if init is True:
self.initialize()
@@ -229,11 +227,11 @@ class NodeClassificationTrainer(BaseTrainer):
self.model.model.train()
optimizer.zero_grad()
res = self.model.model.forward(data)
if hasattr(F, self.loss_type):
loss = getattr(F, self.loss_type)(res[mask], data.y[mask])
if hasattr(F, self.loss):
loss = getattr(F, self.loss)(res[mask], data.y[mask])
else:
raise TypeError(
"PyTorch does not support loss type {}".format(self.loss_type)
"PyTorch does not support loss type {}".format(self.loss)
)

loss.backward()
@@ -241,7 +239,7 @@ class NodeClassificationTrainer(BaseTrainer):
if self.lr_scheduler_type:
scheduler.step()

if hasattr(data, 'val_mask') and data.val_mask is not None:
if hasattr(data, "val_mask") and data.val_mask is not None:
if type(self.feval) is list:
feval = self.feval[0]
else:
@@ -253,7 +251,7 @@ class NodeClassificationTrainer(BaseTrainer):
if self.early_stopping.early_stop:
LOGGER.debug("Early stopping at %d", epoch)
break
if hasattr(data, 'val_mask') and data.val_mask is not None:
if hasattr(data, "val_mask") and data.val_mask is not None:
self.early_stopping.load_checkpoint(self.model.model)

def predict_only(self, data, test_mask=None):
@@ -516,7 +514,7 @@ class NodeClassificationTrainer(BaseTrainer):
device=self.device,
weight_decay=hp["weight_decay"],
feval=self.feval,
loss=self.loss_type,
loss=self.loss,
lr_scheduler_type=self.lr_scheduler_type,
init=True,
*self.args,
@@ -538,7 +536,6 @@ class NodeClassificationTrainer(BaseTrainer):
def hyper_parameter_space(self, space):
# """Set the space of hyperparameter."""
self.space = space
NodeClassificationTrainer.space = space

def get_hyper_parameter(self):
# """Get the hyperparameter in this trainer."""

+ 2
- 3
autogl/solver/base.py View File

@@ -11,7 +11,6 @@ import torch
from ..module.feature import FEATURE_DICT
from ..module.hpo import HPO_DICT
from ..module.model import MODEL_DICT
from ..module.train import NodeClassificationTrainer
from ..module import BaseFeatureAtom, BaseHPOptimizer, BaseTrainer
from .utils import Leaderboard
from ..utils import get_logger
@@ -336,7 +335,7 @@ class BaseSolver:
assert name in self.trained_models, "cannot find model by name" + name
return self.trained_models[name]

def get_model_by_performance(self, index) -> Tuple[NodeClassificationTrainer, str]:
def get_model_by_performance(self, index) -> Tuple[BaseTrainer, str]:
r"""
Find and get the model instance by performance.

@@ -347,7 +346,7 @@ class BaseSolver:

Returns
-------
trainer: autogl.module.train.NodeClassificationTrainer
trainer: autogl.module.train.BaseTrainer
A trainer instance containing the trained models and training status.
name: str
The name of current trainer.


+ 15
- 13
autogl/solver/classifier/graph_classifier.py View File

@@ -13,7 +13,7 @@ import yaml
from .base import BaseClassifier
from ...module.feature import FEATURE_DICT
from ...module.model import BaseModel, MODEL_DICT
from ...module.train import TRAINER_DICT, get_feval, GraphClassificationTrainer
from ...module.train import TRAINER_DICT, get_feval, BaseGraphClassificationTrainer
from ..base import _initialize_single_model, _parse_hp_space
from ..utils import Leaderboard, set_seed
from ...datasets import utils
@@ -90,7 +90,7 @@ class AutoGraphClassifier(BaseClassifier):
hpo_module=hpo_module,
ensemble_module=ensemble_module,
max_evals=max_evals,
default_trainer=default_trainer or "GraphClassification",
default_trainer=default_trainer or "GraphClassificationFull",
trainer_hp_space=trainer_hp_space,
model_hp_spaces=model_hp_spaces,
size=size,
@@ -142,20 +142,22 @@ class AutoGraphClassifier(BaseClassifier):
model.set_num_features(num_features)
model.set_num_graph_features(num_graph_features)
self.graph_model_list.append(model.to(device))
elif isinstance(model, GraphClassificationTrainer):
elif isinstance(model, BaseGraphClassificationTrainer):
# receive a trainer list, put trainer to list
assert (
model.get_model() is not None
), "Passed trainer should contain a model"
model.set_feval(feval)
model.loss_type = loss
model.to(device)
model.model.set_num_classes(num_classes)
model.model.set_num_features(num_features)
model.model.set_num_graph_features(num_graph_features)
model.num_classes = num_classes
model.num_features = num_features
model.num_graph_features = num_graph_features
model.update_parameters(
num_classes=num_classes,
num_features=num_features,
num_graph_features=num_graph_features,
loss=loss,
feval=feval,
device=device,
)
self.graph_model_list.append(model)
else:
raise KeyError("cannot find graph network %s." % (model))
@@ -171,7 +173,7 @@ class AutoGraphClassifier(BaseClassifier):
# set model hp space
if self._model_hp_spaces is not None:
if self._model_hp_spaces[i] is not None:
if isinstance(model, GraphClassificationTrainer):
if isinstance(model, BaseGraphClassificationTrainer):
model.model.hyper_parameter_space = self._model_hp_spaces[i]
else:
model.hyper_parameter_space = self._model_hp_spaces[i]
@@ -770,11 +772,11 @@ class AutoGraphClassifier(BaseClassifier):
]

trainer = path_or_dict.pop("trainer", None)
default_trainer = "GraphClassification"
default_trainer = "GraphClassificationFull"
trainer_space = None
if isinstance(trainer, dict):
# global default
default_trainer = trainer.pop("name", "GraphClassification")
default_trainer = trainer.pop("name", "GraphClassificationFull")
trainer_space = _parse_hp_space(trainer.pop("hp_space", None))
default_kwargs = {"num_features": None, "num_classes": None}
default_kwargs.update(trainer)
@@ -793,7 +795,7 @@ class AutoGraphClassifier(BaseClassifier):
trainer_space = []
for i in range(len(model_list)):
train, model = trainer[i], model_list[i]
default_trainer = train.pop("name", "GraphClassification")
default_trainer = train.pop("name", "GraphClassificationFull")
trainer_space.append(_parse_hp_space(train.pop("hp_space", None)))
default_kwargs = {"num_features": None, "num_classes": None}
default_kwargs.update(train)


+ 14
- 12
autogl/solver/classifier/node_classifier.py View File

@@ -14,7 +14,7 @@ from .base import BaseClassifier
from ..base import _parse_hp_space, _initialize_single_model
from ...module.feature import FEATURE_DICT
from ...module.model import MODEL_DICT, BaseModel
from ...module.train import TRAINER_DICT, NodeClassificationTrainer
from ...module.train import TRAINER_DICT, BaseNodeClassificationTrainer
from ...module.train import get_feval
from ..utils import Leaderboard, set_seed
from ...datasets import utils
@@ -92,7 +92,7 @@ class AutoNodeClassifier(BaseClassifier):
hpo_module=hpo_module,
ensemble_module=ensemble_module,
max_evals=max_evals,
default_trainer=default_trainer or "NodeClassification",
default_trainer=default_trainer or "NodeClassificationFull",
trainer_hp_space=trainer_hp_space,
model_hp_spaces=model_hp_spaces,
size=size,
@@ -135,18 +135,20 @@ class AutoNodeClassifier(BaseClassifier):
model.set_num_classes(num_classes)
model.set_num_features(num_features)
self.graph_model_list.append(model.to(device))
elif isinstance(model, NodeClassificationTrainer):
elif isinstance(model, BaseNodeClassificationTrainer):
# receive a trainer list, put trainer to list
assert (
model.get_model() is not None
), "Passed trainer should contain a model"
model.set_feval(feval)
model.loss_type = loss
model.to(device)
model.model.set_num_classes(num_classes)
model.model.set_num_features(num_features)
model.num_classes = num_classes
model.num_features = num_features
model.update_parameters(
num_classes=num_classes,
num_features=num_features,
loss=loss,
feval=feval,
device=device,
)
self.graph_model_list.append(model)
else:
raise KeyError("cannot find graph network %s." % (model))
@@ -162,7 +164,7 @@ class AutoNodeClassifier(BaseClassifier):
# set model hp space
if self._model_hp_spaces is not None:
if self._model_hp_spaces[i] is not None:
if isinstance(model, NodeClassificationTrainer):
if isinstance(model, BaseNodeClassificationTrainer):
model.model.hyper_parameter_space = self._model_hp_spaces[i]
else:
model.hyper_parameter_space = self._model_hp_spaces[i]
@@ -689,11 +691,11 @@ class AutoNodeClassifier(BaseClassifier):
]

trainer = path_or_dict.pop("trainer", None)
default_trainer = "NodeClassification"
default_trainer = "NodeClassificationFull"
trainer_space = None
if isinstance(trainer, dict):
# global default
default_trainer = trainer.pop("name", "NodeClassification")
default_trainer = trainer.pop("name", "NodeClassificationFull")
trainer_space = _parse_hp_space(trainer.pop("hp_space", None))
default_kwargs = {"num_features": None, "num_classes": None}
default_kwargs.update(trainer)
@@ -712,7 +714,7 @@ class AutoNodeClassifier(BaseClassifier):
trainer_space = []
for i in range(len(model_list)):
train, model = trainer[i], model_list[i]
default_trainer = train.pop("name", "NodeClassification")
default_trainer = train.pop("name", "NodeClassificationFull")
trainer_space.append(_parse_hp_space(train.pop("hp_space", None)))
default_kwargs = {"num_features": None, "num_classes": None}
default_kwargs.update(train)


Loading…
Cancel
Save