Browse Source

Merge branch 'fix_examples' into revise-gin-dgl-encoder-decoder

tags/v0.3.1
Frozenmad 4 years ago
parent
commit
0c8b39fc7d
10 changed files with 54 additions and 24 deletions
  1. +2
    -2
      autogl/module/model/dgl/_model_registry.py
  2. +2
    -2
      autogl/module/model/pyg/_model_registry.py
  3. +13
    -4
      autogl/module/train/__init__.py
  4. +3
    -2
      autogl/module/train/link_prediction_full.py
  5. +11
    -10
      autogl/module/train/node_classification_trainer/__init__.py
  6. +3
    -1
      autogl/solver/classifier/graph_classifier.py
  7. +2
    -0
      autogl/solver/classifier/hetero/node_classifier.py
  8. +2
    -0
      autogl/solver/classifier/link_predictor.py
  9. +2
    -0
      autogl/solver/classifier/node_classifier.py
  10. +14
    -3
      examples/link_prediction.py

+ 2
- 2
autogl/module/model/dgl/_model_registry.py View File

@@ -22,7 +22,7 @@ class ModelUniversalRegistry:
@classmethod
def get_model(cls, name: str) -> _typing.Type[BaseAutoModel]:
if type(name) != str:
raise TypeError
raise TypeError(f"Expect model type str, but get {type(name)}.")
if name not in MODEL_DICT:
raise KeyError
raise KeyError(f"Do not support {name} model in pyg backend")
return MODEL_DICT.get(name)

+ 2
- 2
autogl/module/model/pyg/_model_registry.py View File

@@ -22,7 +22,7 @@ class ModelUniversalRegistry:
@classmethod
def get_model(cls, name: str) -> _typing.Type[BaseAutoModel]:
if type(name) != str:
raise TypeError
raise TypeError(f"Expect model type str, but get {type(name)}.")
if name not in MODEL_DICT:
raise KeyError
raise KeyError(f"Do not support {name} model in pyg backend")
return MODEL_DICT.get(name)

+ 13
- 4
autogl/module/train/__init__.py View File

@@ -29,7 +29,12 @@ from .graph_classification_full import GraphClassificationFullTrainer
from .node_classification_full import NodeClassificationFullTrainer
from .link_prediction_full import LinkPredictionTrainer
from .node_classification_het import NodeClassificationHetTrainer
from .node_classification_trainer import *
if DependentBackend.is_pyg():
from .node_classification_trainer import (
NodeClassificationGraphSAINTTrainer,
NodeClassificationLayerDependentImportanceSamplingTrainer,
NodeClassificationNeighborSamplingTrainer
)
from .evaluation import get_feval, Acc, Auc, Logloss, Mrr, MicroF1

