From 01900fbd6c68dfdf904d64042e85d67bd1b66cb3 Mon Sep 17 00:00:00 2001 From: Frozenmad Date: Mon, 13 Dec 2021 12:13:35 +0000 Subject: [PATCH] Adjust node classification solver --- autogl/solver/classifier/node_classifier.py | 134 +++++------- test/solver/nodeclf_pyg.py | 218 ++++++++++++++++++++ 2 files changed, 265 insertions(+), 87 deletions(-) create mode 100644 test/solver/nodeclf_pyg.py diff --git a/autogl/solver/classifier/node_classifier.py b/autogl/solver/classifier/node_classifier.py index ddc6841..68e59a8 100644 --- a/autogl/solver/classifier/node_classifier.py +++ b/autogl/solver/classifier/node_classifier.py @@ -84,14 +84,14 @@ class AutoNodeClassifier(BaseClassifier): def __init__( self, feature_module=None, - graph_models=("gat", "gcn"), + graph_models=("gat", "gcn"), # TODO: support a list of model nas_algorithms=None, nas_spaces=None, nas_estimators=None, hpo_module="anneal", ensemble_module="voting", max_evals=50, - default_trainer=None, + default_trainer="NodeClassificationFull", trainer_hp_space=None, model_hp_spaces=None, size=4, @@ -107,7 +107,7 @@ class AutoNodeClassifier(BaseClassifier): hpo_module=hpo_module, ensemble_module=ensemble_module, max_evals=max_evals, - default_trainer=default_trainer or "NodeClassificationFull", + default_trainer=default_trainer, trainer_hp_space=trainer_hp_space, model_hp_spaces=model_hp_spaces, size=size, @@ -122,91 +122,52 @@ class AutoNodeClassifier(BaseClassifier): ) -> "AutoNodeClassifier": # load graph network module self.graph_model_list = [] - if isinstance(graph_models, (list, tuple)): - for model in graph_models: - if isinstance(model, str): - if model in MODEL_DICT: - self.graph_model_list.append( - MODEL_DICT[model]( - num_classes=num_classes, - num_features=num_features, - device=device, - init=False, - ) - ) - else: - raise KeyError("cannot find model %s" % (model)) - elif isinstance(model, type) and issubclass(model, BaseAutoModel): - self.graph_model_list.append( - model( - num_classes=num_classes, - num_features=num_features, - device=device, - init=False, - ) - ) - elif isinstance(model, BaseAutoModel): - # setup the hp of num_classes and num_features - model.set_num_classes(num_classes) - model.set_num_features(num_features) - self.graph_model_list.append(model.to(device)) - elif isinstance(model, BaseNodeClassificationTrainer): - # receive a trainer list, put trainer to list - assert ( - model.get_model() is not None - ), "Passed trainer should contain a model" - model.model.set_num_classes(num_classes) - model.model.set_num_features(num_features) - model.update_parameters( - num_classes=num_classes, - num_features=num_features, - loss=loss, - feval=feval, - device=device, - ) - self.graph_model_list.append(model) + + for i, model in enumerate(graph_models): + # init the trainer + if not isinstance(model, BaseNodeClassificationTrainer): + trainer = ( + self._default_trainer if not isinstance(self._default_trainer, (tuple, list)) + else self._default_trainer[i] + ) + if isinstance(trainer, str): + trainer = TRAINER_DICT[trainer]() + if isinstance(model, (tuple, list)): + trainer.encoder = model[0] + trainer.decoder = model[1] else: - raise KeyError("cannot find graph network %s." % (model)) - else: - raise ValueError( - "need graph network to be (list of) str or a BaseModel class/instance, get", - graph_models, - "instead.", - ) + trainer.encoder = model + else: + trainer = model - # wrap all model_cls with specified trainer - for i, model in enumerate(self.graph_model_list): # set model hp space if self._model_hp_spaces is not None: if self._model_hp_spaces[i] is not None: - if isinstance(model, BaseNodeClassificationTrainer): - model.model.hyper_parameter_space = self._model_hp_spaces[i] + if isinstance(self._model_hp_spaces[i], dict): + encoder_hp_space = self._model_hp_spaces[i].get('encoder', None) + decoder_hp_space = self._model_hp_spaces[i].get('decoder', None) else: - model.hyper_parameter_space = self._model_hp_spaces[i] - # initialize trainer if needed - if isinstance(model, BaseAutoModel): - name = ( - self._default_trainer - if isinstance(self._default_trainer, str) - else self._default_trainer[i] - ) - model = TRAINER_DICT[name]( - model=model, - num_features=num_features, - num_classes=num_classes, - loss=loss, - feval=feval, - device=device, - init=False, - ) + encoder_hp_space = self._model_hp_spaces[i] + decoder_hp_space = None + if encoder_hp_space is not None: + trainer.encoder.hyper_parameter_space = encoder_hp_space + if decoder_hp_space is not None: + trainer.decoder.hyper_parameter_space = decoder_hp_space + # set trainer hp space if self._trainer_hp_space is not None: if isinstance(self._trainer_hp_space[0], list): current_hp_for_trainer = self._trainer_hp_space[i] else: current_hp_for_trainer = self._trainer_hp_space - model.hyper_parameter_space = current_hp_for_trainer - self.graph_model_list[i] = model + trainer.hyper_parameter_space = current_hp_for_trainer + + trainer.num_features = num_features + trainer.num_classes = num_classes + trainer.loss = loss + trainer.feval = feval + trainer.to(device) + self.graph_model_list.append(trainer) return self @@ -400,24 +361,23 @@ class AutoNodeClassifier(BaseClassifier): device=self.runtime_device, init=False, ) - else: + elif isinstance(train_name, BaseNodeClassificationTrainer): trainer = train_name - trainer.model = model - trainer.update_parameters( - num_features=num_features, - num_classes=num_classes, - loss="nll_loss" - if not hasattr(dataset, "loss") - else dataset.loss, - feval=evaluator_list, - device=self.runtime_device, - ) + trainer.encoder = model + trainer.num_features = num_features + trainer.num_classes = num_classes + trainer.loss = "nll_loss" if not hasattr(dataset, "loss") else dataset.loss + trainer.feval = evaluator_list + trainer.to(self.runtime_device) + else: + raise ValueError() self.graph_model_list.append(trainer) # train the models and tune hpo result_valid = [] names = [] for idx, model in enumerate(self.graph_model_list): + model: BaseNodeClassificationTrainer time_for_each_model = (time_limit - time.time() + time_begin) / ( len(self.graph_model_list) - idx ) diff --git a/test/solver/nodeclf_pyg.py b/test/solver/nodeclf_pyg.py new file mode 100644 index 0000000..73c8f46 --- /dev/null +++ b/test/solver/nodeclf_pyg.py @@ -0,0 +1,218 @@ +from autogl.datasets import build_dataset_from_name +from autogl.solver import AutoNodeClassifier +from torch_geometric.datasets import Planetoid +from autogl.module.model import BaseAutoModel, BaseEncoder, AutoClassifierDecoder, BaseDecoder, AutoHomogeneousEncoder +from autogl.module.train import NodeClassificationFullTrainer +import torch +import torch_geometric.nn as gnn +import torch.nn.functional as F + +def activate(act, x): + if hasattr(torch, act): return getattr(torch, act)(x) + return getattr(F, act)(x) + +class GCN(torch.nn.Module): + def __init__(self, num_features, num_classes): + super(GCN, self).__init__() + self.conv1 = gnn.GCNConv(num_features, 16) + self.conv2 = gnn.GCNConv(16, num_classes) + + def forward(self, data): + x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr + x = F.relu(self.conv1(x, edge_index, edge_weight)) + x = F.dropout(x, training=self.training) + x = self.conv2(x, edge_index, edge_weight) + return F.log_softmax(x, dim=1) + +class AutoGCN(BaseAutoModel): + def __init__(self, num_features=None, num_classes=None, device="cpu"): + super().__init__(device=device) + self.device = device + self.num_features = num_features + self.num_classes = num_classes + self.hyper_parameter_space = [] + self.hyper_parameter = {} + + def initialize(self): + self.model = GCN(self.num_features, self.num_classes).to(self.device) + + @property + def num_features(self): + return self.__num_features + + @num_features.setter + def num_features(self, num_features): + self.__num_features = num_features + + @property + def num_classes(self): + return self.__num_classes + + @num_classes.setter + def num_classes(self, num_classes): + self.__num_classes = num_classes + + def from_hyper_parameter(self, hp): + model = AutoGCN(self.num_features, self.num_classes, self.device) + model.initialize() + return model + +class GCNEncoder(BaseEncoder): + def __init__(self, num_features, last_dim, num_layers=2, hidden=(16,), dropout=0.6, act="relu"): + super().__init__() + self.core = torch.nn.ModuleList() + + # first layer + if num_layers == 1: + self.core.append(gnn.GCNConv(num_features, last_dim)) + else: + self.core.append(gnn.GCNConv(num_features, hidden[0])) + + # middle layer + for layer in range(num_layers - 2): + self.core.append(gnn.GCNConv(hidden[layer], hidden[layer + 1])) + + # last layer + if num_layers > 1: + self.core.append(gnn.GCNConv(hidden[-1], last_dim)) + + self.act = act + self.dropout = dropout + + def forward(self, data): + x, edge_index = data.x, data.edge_index + features = [] + for i, layer in enumerate(self.core): + if i > 0: + x = F.dropout(x, p=self.dropout, training=self.training) + x = activate(self.act, x) + x = layer(x, edge_index) + features.append(x) + return features + +class AutoGCNEncoder(AutoHomogeneousEncoder): + def __init__(self, num_features=None, last_dim="auto", device="auto"): + super().__init__(device) + + self.num_features = num_features + + self.hyper_parameter_space = [ + { + "parameterName": "num_layers", + "type": "DISCRETE", + "feasiblePoints": "2,3,4", + }, + { + "parameterName": "hidden", + "type": "NUMERICAL_LIST", + "numericalType": "INTEGER", + "length": 3, + "minValue": [8, 8, 8], + "maxValue": [128, 128, 128], + "scalingType": "LOG", + "cutPara": ("num_layers",), + "cutFunc": lambda x: x[0] - 1, + }, + { + "parameterName": "dropout", + "type": "DOUBLE", + "maxValue": 0.8, + "minValue": 0.2, + "scalingType": "LINEAR", + }, + { + "parameterName": "act", + "type": "CATEGORICAL", + "feasiblePoints": ["leaky_relu", "relu", "elu", "tanh"], + }, + ] + + self.hyper_parameter = { + "num_layers": 2, + "hidden": [16], + "dropout": 0.6, + "act": "tanh" + } + + if last_dim == "auto": + self.register_hyper_parameter_space({ + "parameterName": "last_dim", + "type": "INTEGER", + "scalingType": "LOG", + "minValue": 8, + "maxValue": 128 + }) + self.register_hyper_parameter("last_dim", 16) + else: + self.last_dim = last_dim + + def initialize(self): + self.model = GCNEncoder( + self.num_features, self.last_dim, self.num_layers, self.hidden, self.dropout, self.act + ) + self.model.to(self.device) + + @property + def num_features(self): + return self.__num_features + + @num_features.setter + def num_features(self, num_features): + self.__num_features = num_features + + def from_hyper_parameter(self, hp): + automodel = AutoGCNEncoder(self.num_features, self.last_dim, self.device) + automodel.hyper_parameter = hp + automodel.initialize() + return automodel + +class JKDecoder(BaseDecoder): + def __init__(self, num_classes, input_dims): + super().__init__() + self.out = torch.nn.Linear(sum(input_dims), num_classes) + + def forward(self, features, data): + return F.log_softmax(self.out(torch.cat(features, dim=1)), dim=1) + +class AutoJKDecoder(AutoClassifierDecoder): + def __init__(self, input_dim="auto", num_classes=None, device="auto"): + super().__init__(device) + self.num_classes = num_classes + + def initialize(self, encoder): + self.model = JKDecoder(self.num_classes, [*encoder.hidden, encoder.last_dim]) + self.model.to(self.device) + + def from_hyper_parameter_and_encoder(self, hp, encoder): + autodecoder = AutoJKDecoder(num_classes=self.num_classes) + autodecoder.initialize(encoder) + return autodecoder + + @property + def num_classes(self): + return self.__num_classes + + @num_classes.setter + def num_classes(self, num_classes): + self.__num_classes = num_classes + +cora = build_dataset_from_name("cora") +# cora = Planetoid("/home/guancy/data", "cora") + +solver = AutoNodeClassifier( + graph_models=((AutoGCNEncoder(), AutoJKDecoder()), ), + default_trainer=NodeClassificationFullTrainer( + decoder=None, + init=False, + max_epoch=200, + early_stopping_round=201, + lr=0.01, + weight_decay=0.0, + ), + hpo_module=None, + device="auto" +) + +solver.fit(cora, evaluation_method=["acc"]) +result = solver.predict(cora) +print((result == cora[0].nodes.data["y"][cora[0].nodes.data["test_mask"]].cpu().numpy()).astype('float').mean())