From 8c39d440950283eeedd4bbe2aaa366d4f4adebb2 Mon Sep 17 00:00:00 2001 From: Frozenmad Date: Sat, 18 Dec 2021 08:04:48 +0800 Subject: [PATCH] revise logic of parsing models --- autogl/solver/base.py | 38 ++++++++++++++++---- autogl/solver/classifier/graph_classifier.py | 13 ++++--- autogl/solver/classifier/link_predictor.py | 12 ++++--- autogl/solver/classifier/node_classifier.py | 12 ++++--- 4 files changed, 56 insertions(+), 19 deletions(-) diff --git a/autogl/solver/base.py b/autogl/solver/base.py index 4c88b21..a61afb4 100644 --- a/autogl/solver/base.py +++ b/autogl/solver/base.py @@ -11,7 +11,7 @@ import torch from ..module.feature import FEATURE_DICT from ..module.hpo import HPO_DICT -from ..module.model import MODEL_DICT +from ..module.model import EncoderUniversalRegistry, DecoderUniversalRegistry, ModelUniversalRegistry from ..module.nas.algorithm import NAS_ALGO_DICT from ..module.nas.estimator import NAS_ESTIMATOR_DICT from ..module.nas.space import NAS_SPACE_DICT @@ -22,11 +22,21 @@ from ..utils import get_logger LOGGER = get_logger("BaseSolver") -def _initialize_single_model(model_name, parameters=None): - if parameters: - return MODEL_DICT[model_name](**parameters) - return MODEL_DICT[model_name]() - +def _initialize_single_model(model): + encoder, decoder = None, None + if "encoder" in model: + # initialize encoder + name = model["encoder"].pop("name") + encoder = EncoderUniversalRegistry.get_encoder(name)(**model["encoder"]) + if "decoder" in model: + # initialize decoder + name = model["decoder"].pop("name") + decoder = DecoderUniversalRegistry.get_decoder(name)(**model["decoder"]) + if "name" in model: + # whole model + name = model.pop("name") + encoder = ModelUniversalRegistry.get_model(name)(**model) + return (encoder, decoder) def _parse_hp_space(spaces): if spaces is None: @@ -36,6 +46,22 @@ def _parse_hp_space(spaces): space["cutFunc"] = eval(space["cutFunc"]) return spaces +def _parse_model_hp(model): + assert isinstance(model, dict) + output = [] + if "encoder" in model and "decoder" in model: + output.append({ + "encoder": _parse_hp_space(model["encoder"].pop("hp_space", None)), + "decoder": _parse_hp_space(model["decoder"].pop("hp_space", None)), + }) + elif "encoder" in model: + output.append({ + "encoder": _parse_hp_space(model["encoder"].pop("hp_space", None)), + "decoder": None, + }) + else: + output.append(_parse_hp_space(model.pop("hp_space", None))) + return output class BaseSolver: r""" diff --git a/autogl/solver/classifier/graph_classifier.py b/autogl/solver/classifier/graph_classifier.py index 0f8347f..54e7cbc 100644 --- a/autogl/solver/classifier/graph_classifier.py +++ b/autogl/solver/classifier/graph_classifier.py @@ -12,9 +12,8 @@ import yaml from .base import BaseClassifier from ...module.feature import FEATURE_DICT -from ...module.model import BaseAutoModel, MODEL_DICT from ...module.train import TRAINER_DICT, get_feval, BaseGraphClassificationTrainer -from ..base import _initialize_single_model, _parse_hp_space +from ..base import _initialize_single_model, _parse_hp_space, _parse_model_hp from ..utils import LeaderBoard, get_dataset_labels, set_seed, get_graph_from_dataset, get_graph_node_features, convert_dataset from ...datasets import utils from ..utils import get_logger @@ -656,12 +655,16 @@ class AutoGraphClassifier(BaseClassifier): if fe_list_ele != []: solver.set_feature_module(fe_list_ele) - models = path_or_dict.pop("models", [{"name": "gin"}, {"name": "topkpool"}]) + models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}, {"name": "sage"}, {"name": "gin"}]) + # models should be a list of model + # with each element in two cases + # * a dict describing a certain model + # * a dict containing {"encoder": encoder, "decoder": decoder} model_hp_space = [ - _parse_hp_space(model.pop("hp_space", None)) for model in models + _parse_model_hp(model) for model in models ] model_list = [ - _initialize_single_model(model.pop("name"), model) for model in models + _initialize_single_model(model) for model in models ] trainer = path_or_dict.pop("trainer", None) diff --git a/autogl/solver/classifier/link_predictor.py b/autogl/solver/classifier/link_predictor.py index c654a5a..680c455 100644 --- a/autogl/solver/classifier/link_predictor.py +++ b/autogl/solver/classifier/link_predictor.py @@ -12,7 +12,7 @@ import yaml from ...data import Data from .base import BaseClassifier -from ..base import _parse_hp_space, _initialize_single_model +from ..base import _parse_hp_space, _initialize_single_model, _parse_model_hp from ...module.feature import FEATURE_DICT from ...module.train import TRAINER_DICT, BaseLinkPredictionTrainer from ...module.train import get_feval @@ -703,12 +703,16 @@ class AutoLinkPredictor(BaseClassifier): if fe_list_ele != []: solver.set_feature_module(fe_list_ele) - models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}]) + models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}, {"name": "sage"}, {"name": "gin"}]) + # models should be a list of model + # with each element in two cases + # * a dict describing a certain model + # * a dict containing {"encoder": encoder, "decoder": decoder} model_hp_space = [ - _parse_hp_space(model.pop("hp_space", None)) for model in models + _parse_model_hp(model) for model in models ] model_list = [ - _initialize_single_model(model.pop("name"), model) for model in models + _initialize_single_model(model) for model in models ] trainer = path_or_dict.pop("trainer", None) diff --git a/autogl/solver/classifier/node_classifier.py b/autogl/solver/classifier/node_classifier.py index 881fa9c..e251e12 100644 --- a/autogl/solver/classifier/node_classifier.py +++ b/autogl/solver/classifier/node_classifier.py @@ -12,7 +12,7 @@ import numpy as np import yaml from .base import BaseClassifier -from ..base import _parse_hp_space, _initialize_single_model +from ..base import _parse_hp_space, _initialize_single_model, _parse_model_hp from ...module.feature import FEATURE_DICT from ...module.model import BaseEncoderMaintainer, BaseDecoderMaintainer, BaseAutoModel from ...module.train import TRAINER_DICT, BaseNodeClassificationTrainer @@ -732,12 +732,16 @@ class AutoNodeClassifier(BaseClassifier): if fe_list_ele != []: solver.set_feature_module(fe_list_ele) - models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}]) + models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}, {"name": "sage"}, {"name": "gin"}]) + # models should be a list of model + # with each element in two cases + # * a dict describing a certain model + # * a dict containing {"encoder": encoder, "decoder": decoder} model_hp_space = [ - _parse_hp_space(model.pop("hp_space", None)) for model in models + _parse_model_hp(model) for model in models ] model_list = [ - _initialize_single_model(model.pop("name"), model) for model in models + _initialize_single_model(model) for model in models ] trainer = path_or_dict.pop("trainer", None)