__all__ = [
@@ -42,9 +47,6 @@ __all__ = [
"GraphClassificationFullTrainer",
"NodeClassificationFullTrainer",
"NodeClassificationHetTrainer",
"NodeClassificationGraphSAINTTrainer",
"NodeClassificationLayerDependentImportanceSamplingTrainer",
"NodeClassificationNeighborSamplingTrainer",
"LinkPredictionTrainer",
"Acc",
"Auc",
@@ -53,3 +55,10 @@ __all__ = [
"MicroF1",
"get_feval",
]

if DependentBackend.is_pyg():
__all__.extend([
"NodeClassificationGraphSAINTTrainer",
"NodeClassificationLayerDependentImportanceSamplingTrainer",
"NodeClassificationNeighborSamplingTrainer",
])

+ 3
- 2
autogl/module/train/link_prediction_full.py View File

@@ -7,8 +7,6 @@ from .evaluation import Auc, EVALUATE_DICT
from .base import EarlyStopping, BaseLinkPredictionTrainer
from typing import Union, Tuple
from copy import deepcopy
from torch_geometric.utils import negative_sampling
# from ...datasets.utils import negative_sampling
from ...utils import get_logger

from ...backend import DependentBackend
@@ -198,6 +196,9 @@ class LinkPredictionTrainer(BaseLinkPredictionTrainer):
try:
neg_edge_index = data.train_neg_edge_index
except:
from torch_geometric.utils import negative_sampling
# from ...datasets.utils import negative_sampling

neg_edge_index = negative_sampling(
edge_index=data.train_pos_edge_index,
num_nodes=data.num_nodes,


+ 11
- 10
autogl/module/train/node_classification_trainer/__init__.py View File

@@ -1,14 +1,15 @@
from ....backend import DependentBackend

if DependentBackend.is_pyg():
from .node_classification_sampled_trainer import *
from .node_classification_sampled_trainer import (
NodeClassificationGraphSAINTTrainer,
NodeClassificationLayerDependentImportanceSamplingTrainer,
NodeClassificationNeighborSamplingTrainer
)
__all__ = [
"NodeClassificationGraphSAINTTrainer",
"NodeClassificationLayerDependentImportanceSamplingTrainer",
"NodeClassificationNeighborSamplingTrainer"
]
else:
NodeClassificationGraphSAINTTrainer = None
NodeClassificationLayerDependentImportanceSamplingTrainer = None
NodeClassificationNeighborSamplingTrainer = None
pass





__all__ = []

+ 3
- 1
autogl/solver/classifier/graph_classifier.py View File

@@ -124,6 +124,8 @@ class AutoGraphClassifier(BaseClassifier):
else self._default_trainer[i]
)
if isinstance(trainer, str):
if trainer not in TRAINER_DICT:
raise KeyError(f"Does not support trainer {trainer}")
trainer = TRAINER_DICT[trainer]()
if isinstance(model, (tuple, list)):
trainer.encoder = model[0]
@@ -219,7 +221,7 @@ class AutoGraphClassifier(BaseClassifier):

set_seed(seed)

num_classes = max(get_dataset_labels(dataset)) + 1
num_classes = get_dataset_labels(dataset).max().item() + 1

if time_limit < 0:
time_limit = 3600 * 24


+ 2
- 0
autogl/solver/classifier/hetero/node_classifier.py View File

@@ -115,6 +115,8 @@ class AutoHeteroNodeClassifier(BaseClassifier):
else self._default_trainer[i]
)
if isinstance(trainer, str):
if trainer not in TRAINER_DICT:
raise KeyError(f"Does not support trainer {trainer}")
trainer = TRAINER_DICT[trainer]()
if isinstance(model, (tuple, list)):
trainer.encoder = model[0]


+ 2
- 0
autogl/solver/classifier/link_predictor.py View File

@@ -134,6 +134,8 @@ class AutoLinkPredictor(BaseClassifier):
else self._default_trainer[i]
)
if isinstance(trainer, str):
if trainer not in TRAINER_DICT:
raise KeyError(f"Does not support trainer {trainer}")
trainer = TRAINER_DICT[trainer]()
if isinstance(model, (tuple, list)):
trainer.encoder = model[0]


+ 2
- 0
autogl/solver/classifier/node_classifier.py View File

@@ -131,6 +131,8 @@ class AutoNodeClassifier(BaseClassifier):
else self._default_trainer[i]
)
if isinstance(trainer, str):
if trainer not in TRAINER_DICT:
raise KeyError(f"Does not support trainer {trainer}")
trainer = TRAINER_DICT[trainer]()
if isinstance(model, (tuple, list)):
trainer.encoder = model[0]


+ 14
- 3
examples/link_prediction.py View File

@@ -1,6 +1,8 @@
from autogl.datasets import build_dataset_from_name
from autogl.solver.classifier.link_predictor import AutoLinkPredictor
from autogl.module.train.evaluation import Auc
from autogl.datasets.utils import split_edges
from autogl.backend import DependentBackend
import yaml
import random
import torch
@@ -57,6 +59,17 @@ if __name__ == "__main__":

dataset = build_dataset_from_name(args.dataset)

# split the edges for dataset
dataset = split_edges(dataset, 0.8, 0.05)

# add self-loop
if DependentBackend.is_dgl():
import dgl
# add self loop to 0
data = list(dataset[0])
data[0] = dgl.add_self_loop(data[0])
dataset = [data]

configs = yaml.load(open(args.configs, "r").read(), Loader=yaml.FullLoader)
configs["hpo"]["name"] = args.hpo
configs["hpo"]["max_evals"] = args.max_eval
@@ -67,9 +80,7 @@ if __name__ == "__main__":
dataset,
time_limit=3600,
evaluation_method=[Auc],
seed=seed,
train_split=0.85,
val_split=0.05,
seed=seed
)
autoClassifier.get_leaderboard().show()



Loading…
Cancel
Save