| @@ -21,7 +21,7 @@ from ..utils import get_logger | |||
| from ...backend import DependentBackend | |||
| LOGGER = get_logger("GraphClassifier") | |||
| __backend = DependentBackend.get_backend_name() | |||
| BACKEND = DependentBackend.get_backend_name() | |||
| class AutoGraphClassifier(BaseClassifier): | |||
| """ | |||
| @@ -277,6 +277,8 @@ class AutoGraphClassifier(BaseClassifier): | |||
| set_seed(seed) | |||
| num_classes = dataset.num_classes if BACKEND == 'pyg' else dataset.gclasses | |||
| if time_limit < 0: | |||
| time_limit = 3600 * 24 | |||
| time_begin = time.time() | |||
| @@ -286,8 +288,7 @@ class AutoGraphClassifier(BaseClassifier): | |||
| if hasattr(dataset, "metric"): | |||
| evaluation_method = [dataset.metric] | |||
| else: | |||
| num_of_label = dataset.num_classes | |||
| if num_of_label == 2: | |||
| if num_classes == 2: | |||
| evaluation_method = ["auc"] | |||
| else: | |||
| evaluation_method = ["acc"] | |||
| @@ -339,7 +340,6 @@ class AutoGraphClassifier(BaseClassifier): | |||
| " node features." | |||
| ) | |||
| num_features = feat.size(-1) | |||
| num_classes = dataset.num_classes if __backend == 'pyg' else dataset.num_labels | |||
| # initialize graph networks | |||
| self._init_graph_module( | |||
| @@ -350,9 +350,9 @@ class AutoGraphClassifier(BaseClassifier): | |||
| feval=evaluator_list, | |||
| device=self.runtime_device, | |||
| loss="cross_entropy" if not hasattr(dataset, "loss") else dataset.loss, | |||
| num_graph_features=0 | |||
| num_graph_features=(0 | |||
| if not hasattr(dataset.data, "gf") | |||
| else dataset.data.gf.size(1), | |||
| else dataset.data.gf.size(1)) if BACKEND == 'pyg' else 0, | |||
| ) | |||
| # currently disabled | |||
| @@ -22,7 +22,7 @@ from ..utils import get_logger | |||
| from ...backend import DependentBackend | |||
| LOGGER = get_logger("LinkPredictor") | |||
| __backend = DependentBackend.get_backend_name() | |||
| BACKEND = DependentBackend.get_backend_name() | |||
| class AutoLinkPredictor(BaseClassifier): | |||
| """ | |||
| @@ -283,7 +283,7 @@ class AutoLinkPredictor(BaseClassifier): | |||
| if train_split is not None and val_split is not None: | |||
| utils.split_edges(dataset, train_split, val_split) | |||
| else: | |||
| if __backend == 'pyg': | |||
| if BACKEND == 'pyg': | |||
| assert all( | |||
| [ | |||
| hasattr(graph_data, f"{name}") | |||
| @@ -300,7 +300,7 @@ class AutoLinkPredictor(BaseClassifier): | |||
| "The dataset has no default train/val split! Please manually pass " | |||
| "train and val ratio." | |||
| ) | |||
| elif __backend == 'dgl': | |||
| elif BACKEND == 'dgl': | |||
| assert hasattr(graph_data, 'edata') and "train_mask" in graph_data.edata and "val_mask" in graph_data.edata, ( | |||
| "The dataset has no default train/val split! Please manually pass " | |||
| "train and val ratio." | |||
| @@ -374,7 +374,7 @@ class AutoLinkPredictor(BaseClassifier): | |||
| # fit the ensemble model | |||
| if self.ensemble_module is not None: | |||
| if __backend == 'pyg': | |||
| if BACKEND == 'pyg': | |||
| pos_edge_index, neg_edge_index = ( | |||
| self.dataset[0].val_pos_edge_index, | |||
| self.dataset[0].val_neg_edge_index, | |||
| @@ -382,7 +382,7 @@ class AutoLinkPredictor(BaseClassifier): | |||
| E = pos_edge_index.size(1) + neg_edge_index.size(1) | |||
| link_labels = torch.zeros(E, dtype=torch.float) | |||
| link_labels[: pos_edge_index.size(1)] = 1.0 | |||
| elif __backend == 'dgl': | |||
| elif BACKEND == 'dgl': | |||
| val_mask = self.dataset[0].edata["val_mask"] | |||
| val_index = torch.nonzero(val_mask, as_tuple=False).squeeze() | |||
| link_labels = self.dataset[0].edata['etype'][val_index] | |||
| @@ -16,7 +16,7 @@ from ..utils import get_logger | |||
| LOGGER = get_logger("LeaderBoard") | |||
| __backend = DependentBackend.get_backend_name() | |||
| BACKEND = DependentBackend.get_backend_name() | |||
| class LeaderBoard: | |||
| """ | |||
| @@ -179,36 +179,36 @@ class LeaderBoard: | |||
| ) | |||
| def get_graph_from_dataset(dataset, graph_id=0): | |||
| if __backend == 'pyg': return dataset[graph_id] | |||
| if BACKEND == 'pyg': return dataset[graph_id] | |||
| return dataset.graph[graph_id] | |||
| def get_graph_node_number(graph): | |||
| if __backend == 'pyg': | |||
| if BACKEND == 'pyg': | |||
| size = graph.x.shape[0] | |||
| else: | |||
| size = graph.num_nodes() | |||
| return size | |||
| def get_graph_node_features(graph): | |||
| if __backend == 'pyg' and hasattr(graph, 'x'): | |||
| if BACKEND == 'pyg' and hasattr(graph, 'x'): | |||
| return graph.x | |||
| elif __backend == 'dgl' and 'feat' in graph.ndata: | |||
| elif BACKEND == 'dgl' and 'feat' in graph.ndata: | |||
| return graph.ndata['feat'] | |||
| return None | |||
| def get_graph_masks(graph, mask='train'): | |||
| if __backend == 'pyg' and hasattr(graph, f'{mask}_mask'): | |||
| if BACKEND == 'pyg' and hasattr(graph, f'{mask}_mask'): | |||
| return getattr(graph, f'{mask}_mask') | |||
| if __backend == 'dgl' and f'{mask}_mask' in graph.ndata: | |||
| if BACKEND == 'dgl' and f'{mask}_mask' in graph.ndata: | |||
| return graph.ndata[f'{mask}_mask'] | |||
| return None | |||
| def get_graph_labels(graph): | |||
| if __backend == 'pyg': return graph.y | |||
| if BACKEND == 'pyg': return graph.y | |||
| return graph.ndata['label'] | |||
| def get_dataset_labels(dataset): | |||
| if __backend == 'pyg': | |||
| if BACKEND == 'pyg': | |||
| return dataset.data.y | |||
| else: | |||
| return torch.LongTensor([d[1] for d in dataset]) | |||