| @@ -40,7 +40,10 @@ class _OGBNDatasetUtil(_OGBDatasetUtil): | |||||
| edge_feat = torch.tensor(edge_feat) | edge_feat = torch.tensor(edge_feat) | ||||
| edge_index = SparseTensor(row=torch.tensor(edge_index[0]), col=torch.tensor(edge_index[1]), value=edge_feat, sparse_sizes=(num_nodes, num_nodes)) | edge_index = SparseTensor(row=torch.tensor(edge_index[0]), col=torch.tensor(edge_index[1]), value=edge_feat, sparse_sizes=(num_nodes, num_nodes)) | ||||
| _, _, value = edge_index.coo() | _, _, value = edge_index.coo() | ||||
| ogbn_data['edge_feat'] = value.cpu().detach().numpy() | |||||
| if value is not None: | |||||
| ogbn_data['edge_feat'] = value.cpu().detach().numpy() | |||||
| else: | |||||
| ogbn_data['edge_feat'] = edge_feat | |||||
| edge_index = edge_index.to_symmetric() | edge_index = edge_index.to_symmetric() | ||||
| row, col, _ = edge_index.coo() | row, col, _ = edge_index.coo() | ||||
| edge_index = np.array([row.cpu().detach().numpy(), col.cpu().detach().numpy()]) | edge_index = np.array([row.cpu().detach().numpy(), col.cpu().detach().numpy()]) | ||||
| @@ -18,7 +18,7 @@ from torch.optim.lr_scheduler import ( | |||||
| ReduceLROnPlateau, | ReduceLROnPlateau, | ||||
| ) | ) | ||||
| from dig.sslgraph.method.contrastive.objectives import NCE_loss, JSE_loss | |||||
| from .losses import NTXent_loss | |||||
| from .utils import get_view_by_name | from .utils import get_view_by_name | ||||
| from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer | from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer | ||||
| @@ -49,7 +49,7 @@ class BaseContrastiveTrainer(BaseTrainer): | |||||
| feval: _typing.Union[ | feval: _typing.Union[ | ||||
| _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] | ||||
| ] = (Acc,), | ] = (Acc,), | ||||
| loss: Union[str, Callable] = "NCE", | |||||
| loss: Union[str, Callable] = "NT_Xent", | |||||
| f_loss: Union[str, Callable] = "nll_loss", | f_loss: Union[str, Callable] = "nll_loss", | ||||
| views_fn: _typing.Union[ | views_fn: _typing.Union[ | ||||
| _typing.Sequence[_typing.Callable], None | _typing.Sequence[_typing.Callable], None | ||||
| @@ -59,7 +59,6 @@ class BaseContrastiveTrainer(BaseTrainer): | |||||
| node_level: bool = False, | node_level: bool = False, | ||||
| z_dim: _typing.Union[int, None] = None, | z_dim: _typing.Union[int, None] = None, | ||||
| z_node_dim: _typing.Union[int, None] = None, | z_node_dim: _typing.Union[int, None] = None, | ||||
| neg_by_crpt: bool = False, | |||||
| tau: int = 0.5, | tau: int = 0.5, | ||||
| p_optim: Union[torch.optim.Optimizer, str] = "Adam", | p_optim: Union[torch.optim.Optimizer, str] = "Adam", | ||||
| p_lr: float = 0.0001, | p_lr: float = 0.0001, | ||||
| @@ -110,16 +109,13 @@ class BaseContrastiveTrainer(BaseTrainer): | |||||
| The dimension of graph-level representations | The dimension of graph-level representations | ||||
| z_node_dim: `int`, Optional | z_node_dim: `int`, Optional | ||||
| The dimension of node-level representations | The dimension of node-level representations | ||||
| neg_by_crpt: `bool`, Optional | |||||
| The mode to obtain negative samples | |||||
| tau: `int`, Optional | tau: `int`, Optional | ||||
| The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE" | |||||
| The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent" | |||||
| model_path: `str` or None, Optional | model_path: `str` or None, Optional | ||||
| The directory to restore the saved model. | The directory to restore the saved model. | ||||
| If `model_path` = None, the model will not be saved. | If `model_path` = None, the model will not be saved. | ||||
| """ | """ | ||||
| assert (node_level or graph_level) is True | assert (node_level or graph_level) is True | ||||
| assert not (loss == "NCE" and neg_by_crpt) | |||||
| assert isinstance(encoder, BaseEncoderMaintainer) or isinstance(encoder, str) or encoder is None | assert isinstance(encoder, BaseEncoderMaintainer) or isinstance(encoder, str) or encoder is None | ||||
| self.loss = self._get_loss(loss) | self.loss = self._get_loss(loss) | ||||
| self.node_level = node_level | self.node_level = node_level | ||||
| @@ -141,7 +137,6 @@ class BaseContrastiveTrainer(BaseTrainer): | |||||
| self.last_dim = z_dim if graph_level else z_node_dim | self.last_dim = z_dim if graph_level else z_node_dim | ||||
| self.num_features = num_features | self.num_features = num_features | ||||
| self.num_graph_features = num_graph_features | self.num_graph_features = num_graph_features | ||||
| self.neg_by_crpt = neg_by_crpt | |||||
| self.tau = tau | self.tau = tau | ||||
| self.model_path = model_path | self.model_path = model_path | ||||
| if isinstance(device, str): | if isinstance(device, str): | ||||
| @@ -195,8 +190,8 @@ class BaseContrastiveTrainer(BaseTrainer): | |||||
| if callable(loss): | if callable(loss): | ||||
| return loss | return loss | ||||
| elif isinstance(loss, str): | elif isinstance(loss, str): | ||||
| assert loss in ['JSE', 'NCE'] | |||||
| return {'JSE': JSE_loss, 'NCE': NCE_loss}[loss] | |||||
| assert loss in ['NT_Xent'] | |||||
| return {'NT_Xent': NTXent_loss}[loss] | |||||
| else: | else: | ||||
| raise NotImplementedError("The argument `loss` should be str or callable which returns a loss tensor") | raise NotImplementedError("The argument `loss` should be str or callable which returns a loss tensor") | ||||
| @@ -453,7 +448,7 @@ class BaseContrastiveTrainer(BaseTrainer): | |||||
| for view in views: | for view in views: | ||||
| z = self._get_embed(view.to(self.device)) | z = self._get_embed(view.to(self.device)) | ||||
| zs.append(self.decoder.decoder(z, view.to(self.device))) | zs.append(self.decoder.decoder(z, view.to(self.device))) | ||||
| loss = self.loss(zs, neg_by_crpt=self.neg_by_crpt, tau=self.tau) | |||||
| loss = self.loss(zs, tau=self.tau) | |||||
| loss.backward() | loss.backward() | ||||
| optimizer.step() | optimizer.step() | ||||
| if self.p_lr_scheduler_type: | if self.p_lr_scheduler_type: | ||||
| @@ -474,7 +469,7 @@ class BaseContrastiveTrainer(BaseTrainer): | |||||
| for view in views: | for view in views: | ||||
| z = self._get_embed(view.to(self.device)) | z = self._get_embed(view.to(self.device)) | ||||
| zs.append(self.decoder.decoder(z, view.to(self.device))) | zs.append(self.decoder.decoder(z, view.to(self.device))) | ||||
| loss = self.loss(zs, neg_by_crpt=self.neg_by_crpt, tau=self.tau) | |||||
| loss = self.loss(zs, tau=self.tau) | |||||
| epoch_loss += loss.item() | epoch_loss += loss.item() | ||||
| last_loss = loss.item() | last_loss = loss.item() | ||||
| return epoch_loss, last_loss | return epoch_loss, last_loss | ||||
| @@ -534,19 +529,7 @@ class BaseContrastiveTrainer(BaseTrainer): | |||||
| return self.encoder.encoder.to(self.device) | return self.encoder.encoder.to(self.device) | ||||
| def _get_embed(self, view): | def _get_embed(self, view): | ||||
| if self.neg_by_crpt: | |||||
| view_crpt = self._corrupt_graph(view) | |||||
| if self.node_level and self.graph_level: | |||||
| z_g, z_n = self.encoder.encoder(view) | |||||
| z_g_crpt, z_n_crpt = self.encoder.encoder(view_crpt) | |||||
| z = (torch.cat([z_g, z_g_crpt], 0), | |||||
| torch.cat([z_n, z_n_crpt], 0)) | |||||
| else: | |||||
| z = self.encoder.encoder(view) | |||||
| z_crpt = self.encoder.encoder(view_crpt) | |||||
| z = torch.cat([z, z_crpt], 0) | |||||
| else: | |||||
| z = self.encoder.encoder(view) | |||||
| z = self.encoder.encoder(view) | |||||
| return z | return z | ||||
| def predict(self, dataset, mask="test"): | def predict(self, dataset, mask="test"): | ||||
| @@ -1,4 +1,3 @@ | |||||
| # codes in this file are reproduced from <https://github.com/divelab/DIG> with some changes. | |||||
| import os | import os | ||||
| import torch | import torch | ||||
| import logging | import logging | ||||
| @@ -14,7 +13,6 @@ from typing import Union, Tuple, Sequence, Type, Callable | |||||
| from tqdm import trange | from tqdm import trange | ||||
| from copy import deepcopy | from copy import deepcopy | ||||
| from dig.sslgraph.evaluation.eval_graph import k_fold | |||||
| from .base import BaseContrastiveTrainer | from .base import BaseContrastiveTrainer | ||||
| @@ -54,7 +52,6 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): | |||||
| ] = None, | ] = None, | ||||
| aug_ratio: Union[float, Sequence[float]] = 0.2, | aug_ratio: Union[float, Sequence[float]] = 0.2, | ||||
| z_dim: Union[int, None] = 128, | z_dim: Union[int, None] = 128, | ||||
| neg_by_crpt: bool = False, | |||||
| tau: int = 0.5, | tau: int = 0.5, | ||||
| model_path: Union[str, None] = "./models", | model_path: Union[str, None] = "./models", | ||||
| num_workers: int = 0, | num_workers: int = 0, | ||||
| @@ -105,10 +102,8 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): | |||||
| If aug_ratio is set as a list of float, the value of this list and views_fn one to one correspondence. | If aug_ratio is set as a list of float, the value of this list and views_fn one to one correspondence. | ||||
| z_dim: `int` | z_dim: `int` | ||||
| The dimension of graph-level representations | The dimension of graph-level representations | ||||
| neg_by_crpt: `bool` | |||||
| The mode to obtain negative samples. Only required when `loss` = "JSE" | |||||
| tau: `int` | tau: `int` | ||||
| The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE" | |||||
| The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent" | |||||
| model_path: `str` or None | model_path: `str` or None | ||||
| The directory to restore the saved model. | The directory to restore the saved model. | ||||
| If `model_path` = None, the model will not be saved. | If `model_path` = None, the model will not be saved. | ||||
| @@ -165,9 +160,10 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): | |||||
| feval=feval, | feval=feval, | ||||
| z_dim=z_dim, | z_dim=z_dim, | ||||
| z_node_dim=None, | z_node_dim=None, | ||||
| neg_by_crpt=neg_by_crpt, | |||||
| tau=tau, | tau=tau, | ||||
| model_path=model_path | |||||
| model_path=model_path, | |||||
| *args, | |||||
| **kwargs | |||||
| ) | ) | ||||
| self.views_fn = views_fn | self.views_fn = views_fn | ||||
| self.aug_ratio = aug_ratio | self.aug_ratio = aug_ratio | ||||
| @@ -438,7 +434,6 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): | |||||
| views_fn=self.views_fn_opt, | views_fn=self.views_fn_opt, | ||||
| aug_ratio=self.aug_ratio, | aug_ratio=self.aug_ratio, | ||||
| z_dim=self.last_dim, | z_dim=self.last_dim, | ||||
| neg_by_crpt=self.neg_by_crpt, | |||||
| tau=self.tau, | tau=self.tau, | ||||
| model_path=self.model_path, | model_path=self.model_path, | ||||
| num_workers=self.num_workers, | num_workers=self.num_workers, | ||||
| @@ -530,10 +525,8 @@ class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer): | |||||
| If aug_ratio is set as a list of float, the value of this list and views_fn one to one correspondence. | If aug_ratio is set as a list of float, the value of this list and views_fn one to one correspondence. | ||||
| z_dim: `int` | z_dim: `int` | ||||
| The dimension of graph-level representations | The dimension of graph-level representations | ||||
| neg_by_crpt: `bool` | |||||
| The mode to obtain negative samples. Only required when `loss` = "JSE" | |||||
| tau: `int` | tau: `int` | ||||
| The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE" | |||||
| The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent" | |||||
| model_path: `str` or None | model_path: `str` or None | ||||
| The directory to restore the saved model. | The directory to restore the saved model. | ||||
| If `model_path` = None, the model will not be saved. | If `model_path` = None, the model will not be saved. | ||||
| @@ -894,7 +887,6 @@ class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer): | |||||
| views_fn=self.views_fn_opt, | views_fn=self.views_fn_opt, | ||||
| aug_ratio=self.aug_ratio, | aug_ratio=self.aug_ratio, | ||||
| z_dim=self.last_dim, | z_dim=self.last_dim, | ||||
| neg_by_crpt=self.neg_by_crpt, | |||||
| tau=self.tau, | tau=self.tau, | ||||
| model_path=self.model_path, | model_path=self.model_path, | ||||
| num_workers=self.num_workers, | num_workers=self.num_workers, | ||||
| @@ -0,0 +1,17 @@ | |||||
| # NTXent_loss from <https://github.com/Shen-Lab/GraphCL/> | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| def NTXent_loss(zs, tau=0.5, norm=True): | |||||
| batch_size, _ = zs[0].size() | |||||
| sim_matrix = torch.einsum('ik,jk->ij', zs[0], zs[1]) | |||||
| if norm: | |||||
| z1_abs = zs[0].norm(dim=1) | |||||
| z2_abs = zs[1].norm(dim=1) | |||||
| sim_matrix = sim_matrix / torch.einsum('i,j->ij', z1_abs, z2_abs) | |||||
| sim_matrix = torch.exp(sim_matrix/tau) | |||||
| pos_sim = sim_matrix[range(batch_size), range(batch_size)] | |||||
| loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) | |||||
| loss = - torch.log(loss).mean() | |||||
| return loss | |||||
| @@ -1,8 +1,8 @@ | |||||
| from dig.sslgraph.method.contrastive.views_fn import ( | |||||
| NodeAttrMask, | |||||
| EdgePerturbation, | |||||
| UniformSample, | |||||
| RWSample, | |||||
| from .views_fn import ( | |||||
| DropNode, | |||||
| PermuteEdge, | |||||
| MaskNode, | |||||
| SubGraph, | |||||
| RandomView | RandomView | ||||
| ) | ) | ||||
| @@ -10,28 +10,28 @@ def get_view_by_name(view, aug_ratio): | |||||
| if view is None: | if view is None: | ||||
| return lambda x: x | return lambda x: x | ||||
| elif view == "dropN": | elif view == "dropN": | ||||
| return UniformSample(ratio=aug_ratio) | |||||
| return DropNode(aug_ratio=aug_ratio) | |||||
| elif view == "permE": | elif view == "permE": | ||||
| return EdgePerturbation(ratio=aug_ratio) | |||||
| return PermuteEdge(aug_ratio=aug_ratio) | |||||
| elif view == "subgraph": | elif view == "subgraph": | ||||
| return RWSample(ratio=aug_ratio) | |||||
| return SubGraph(aug_ratio=aug_ratio) | |||||
| elif view == "maskN": | elif view == "maskN": | ||||
| return NodeAttrMask(mask_ratio=aug_ratio) | |||||
| return MaskNode(aug_ratio=aug_ratio) | |||||
| elif view == "random2": | elif view == "random2": | ||||
| canditates = [UniformSample(ratio=aug_ratio), | |||||
| RWSample(ratio=aug_ratio)] | |||||
| canditates = [DropNode(aug_ratio=aug_ratio), | |||||
| SubGraph(aug_ratio=aug_ratio)] | |||||
| return RandomView(candidates=canditates) | return RandomView(candidates=canditates) | ||||
| elif view == "random3": | elif view == "random3": | ||||
| canditates = [UniformSample(ratio=aug_ratio), | |||||
| RWSample(ratio=aug_ratio), | |||||
| EdgePerturbation(ratio=aug_ratio)] | |||||
| canditates = [DropNode(aug_ratio=aug_ratio), | |||||
| SubGraph(aug_ratio=aug_ratio), | |||||
| PermuteEdge(aug_ratio=aug_ratio)] | |||||
| return RandomView(candidates=canditates) | return RandomView(candidates=canditates) | ||||
| elif view == "random4": | elif view == "random4": | ||||
| canditates = [UniformSample(ratio=aug_ratio), | |||||
| RWSample(ratio=aug_ratio), | |||||
| EdgePerturbation(ratio=aug_ratio), | |||||
| NodeAttrMask(mask_ratio=aug_ratio)] | |||||
| canditates = [DropNode(aug_ratio=aug_ratio), | |||||
| SubGraph(aug_ratio=aug_ratio), | |||||
| PermuteEdge(aug_ratio=aug_ratio), | |||||
| MaskNode(aug_ratio=aug_ratio)] | |||||
| return RandomView(candidates=canditates) | return RandomView(candidates=canditates) | ||||
| else: | else: | ||||
| raise NotImplementedError(f'The augmentation method must be in ["dropN", "permE", "subgraph", \ | |||||
| "maskN", "random2", "random3", "random4"] or None. And {view} is not supported yet.') | |||||
| raise NotImplementedError(f'{view} is not supported yet. Support: ["dropN", "permE", "subgraph", \ | |||||
| "maskN", "random2", "random3", "random4", None]') | |||||
| @@ -0,0 +1,137 @@ | |||||
| # pyg augmentation method from <https://github.com/Shen-Lab/GraphCL/> | |||||
| import random | |||||
| import torch | |||||
| import numpy as np | |||||
| from itertools import repeat, product | |||||
| from torch_geometric.data import Batch | |||||
| class BaseAugmentation: | |||||
| def __init__(self, aug_ratio=None): | |||||
| self.aug_ratio = aug_ratio | |||||
| def _aug_data(self, data): | |||||
| pass | |||||
| def __call__(self, batch): | |||||
| new_data = [] | |||||
| for data in batch.to_data_list(): | |||||
| new_data.append(self._aug_data(data)) | |||||
| return Batch.from_data_list(new_data) | |||||
| class DropNode(BaseAugmentation): | |||||
| def __init__(self, aug_ratio): | |||||
| super().__init__(aug_ratio) | |||||
| def _aug_data(self, data): | |||||
| node_num, _ = data.x.size() | |||||
| _, edge_num = data.edge_index.size() | |||||
| drop_num = int(node_num * self.aug_ratio) | |||||
| idx_perm = np.random.permutation(node_num) | |||||
| idx_drop = idx_perm[:drop_num] | |||||
| idx_nondrop = idx_perm[drop_num:] | |||||
| idx_nondrop.sort() | |||||
| idx_dict = {idx_nondrop[n]:n for n in list(range(idx_nondrop.shape[0]))} | |||||
| edge_index = data.edge_index.numpy() | |||||
| adj = torch.zeros((node_num, node_num)) | |||||
| adj[edge_index[0], edge_index[1]] = 1 | |||||
| adj = adj[idx_nondrop, :][:, idx_nondrop] | |||||
| edge_index = adj.nonzero().t() | |||||
| try: | |||||
| data.edge_index = edge_index | |||||
| data.x = data.x[idx_nondrop] | |||||
| except: | |||||
| data = data | |||||
| return data | |||||
| class PermuteEdge(BaseAugmentation): | |||||
| def __init__(self, aug_ratio): | |||||
| super().__init__(aug_ratio) | |||||
| def _aug_data(self, data): | |||||
| node_num, _ = data.x.size() | |||||
| _, edge_num = data.edge_index.size() | |||||
| permute_num = int(edge_num * self.aug_ratio) | |||||
| edge_index = data.edge_index.numpy() | |||||
| idx_add = np.random.choice(node_num, (2, permute_num)) | |||||
| # idx_add = [[idx_add[0, n], idx_add[1, n]] for n in range(permute_num) if not (idx_add[0, n], idx_add[1, n]) in edge_index] | |||||
| # edge_index = [edge_index[n] for n in range(edge_num) if not n in np.random.choice(edge_num, permute_num, replace=False)] + idx_add | |||||
| edge_index = np.concatenate((edge_index[:, np.random.choice(edge_num, (edge_num - permute_num), replace=False)], idx_add), axis=1) | |||||
| data.edge_index = torch.tensor(edge_index) | |||||
| return data | |||||
| class SubGraph(BaseAugmentation): | |||||
| def __init__(self, aug_ratio): | |||||
| super().__init__(aug_ratio) | |||||
| def _aug_data(self, data): | |||||
| node_num, _ = data.x.size() | |||||
| _, edge_num = data.edge_index.size() | |||||
| sub_num = int(node_num * self.aug_ratio) | |||||
| edge_index = data.edge_index.numpy() | |||||
| idx_sub = [np.random.randint(node_num, size=1)[0]] | |||||
| idx_neigh = set([n for n in edge_index[1][edge_index[0]==idx_sub[0]]]) | |||||
| count = 0 | |||||
| while len(idx_sub) <= sub_num: | |||||
| count = count + 1 | |||||
| if count > node_num: | |||||
| break | |||||
| if len(idx_neigh) == 0: | |||||
| break | |||||
| sample_node = np.random.choice(list(idx_neigh)) | |||||
| if sample_node in idx_sub: | |||||
| continue | |||||
| idx_sub.append(sample_node) | |||||
| idx_neigh.union(set([n for n in edge_index[1][edge_index[0]==idx_sub[-1]]])) | |||||
| idx_drop = [n for n in range(node_num) if not n in idx_sub] | |||||
| idx_nondrop = idx_sub | |||||
| data.x = data.x[idx_nondrop] | |||||
| idx_dict = {idx_nondrop[n]:n for n in list(range(len(idx_nondrop)))} | |||||
| edge_index = data.edge_index.numpy() | |||||
| adj = torch.zeros((node_num, node_num)) | |||||
| adj[edge_index[0], edge_index[1]] = 1 | |||||
| adj[list(range(node_num)), list(range(node_num))] = 1 | |||||
| adj = adj[idx_nondrop, :][:, idx_nondrop] | |||||
| edge_index = adj.nonzero().t() | |||||
| # edge_index = [[idx_dict[edge_index[0, n]], idx_dict[edge_index[1, n]]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] + [[n, n] for n in idx_nondrop] | |||||
| data.edge_index = edge_index | |||||
| return data | |||||
| class MaskNode(BaseAugmentation): | |||||
| def __init__(self, aug_ratio): | |||||
| super().__init__(aug_ratio) | |||||
| def _aug_data(self, data): | |||||
| node_num, feat_dim = data.x.size() | |||||
| mask_num = int(node_num * self.aug_ratio) | |||||
| token = data.x.mean(dim=0) | |||||
| idx_mask = np.random.choice(node_num, mask_num, replace=False) | |||||
| data.x[idx_mask] = torch.tensor(token, dtype=torch.float32) | |||||
| return data | |||||
| class RandomView(BaseAugmentation): | |||||
| def __init__(self, candidates): | |||||
| super().__init__() | |||||
| self.candidates = candidates | |||||
| def _aug_data(self, data): | |||||
| view = random.choice(self.candidates) | |||||
| return view._aug_data(data) | |||||
| @@ -302,7 +302,7 @@ class SSLGraphClassifier(BaseClassifier): | |||||
| num_classes=num_classes, | num_classes=num_classes, | ||||
| feval=evaluator_list, | feval=evaluator_list, | ||||
| device=self.runtime_device, | device=self.runtime_device, | ||||
| loss="NCE" if not hasattr(dataset, "loss") else dataset.loss, | |||||
| loss="NT_Xent" if not hasattr(dataset, "loss") else dataset.loss, | |||||
| num_graph_features=(0 | num_graph_features=(0 | ||||
| if not hasattr(dataset[0], "gf") | if not hasattr(dataset[0], "gf") | ||||
| else dataset[0].gf.size(1)) if BACKEND == 'pyg' else | else dataset[0].gf.size(1)) if BACKEND == 'pyg' else | ||||
| @@ -35,12 +35,12 @@ def test_graph_trainer(): | |||||
| prediction_head="sumpoolmlp", | prediction_head="sumpoolmlp", | ||||
| views_fn=["random2", "random2"], | views_fn=["random2", "random2"], | ||||
| batch_size=128, | batch_size=128, | ||||
| p_lr=5.6004725115062315e-05, | |||||
| p_weight_decay=0.00022810837622188083, | |||||
| p_epoch=267, | |||||
| f_epoch=131, | |||||
| f_lr=0.0005362155524564354, | |||||
| f_weight_decay=0.0022069814932058804, | |||||
| p_lr=0.0001, | |||||
| p_weight_decay=0.0002, | |||||
| p_epoch=300, | |||||
| f_epoch=150, | |||||
| f_lr=0.0001, | |||||
| f_weight_decay=0.002, | |||||
| p_early_stopping_round=50, | p_early_stopping_round=50, | ||||
| f_early_stopping_round=50, | f_early_stopping_round=50, | ||||
| z_dim=128, | z_dim=128, | ||||
| @@ -0,0 +1,115 @@ | |||||
| import os | |||||
| import random | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import numpy as np | |||||
| from autogl.module.train.ssl import GraphCLSemisupervisedTrainer | |||||
| from autogl.datasets import build_dataset_from_name, utils | |||||
| from autogl.datasets.utils.conversion import to_pyg_dataset as convert_dataset | |||||
| from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer | |||||
| from autogl.module.model.decoders import BaseDecoderMaintainer | |||||
| from autogl.solver.utils import set_seed | |||||
| def fixed(**kwargs): | |||||
| return [{ | |||||
| 'parameterName': k, | |||||
| 'type': "FIXED", | |||||
| 'value': v | |||||
| } for k, v in kwargs.items()] | |||||
| if __name__ == "__main__": | |||||
| import argparse | |||||
| parser = argparse.ArgumentParser('ssl pyg trainer') | |||||
| parser.add_argument('--device', type=str, default='cuda') | |||||
| parser.add_argument('--dataset', type=str, choices=['MUTAG', 'NCI1', 'PROTEINS', 'PTC_MR'], default='PROTEINS') | |||||
| parser.add_argument('--dataset_seed', type=int, default=2021) | |||||
| parser.add_argument('--batch_size', type=int, default=32) | |||||
| parser.add_argument('--repeat', type=int, default=50) | |||||
| # parser.add_argument('--model', type=str, choices=['gin', 'gat', 'gcn', 'sage'], default='gin') | |||||
| parser.add_argument('--encoder', type=str, choices=['gin', 'gcn'], default='gcn') | |||||
| parser.add_argument('--p_lr', type=float, default=0.0001) | |||||
| parser.add_argument('--p_weight_decay', type=float, default=0) | |||||
| parser.add_argument('--p_epoch', type=int, default=100) | |||||
| parser.add_argument('--f_lr', type=float, default=0.001) | |||||
| parser.add_argument('--f_weight_decay', type=float, default=0) | |||||
| parser.add_argument('--f_epoch', type=int, default=100) | |||||
| parser.add_argument('--epoch', type=int, default=100) | |||||
| args=parser.parse_args() | |||||
| # split dataset | |||||
| dataset = build_dataset_from_name(args.dataset) | |||||
| dataset = convert_dataset(dataset) | |||||
| utils.graph_random_splits(dataset, train_ratio=0.1, val_ratio=0.1, seed=2022) | |||||
| accs = [[],[],[]] | |||||
| encoder_hp = { | |||||
| "num_layers": 5, | |||||
| "hidden": [32, 64, 64, 64], | |||||
| "dropout": 0.5, | |||||
| "act": "elu", | |||||
| "eps": "true" | |||||
| } | |||||
| decoder_hp = { | |||||
| "hidden": 32, | |||||
| "act": "tanh", | |||||
| "dropout": 0.35 | |||||
| } | |||||
| prediction_head = { | |||||
| "hidden": 128, | |||||
| "act": "relu", | |||||
| "dropout": 0.4 | |||||
| } | |||||
| from tqdm import tqdm | |||||
| for seed in tqdm(range(args.repeat)): | |||||
| set_seed(seed) | |||||
| trainer = GraphCLSemisupervisedTrainer( | |||||
| model=(args.encoder, 'sumpoolmlp'), | |||||
| prediction_head='sumpoolmlp', | |||||
| views_fn=['random2', 'random2'], | |||||
| device=args.device, | |||||
| num_features=dataset[0].x.size(1), | |||||
| num_classes=max([data.y.item() for data in dataset]) + 1, | |||||
| batch_size=args.batch_size, | |||||
| # p_lr=args.p_lr, | |||||
| # p_weight_decay=args.p_weight_decay, | |||||
| # p_epoch=args.p_epoch, | |||||
| # f_lr=args.f_lr, | |||||
| # f_weight_decay=args.f_weight_decay, | |||||
| # f_epoch=args.f_epoch, | |||||
| z_dim=128, | |||||
| init=False | |||||
| ) | |||||
| trainer.initialize() | |||||
| trainer = trainer.duplicate_from_hyper_parameter( | |||||
| { | |||||
| 'trainer': { | |||||
| 'batch_size': args.batch_size, | |||||
| 'p_lr': args.p_lr, | |||||
| 'p_weight_decay': args.p_weight_decay, | |||||
| 'p_epoch': args.p_epoch, | |||||
| 'p_early_stopping_round': args.p_epoch + 1, | |||||
| 'f_lr': args.f_lr, | |||||
| 'f_weight_decay': args.f_weight_decay, | |||||
| 'f_epoch': args.f_epoch, | |||||
| 'f_early_stopping_round': args.f_epoch + 1, | |||||
| }, | |||||
| "encoder": encoder_hp, | |||||
| "decoder": decoder_hp, | |||||
| "prediction_head": prediction_head | |||||
| } | |||||
| ) | |||||
| trainer.train(dataset, False) | |||||
| out = trainer.predict(dataset, 'test').detach().cpu().numpy() | |||||
| train_result = trainer.evaluate(dataset, 'train') | |||||
| valid_result = trainer.evaluate(dataset, 'val') | |||||
| test_result = trainer.evaluate(dataset, 'test') | |||||
| print(f"{train_result[0]} - {valid_result[0]} - {test_result[0]}") | |||||
| accs[0].append(train_result[0]) | |||||
| accs[1].append(valid_result[0]) | |||||
| accs[2].append(test_result[0]) | |||||
| print('{:.4f} ~ {:.4f}'.format(np.mean(accs[0]), np.std(accs[0]))) | |||||
| print('{:.4f} ~ {:.4f}'.format(np.mean(accs[1]), np.std(accs[1]))) | |||||
| print('{:.4f} ~ {:.4f}'.format(np.mean(accs[2]), np.std(accs[2]))) | |||||