| @@ -11,7 +11,7 @@ import torch | |||
| from ..module.feature import FEATURE_DICT | |||
| from ..module.hpo import HPO_DICT | |||
| from ..module.model import MODEL_DICT | |||
| from ..module.model import EncoderUniversalRegistry, DecoderUniversalRegistry, ModelUniversalRegistry | |||
| from ..module.nas.algorithm import NAS_ALGO_DICT | |||
| from ..module.nas.estimator import NAS_ESTIMATOR_DICT | |||
| from ..module.nas.space import NAS_SPACE_DICT | |||
| @@ -22,11 +22,21 @@ from ..utils import get_logger | |||
| LOGGER = get_logger("BaseSolver") | |||
| def _initialize_single_model(model_name, parameters=None): | |||
| if parameters: | |||
| return MODEL_DICT[model_name](**parameters) | |||
| return MODEL_DICT[model_name]() | |||
| def _initialize_single_model(model): | |||
| encoder, decoder = None, None | |||
| if "encoder" in model: | |||
| # initialize encoder | |||
| name = model["encoder"].pop("name") | |||
| encoder = EncoderUniversalRegistry.get_encoder(name)(**model["encoder"]) | |||
| if "decoder" in model: | |||
| # initialize decoder | |||
| name = model["decoder"].pop("name") | |||
| decoder = DecoderUniversalRegistry.get_decoder(name)(**model["decoder"]) | |||
| if "name" in model: | |||
| # whole model | |||
| name = model.pop("name") | |||
| encoder = ModelUniversalRegistry.get_model(name)(**model) | |||
| return (encoder, decoder) | |||
| def _parse_hp_space(spaces): | |||
| if spaces is None: | |||
| @@ -36,6 +46,22 @@ def _parse_hp_space(spaces): | |||
| space["cutFunc"] = eval(space["cutFunc"]) | |||
| return spaces | |||
| def _parse_model_hp(model): | |||
| assert isinstance(model, dict) | |||
| output = [] | |||
| if "encoder" in model and "decoder" in model: | |||
| output.append({ | |||
| "encoder": _parse_hp_space(model["encoder"].pop("hp_space", None)), | |||
| "decoder": _parse_hp_space(model["decoder"].pop("hp_space", None)), | |||
| }) | |||
| elif "encoder" in model: | |||
| output.append({ | |||
| "encoder": _parse_hp_space(model["encoder"].pop("hp_space", None)), | |||
| "decoder": None, | |||
| }) | |||
| else: | |||
| output.append(_parse_hp_space(model.pop("hp_space", None))) | |||
| return output | |||
| class BaseSolver: | |||
| r""" | |||
| @@ -12,9 +12,8 @@ import yaml | |||
| from .base import BaseClassifier | |||
| from ...module.feature import FEATURE_DICT | |||
| from ...module.model import BaseAutoModel, MODEL_DICT | |||
| from ...module.train import TRAINER_DICT, get_feval, BaseGraphClassificationTrainer | |||
| from ..base import _initialize_single_model, _parse_hp_space | |||
| from ..base import _initialize_single_model, _parse_hp_space, _parse_model_hp | |||
| from ..utils import LeaderBoard, get_dataset_labels, set_seed, get_graph_from_dataset, get_graph_node_features, convert_dataset | |||
| from ...datasets import utils | |||
| from ..utils import get_logger | |||
| @@ -656,12 +655,16 @@ class AutoGraphClassifier(BaseClassifier): | |||
| if fe_list_ele != []: | |||
| solver.set_feature_module(fe_list_ele) | |||
| models = path_or_dict.pop("models", [{"name": "gin"}, {"name": "topkpool"}]) | |||
| models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}, {"name": "sage"}, {"name": "gin"}]) | |||
| # models should be a list of model | |||
| # with each element in two cases | |||
| # * a dict describing a certain model | |||
| # * a dict containing {"encoder": encoder, "decoder": decoder} | |||
| model_hp_space = [ | |||
| _parse_hp_space(model.pop("hp_space", None)) for model in models | |||
| _parse_model_hp(model) for model in models | |||
| ] | |||
| model_list = [ | |||
| _initialize_single_model(model.pop("name"), model) for model in models | |||
| _initialize_single_model(model) for model in models | |||
| ] | |||
| trainer = path_or_dict.pop("trainer", None) | |||
| @@ -12,7 +12,7 @@ import yaml | |||
| from ...data import Data | |||
| from .base import BaseClassifier | |||
| from ..base import _parse_hp_space, _initialize_single_model | |||
| from ..base import _parse_hp_space, _initialize_single_model, _parse_model_hp | |||
| from ...module.feature import FEATURE_DICT | |||
| from ...module.train import TRAINER_DICT, BaseLinkPredictionTrainer | |||
| from ...module.train import get_feval | |||
| @@ -703,12 +703,16 @@ class AutoLinkPredictor(BaseClassifier): | |||
| if fe_list_ele != []: | |||
| solver.set_feature_module(fe_list_ele) | |||
| models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}]) | |||
| models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}, {"name": "sage"}, {"name": "gin"}]) | |||
| # models should be a list of model | |||
| # with each element in two cases | |||
| # * a dict describing a certain model | |||
| # * a dict containing {"encoder": encoder, "decoder": decoder} | |||
| model_hp_space = [ | |||
| _parse_hp_space(model.pop("hp_space", None)) for model in models | |||
| _parse_model_hp(model) for model in models | |||
| ] | |||
| model_list = [ | |||
| _initialize_single_model(model.pop("name"), model) for model in models | |||
| _initialize_single_model(model) for model in models | |||
| ] | |||
| trainer = path_or_dict.pop("trainer", None) | |||
| @@ -12,7 +12,7 @@ import numpy as np | |||
| import yaml | |||
| from .base import BaseClassifier | |||
| from ..base import _parse_hp_space, _initialize_single_model | |||
| from ..base import _parse_hp_space, _initialize_single_model, _parse_model_hp | |||
| from ...module.feature import FEATURE_DICT | |||
| from ...module.model import BaseEncoderMaintainer, BaseDecoderMaintainer, BaseAutoModel | |||
| from ...module.train import TRAINER_DICT, BaseNodeClassificationTrainer | |||
| @@ -732,12 +732,16 @@ class AutoNodeClassifier(BaseClassifier): | |||
| if fe_list_ele != []: | |||
| solver.set_feature_module(fe_list_ele) | |||
| models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}]) | |||
| models = path_or_dict.pop("models", [{"name": "gcn"}, {"name": "gat"}, {"name": "sage"}, {"name": "gin"}]) | |||
| # models should be a list of model | |||
| # with each element in two cases | |||
| # * a dict describing a certain model | |||
| # * a dict containing {"encoder": encoder, "decoder": decoder} | |||
| model_hp_space = [ | |||
| _parse_hp_space(model.pop("hp_space", None)) for model in models | |||
| _parse_model_hp(model) for model in models | |||
| ] | |||
| model_list = [ | |||
| _initialize_single_model(model.pop("name"), model) for model in models | |||
| _initialize_single_model(model) for model in models | |||
| ] | |||
| trainer = path_or_dict.pop("trainer", None) | |||