From 646cb6a13da2c36d0d3f17c442fc89e5e1dfbea8 Mon Sep 17 00:00:00 2001 From: Frozenmad Date: Tue, 19 Oct 2021 02:19:02 +0000 Subject: [PATCH] revise logics of graph solver --- autogl/solver/classifier/graph_classifier.py | 12 ++++++------ autogl/solver/classifier/link_predictor.py | 10 +++++----- autogl/solver/utils.py | 18 +++++++++--------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/autogl/solver/classifier/graph_classifier.py b/autogl/solver/classifier/graph_classifier.py index 8d17857..b88b738 100644 --- a/autogl/solver/classifier/graph_classifier.py +++ b/autogl/solver/classifier/graph_classifier.py @@ -21,7 +21,7 @@ from ..utils import get_logger from ...backend import DependentBackend LOGGER = get_logger("GraphClassifier") -__backend = DependentBackend.get_backend_name() +BACKEND = DependentBackend.get_backend_name() class AutoGraphClassifier(BaseClassifier): """ @@ -277,6 +277,8 @@ class AutoGraphClassifier(BaseClassifier): set_seed(seed) + num_classes = dataset.num_classes if BACKEND == 'pyg' else dataset.gclasses + if time_limit < 0: time_limit = 3600 * 24 time_begin = time.time() @@ -286,8 +288,7 @@ class AutoGraphClassifier(BaseClassifier): if hasattr(dataset, "metric"): evaluation_method = [dataset.metric] else: - num_of_label = dataset.num_classes - if num_of_label == 2: + if num_classes == 2: evaluation_method = ["auc"] else: evaluation_method = ["acc"] @@ -339,7 +340,6 @@ class AutoGraphClassifier(BaseClassifier): " node features." ) num_features = feat.size(-1) - num_classes = dataset.num_classes if __backend == 'pyg' else dataset.num_labels # initialize graph networks self._init_graph_module( @@ -350,9 +350,9 @@ class AutoGraphClassifier(BaseClassifier): feval=evaluator_list, device=self.runtime_device, loss="cross_entropy" if not hasattr(dataset, "loss") else dataset.loss, - num_graph_features=0 + num_graph_features=(0 if not hasattr(dataset.data, "gf") - else dataset.data.gf.size(1), + else dataset.data.gf.size(1)) if BACKEND == 'pyg' else 0, ) # currently disabled diff --git a/autogl/solver/classifier/link_predictor.py b/autogl/solver/classifier/link_predictor.py index 4a4a2f1..4b18680 100644 --- a/autogl/solver/classifier/link_predictor.py +++ b/autogl/solver/classifier/link_predictor.py @@ -22,7 +22,7 @@ from ..utils import get_logger from ...backend import DependentBackend LOGGER = get_logger("LinkPredictor") -__backend = DependentBackend.get_backend_name() +BACKEND = DependentBackend.get_backend_name() class AutoLinkPredictor(BaseClassifier): """ @@ -283,7 +283,7 @@ class AutoLinkPredictor(BaseClassifier): if train_split is not None and val_split is not None: utils.split_edges(dataset, train_split, val_split) else: - if __backend == 'pyg': + if BACKEND == 'pyg': assert all( [ hasattr(graph_data, f"{name}") @@ -300,7 +300,7 @@ class AutoLinkPredictor(BaseClassifier): "The dataset has no default train/val split! Please manually pass " "train and val ratio." ) - elif __backend == 'dgl': + elif BACKEND == 'dgl': assert hasattr(graph_data, 'edata') and "train_mask" in graph_data.edata and "val_mask" in graph_data.edata, ( "The dataset has no default train/val split! Please manually pass " "train and val ratio." @@ -374,7 +374,7 @@ class AutoLinkPredictor(BaseClassifier): # fit the ensemble model if self.ensemble_module is not None: - if __backend == 'pyg': + if BACKEND == 'pyg': pos_edge_index, neg_edge_index = ( self.dataset[0].val_pos_edge_index, self.dataset[0].val_neg_edge_index, @@ -382,7 +382,7 @@ class AutoLinkPredictor(BaseClassifier): E = pos_edge_index.size(1) + neg_edge_index.size(1) link_labels = torch.zeros(E, dtype=torch.float) link_labels[: pos_edge_index.size(1)] = 1.0 - elif __backend == 'dgl': + elif BACKEND == 'dgl': val_mask = self.dataset[0].edata["val_mask"] val_index = torch.nonzero(val_mask, as_tuple=False).squeeze() link_labels = self.dataset[0].edata['etype'][val_index] diff --git a/autogl/solver/utils.py b/autogl/solver/utils.py index 51a1555..dd019a5 100644 --- a/autogl/solver/utils.py +++ b/autogl/solver/utils.py @@ -16,7 +16,7 @@ from ..utils import get_logger LOGGER = get_logger("LeaderBoard") -__backend = DependentBackend.get_backend_name() +BACKEND = DependentBackend.get_backend_name() class LeaderBoard: """ @@ -179,36 +179,36 @@ class LeaderBoard: ) def get_graph_from_dataset(dataset, graph_id=0): - if __backend == 'pyg': return dataset[graph_id] + if BACKEND == 'pyg': return dataset[graph_id] return dataset.graph[graph_id] def get_graph_node_number(graph): - if __backend == 'pyg': + if BACKEND == 'pyg': size = graph.x.shape[0] else: size = graph.num_nodes() return size def get_graph_node_features(graph): - if __backend == 'pyg' and hasattr(graph, 'x'): + if BACKEND == 'pyg' and hasattr(graph, 'x'): return graph.x - elif __backend == 'dgl' and 'feat' in graph.ndata: + elif BACKEND == 'dgl' and 'feat' in graph.ndata: return graph.ndata['feat'] return None def get_graph_masks(graph, mask='train'): - if __backend == 'pyg' and hasattr(graph, f'{mask}_mask'): + if BACKEND == 'pyg' and hasattr(graph, f'{mask}_mask'): return getattr(graph, f'{mask}_mask') - if __backend == 'dgl' and f'{mask}_mask' in graph.ndata: + if BACKEND == 'dgl' and f'{mask}_mask' in graph.ndata: return graph.ndata[f'{mask}_mask'] return None def get_graph_labels(graph): - if __backend == 'pyg': return graph.y + if BACKEND == 'pyg': return graph.y return graph.ndata['label'] def get_dataset_labels(dataset): - if __backend == 'pyg': + if BACKEND == 'pyg': return dataset.data.y else: return torch.LongTensor([d[1] for d in dataset])