Browse Source

revise logics of graph solver

tags/v0.3.1
Frozenmad 4 years ago
parent
commit
646cb6a13d
3 changed files with 20 additions and 20 deletions
  1. +6
    -6
      autogl/solver/classifier/graph_classifier.py
  2. +5
    -5
      autogl/solver/classifier/link_predictor.py
  3. +9
    -9
      autogl/solver/utils.py

+ 6
- 6
autogl/solver/classifier/graph_classifier.py View File

@@ -21,7 +21,7 @@ from ..utils import get_logger
from ...backend import DependentBackend

LOGGER = get_logger("GraphClassifier")
__backend = DependentBackend.get_backend_name()
BACKEND = DependentBackend.get_backend_name()

class AutoGraphClassifier(BaseClassifier):
"""
@@ -277,6 +277,8 @@ class AutoGraphClassifier(BaseClassifier):

set_seed(seed)

num_classes = dataset.num_classes if BACKEND == 'pyg' else dataset.gclasses

if time_limit < 0:
time_limit = 3600 * 24
time_begin = time.time()
@@ -286,8 +288,7 @@ class AutoGraphClassifier(BaseClassifier):
if hasattr(dataset, "metric"):
evaluation_method = [dataset.metric]
else:
num_of_label = dataset.num_classes
if num_of_label == 2:
if num_classes == 2:
evaluation_method = ["auc"]
else:
evaluation_method = ["acc"]
@@ -339,7 +340,6 @@ class AutoGraphClassifier(BaseClassifier):
" node features."
)
num_features = feat.size(-1)
num_classes = dataset.num_classes if __backend == 'pyg' else dataset.num_labels

# initialize graph networks
self._init_graph_module(
@@ -350,9 +350,9 @@ class AutoGraphClassifier(BaseClassifier):
feval=evaluator_list,
device=self.runtime_device,
loss="cross_entropy" if not hasattr(dataset, "loss") else dataset.loss,
num_graph_features=0
num_graph_features=(0
if not hasattr(dataset.data, "gf")
else dataset.data.gf.size(1),
else dataset.data.gf.size(1)) if BACKEND == 'pyg' else 0,
)

# currently disabled


+ 5
- 5
autogl/solver/classifier/link_predictor.py View File

@@ -22,7 +22,7 @@ from ..utils import get_logger
from ...backend import DependentBackend

LOGGER = get_logger("LinkPredictor")
__backend = DependentBackend.get_backend_name()
BACKEND = DependentBackend.get_backend_name()

class AutoLinkPredictor(BaseClassifier):
"""
@@ -283,7 +283,7 @@ class AutoLinkPredictor(BaseClassifier):
if train_split is not None and val_split is not None:
utils.split_edges(dataset, train_split, val_split)
else:
if __backend == 'pyg':
if BACKEND == 'pyg':
assert all(
[
hasattr(graph_data, f"{name}")
@@ -300,7 +300,7 @@ class AutoLinkPredictor(BaseClassifier):
"The dataset has no default train/val split! Please manually pass "
"train and val ratio."
)
elif __backend == 'dgl':
elif BACKEND == 'dgl':
assert hasattr(graph_data, 'edata') and "train_mask" in graph_data.edata and "val_mask" in graph_data.edata, (
"The dataset has no default train/val split! Please manually pass "
"train and val ratio."
@@ -374,7 +374,7 @@ class AutoLinkPredictor(BaseClassifier):

# fit the ensemble model
if self.ensemble_module is not None:
if __backend == 'pyg':
if BACKEND == 'pyg':
pos_edge_index, neg_edge_index = (
self.dataset[0].val_pos_edge_index,
self.dataset[0].val_neg_edge_index,
@@ -382,7 +382,7 @@ class AutoLinkPredictor(BaseClassifier):
E = pos_edge_index.size(1) + neg_edge_index.size(1)
link_labels = torch.zeros(E, dtype=torch.float)
link_labels[: pos_edge_index.size(1)] = 1.0
elif __backend == 'dgl':
elif BACKEND == 'dgl':
val_mask = self.dataset[0].edata["val_mask"]
val_index = torch.nonzero(val_mask, as_tuple=False).squeeze()
link_labels = self.dataset[0].edata['etype'][val_index]


+ 9
- 9
autogl/solver/utils.py View File

@@ -16,7 +16,7 @@ from ..utils import get_logger

LOGGER = get_logger("LeaderBoard")

__backend = DependentBackend.get_backend_name()
BACKEND = DependentBackend.get_backend_name()

class LeaderBoard:
"""
@@ -179,36 +179,36 @@ class LeaderBoard:
)

def get_graph_from_dataset(dataset, graph_id=0):
if __backend == 'pyg': return dataset[graph_id]
if BACKEND == 'pyg': return dataset[graph_id]
return dataset.graph[graph_id]

def get_graph_node_number(graph):
if __backend == 'pyg':
if BACKEND == 'pyg':
size = graph.x.shape[0]
else:
size = graph.num_nodes()
return size

def get_graph_node_features(graph):
if __backend == 'pyg' and hasattr(graph, 'x'):
if BACKEND == 'pyg' and hasattr(graph, 'x'):
return graph.x
elif __backend == 'dgl' and 'feat' in graph.ndata:
elif BACKEND == 'dgl' and 'feat' in graph.ndata:
return graph.ndata['feat']
return None

def get_graph_masks(graph, mask='train'):
if __backend == 'pyg' and hasattr(graph, f'{mask}_mask'):
if BACKEND == 'pyg' and hasattr(graph, f'{mask}_mask'):
return getattr(graph, f'{mask}_mask')
if __backend == 'dgl' and f'{mask}_mask' in graph.ndata:
if BACKEND == 'dgl' and f'{mask}_mask' in graph.ndata:
return graph.ndata[f'{mask}_mask']
return None

def get_graph_labels(graph):
if __backend == 'pyg': return graph.y
if BACKEND == 'pyg': return graph.y
return graph.ndata['label']

def get_dataset_labels(dataset):
if __backend == 'pyg':
if BACKEND == 'pyg':
return dataset.data.y
else:
return torch.LongTensor([d[1] for d in dataset])


Loading…
Cancel
Save