From 5d7b06f6f03cbd37f662f120fb830ef84eeb4e2d Mon Sep 17 00:00:00 2001 From: defineZYP <953726616@qq.com> Date: Tue, 20 Dec 2022 16:11:55 +0800 Subject: [PATCH] fix bug --- autogl/datasets/_ogb.py | 5 +- autogl/module/train/ssl/base.py | 33 +---- autogl/module/train/ssl/graphcl.py | 18 +-- autogl/module/train/ssl/losses.py | 17 +++ autogl/module/train/ssl/utils.py | 40 ++--- autogl/module/train/ssl/views_fn.py | 137 ++++++++++++++++++ .../classifier/ssl/ssl_graph_classifier.py | 2 +- test/trainer/pyg/graphcl_ssl.py | 12 +- test/trainer/pyg/graphcl_ssl_full.py | 115 +++++++++++++++ 9 files changed, 313 insertions(+), 66 deletions(-) create mode 100644 autogl/module/train/ssl/losses.py create mode 100644 autogl/module/train/ssl/views_fn.py create mode 100644 test/trainer/pyg/graphcl_ssl_full.py diff --git a/autogl/datasets/_ogb.py b/autogl/datasets/_ogb.py index 9fc39c4..9044a09 100644 --- a/autogl/datasets/_ogb.py +++ b/autogl/datasets/_ogb.py @@ -40,7 +40,10 @@ class _OGBNDatasetUtil(_OGBDatasetUtil): edge_feat = torch.tensor(edge_feat) edge_index = SparseTensor(row=torch.tensor(edge_index[0]), col=torch.tensor(edge_index[1]), value=edge_feat, sparse_sizes=(num_nodes, num_nodes)) _, _, value = edge_index.coo() - ogbn_data['edge_feat'] = value.cpu().detach().numpy() + if value is not None: + ogbn_data['edge_feat'] = value.cpu().detach().numpy() + else: + ogbn_data['edge_feat'] = edge_feat edge_index = edge_index.to_symmetric() row, col, _ = edge_index.coo() edge_index = np.array([row.cpu().detach().numpy(), col.cpu().detach().numpy()]) diff --git a/autogl/module/train/ssl/base.py b/autogl/module/train/ssl/base.py index 1472226..3daf9f9 100644 --- a/autogl/module/train/ssl/base.py +++ b/autogl/module/train/ssl/base.py @@ -18,7 +18,7 @@ from torch.optim.lr_scheduler import ( ReduceLROnPlateau, ) -from dig.sslgraph.method.contrastive.objectives import NCE_loss, JSE_loss +from .losses import NTXent_loss from .utils import get_view_by_name from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer @@ -49,7 +49,7 @@ class BaseContrastiveTrainer(BaseTrainer): feval: _typing.Union[ _typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]] ] = (Acc,), - loss: Union[str, Callable] = "NCE", + loss: Union[str, Callable] = "NT_Xent", f_loss: Union[str, Callable] = "nll_loss", views_fn: _typing.Union[ _typing.Sequence[_typing.Callable], None @@ -59,7 +59,6 @@ class BaseContrastiveTrainer(BaseTrainer): node_level: bool = False, z_dim: _typing.Union[int, None] = None, z_node_dim: _typing.Union[int, None] = None, - neg_by_crpt: bool = False, tau: int = 0.5, p_optim: Union[torch.optim.Optimizer, str] = "Adam", p_lr: float = 0.0001, @@ -110,16 +109,13 @@ class BaseContrastiveTrainer(BaseTrainer): The dimension of graph-level representations z_node_dim: `int`, Optional The dimension of node-level representations - neg_by_crpt: `bool`, Optional - The mode to obtain negative samples tau: `int`, Optional - The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE" + The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent" model_path: `str` or None, Optional The directory to restore the saved model. If `model_path` = None, the model will not be saved. """ assert (node_level or graph_level) is True - assert not (loss == "NCE" and neg_by_crpt) assert isinstance(encoder, BaseEncoderMaintainer) or isinstance(encoder, str) or encoder is None self.loss = self._get_loss(loss) self.node_level = node_level @@ -141,7 +137,6 @@ class BaseContrastiveTrainer(BaseTrainer): self.last_dim = z_dim if graph_level else z_node_dim self.num_features = num_features self.num_graph_features = num_graph_features - self.neg_by_crpt = neg_by_crpt self.tau = tau self.model_path = model_path if isinstance(device, str): @@ -195,8 +190,8 @@ class BaseContrastiveTrainer(BaseTrainer): if callable(loss): return loss elif isinstance(loss, str): - assert loss in ['JSE', 'NCE'] - return {'JSE': JSE_loss, 'NCE': NCE_loss}[loss] + assert loss in ['NT_Xent'] + return {'NT_Xent': NTXent_loss}[loss] else: raise NotImplementedError("The argument `loss` should be str or callable which returns a loss tensor") @@ -453,7 +448,7 @@ class BaseContrastiveTrainer(BaseTrainer): for view in views: z = self._get_embed(view.to(self.device)) zs.append(self.decoder.decoder(z, view.to(self.device))) - loss = self.loss(zs, neg_by_crpt=self.neg_by_crpt, tau=self.tau) + loss = self.loss(zs, tau=self.tau) loss.backward() optimizer.step() if self.p_lr_scheduler_type: @@ -474,7 +469,7 @@ class BaseContrastiveTrainer(BaseTrainer): for view in views: z = self._get_embed(view.to(self.device)) zs.append(self.decoder.decoder(z, view.to(self.device))) - loss = self.loss(zs, neg_by_crpt=self.neg_by_crpt, tau=self.tau) + loss = self.loss(zs, tau=self.tau) epoch_loss += loss.item() last_loss = loss.item() return epoch_loss, last_loss @@ -534,19 +529,7 @@ class BaseContrastiveTrainer(BaseTrainer): return self.encoder.encoder.to(self.device) def _get_embed(self, view): - if self.neg_by_crpt: - view_crpt = self._corrupt_graph(view) - if self.node_level and self.graph_level: - z_g, z_n = self.encoder.encoder(view) - z_g_crpt, z_n_crpt = self.encoder.encoder(view_crpt) - z = (torch.cat([z_g, z_g_crpt], 0), - torch.cat([z_n, z_n_crpt], 0)) - else: - z = self.encoder.encoder(view) - z_crpt = self.encoder.encoder(view_crpt) - z = torch.cat([z, z_crpt], 0) - else: - z = self.encoder.encoder(view) + z = self.encoder.encoder(view) return z def predict(self, dataset, mask="test"): diff --git a/autogl/module/train/ssl/graphcl.py b/autogl/module/train/ssl/graphcl.py index 1a714f9..cd8fa0c 100644 --- a/autogl/module/train/ssl/graphcl.py +++ b/autogl/module/train/ssl/graphcl.py @@ -1,4 +1,3 @@ -# codes in this file are reproduced from with some changes. import os import torch import logging @@ -14,7 +13,6 @@ from typing import Union, Tuple, Sequence, Type, Callable from tqdm import trange from copy import deepcopy -from dig.sslgraph.evaluation.eval_graph import k_fold from .base import BaseContrastiveTrainer @@ -54,7 +52,6 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): ] = None, aug_ratio: Union[float, Sequence[float]] = 0.2, z_dim: Union[int, None] = 128, - neg_by_crpt: bool = False, tau: int = 0.5, model_path: Union[str, None] = "./models", num_workers: int = 0, @@ -105,10 +102,8 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): If aug_ratio is set as a list of float, the value of this list and views_fn one to one correspondence. z_dim: `int` The dimension of graph-level representations - neg_by_crpt: `bool` - The mode to obtain negative samples. Only required when `loss` = "JSE" tau: `int` - The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE" + The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent" model_path: `str` or None The directory to restore the saved model. If `model_path` = None, the model will not be saved. @@ -165,9 +160,10 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): feval=feval, z_dim=z_dim, z_node_dim=None, - neg_by_crpt=neg_by_crpt, tau=tau, - model_path=model_path + model_path=model_path, + *args, + **kwargs ) self.views_fn = views_fn self.aug_ratio = aug_ratio @@ -438,7 +434,6 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer): views_fn=self.views_fn_opt, aug_ratio=self.aug_ratio, z_dim=self.last_dim, - neg_by_crpt=self.neg_by_crpt, tau=self.tau, model_path=self.model_path, num_workers=self.num_workers, @@ -530,10 +525,8 @@ class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer): If aug_ratio is set as a list of float, the value of this list and views_fn one to one correspondence. z_dim: `int` The dimension of graph-level representations - neg_by_crpt: `bool` - The mode to obtain negative samples. Only required when `loss` = "JSE" tau: `int` - The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE" + The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent" model_path: `str` or None The directory to restore the saved model. If `model_path` = None, the model will not be saved. @@ -894,7 +887,6 @@ class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer): views_fn=self.views_fn_opt, aug_ratio=self.aug_ratio, z_dim=self.last_dim, - neg_by_crpt=self.neg_by_crpt, tau=self.tau, model_path=self.model_path, num_workers=self.num_workers, diff --git a/autogl/module/train/ssl/losses.py b/autogl/module/train/ssl/losses.py new file mode 100644 index 0000000..359512b --- /dev/null +++ b/autogl/module/train/ssl/losses.py @@ -0,0 +1,17 @@ +# NTXent_loss from +import torch +import torch.nn as nn + +def NTXent_loss(zs, tau=0.5, norm=True): + batch_size, _ = zs[0].size() + sim_matrix = torch.einsum('ik,jk->ij', zs[0], zs[1]) + if norm: + z1_abs = zs[0].norm(dim=1) + z2_abs = zs[1].norm(dim=1) + sim_matrix = sim_matrix / torch.einsum('i,j->ij', z1_abs, z2_abs) + sim_matrix = torch.exp(sim_matrix/tau) + pos_sim = sim_matrix[range(batch_size), range(batch_size)] + loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) + loss = - torch.log(loss).mean() + + return loss diff --git a/autogl/module/train/ssl/utils.py b/autogl/module/train/ssl/utils.py index 594bdf6..3b048be 100644 --- a/autogl/module/train/ssl/utils.py +++ b/autogl/module/train/ssl/utils.py @@ -1,8 +1,8 @@ -from dig.sslgraph.method.contrastive.views_fn import ( - NodeAttrMask, - EdgePerturbation, - UniformSample, - RWSample, +from .views_fn import ( + DropNode, + PermuteEdge, + MaskNode, + SubGraph, RandomView ) @@ -10,28 +10,28 @@ def get_view_by_name(view, aug_ratio): if view is None: return lambda x: x elif view == "dropN": - return UniformSample(ratio=aug_ratio) + return DropNode(aug_ratio=aug_ratio) elif view == "permE": - return EdgePerturbation(ratio=aug_ratio) + return PermuteEdge(aug_ratio=aug_ratio) elif view == "subgraph": - return RWSample(ratio=aug_ratio) + return SubGraph(aug_ratio=aug_ratio) elif view == "maskN": - return NodeAttrMask(mask_ratio=aug_ratio) + return MaskNode(aug_ratio=aug_ratio) elif view == "random2": - canditates = [UniformSample(ratio=aug_ratio), - RWSample(ratio=aug_ratio)] + canditates = [DropNode(aug_ratio=aug_ratio), + SubGraph(aug_ratio=aug_ratio)] return RandomView(candidates=canditates) elif view == "random3": - canditates = [UniformSample(ratio=aug_ratio), - RWSample(ratio=aug_ratio), - EdgePerturbation(ratio=aug_ratio)] + canditates = [DropNode(aug_ratio=aug_ratio), + SubGraph(aug_ratio=aug_ratio), + PermuteEdge(aug_ratio=aug_ratio)] return RandomView(candidates=canditates) elif view == "random4": - canditates = [UniformSample(ratio=aug_ratio), - RWSample(ratio=aug_ratio), - EdgePerturbation(ratio=aug_ratio), - NodeAttrMask(mask_ratio=aug_ratio)] + canditates = [DropNode(aug_ratio=aug_ratio), + SubGraph(aug_ratio=aug_ratio), + PermuteEdge(aug_ratio=aug_ratio), + MaskNode(aug_ratio=aug_ratio)] return RandomView(candidates=canditates) else: - raise NotImplementedError(f'The augmentation method must be in ["dropN", "permE", "subgraph", \ - "maskN", "random2", "random3", "random4"] or None. And {view} is not supported yet.') \ No newline at end of file + raise NotImplementedError(f'{view} is not supported yet. Support: ["dropN", "permE", "subgraph", \ + "maskN", "random2", "random3", "random4", None]') diff --git a/autogl/module/train/ssl/views_fn.py b/autogl/module/train/ssl/views_fn.py new file mode 100644 index 0000000..a4dc7c2 --- /dev/null +++ b/autogl/module/train/ssl/views_fn.py @@ -0,0 +1,137 @@ +# pyg augmentation method from + +import random +import torch +import numpy as np +from itertools import repeat, product +from torch_geometric.data import Batch + +class BaseAugmentation: + def __init__(self, aug_ratio=None): + self.aug_ratio = aug_ratio + + def _aug_data(self, data): + pass + + def __call__(self, batch): + new_data = [] + for data in batch.to_data_list(): + new_data.append(self._aug_data(data)) + return Batch.from_data_list(new_data) + +class DropNode(BaseAugmentation): + def __init__(self, aug_ratio): + super().__init__(aug_ratio) + + def _aug_data(self, data): + node_num, _ = data.x.size() + _, edge_num = data.edge_index.size() + drop_num = int(node_num * self.aug_ratio) + + idx_perm = np.random.permutation(node_num) + + idx_drop = idx_perm[:drop_num] + idx_nondrop = idx_perm[drop_num:] + idx_nondrop.sort() + idx_dict = {idx_nondrop[n]:n for n in list(range(idx_nondrop.shape[0]))} + + edge_index = data.edge_index.numpy() + adj = torch.zeros((node_num, node_num)) + adj[edge_index[0], edge_index[1]] = 1 + adj = adj[idx_nondrop, :][:, idx_nondrop] + edge_index = adj.nonzero().t() + + try: + data.edge_index = edge_index + data.x = data.x[idx_nondrop] + except: + data = data + return data + +class PermuteEdge(BaseAugmentation): + def __init__(self, aug_ratio): + super().__init__(aug_ratio) + + def _aug_data(self, data): + node_num, _ = data.x.size() + _, edge_num = data.edge_index.size() + permute_num = int(edge_num * self.aug_ratio) + + edge_index = data.edge_index.numpy() + + idx_add = np.random.choice(node_num, (2, permute_num)) + + # idx_add = [[idx_add[0, n], idx_add[1, n]] for n in range(permute_num) if not (idx_add[0, n], idx_add[1, n]) in edge_index] + # edge_index = [edge_index[n] for n in range(edge_num) if not n in np.random.choice(edge_num, permute_num, replace=False)] + idx_add + + edge_index = np.concatenate((edge_index[:, np.random.choice(edge_num, (edge_num - permute_num), replace=False)], idx_add), axis=1) + data.edge_index = torch.tensor(edge_index) + + return data + +class SubGraph(BaseAugmentation): + def __init__(self, aug_ratio): + super().__init__(aug_ratio) + + def _aug_data(self, data): + node_num, _ = data.x.size() + _, edge_num = data.edge_index.size() + sub_num = int(node_num * self.aug_ratio) + + edge_index = data.edge_index.numpy() + + idx_sub = [np.random.randint(node_num, size=1)[0]] + idx_neigh = set([n for n in edge_index[1][edge_index[0]==idx_sub[0]]]) + + count = 0 + while len(idx_sub) <= sub_num: + count = count + 1 + if count > node_num: + break + if len(idx_neigh) == 0: + break + sample_node = np.random.choice(list(idx_neigh)) + if sample_node in idx_sub: + continue + idx_sub.append(sample_node) + idx_neigh.union(set([n for n in edge_index[1][edge_index[0]==idx_sub[-1]]])) + + idx_drop = [n for n in range(node_num) if not n in idx_sub] + idx_nondrop = idx_sub + data.x = data.x[idx_nondrop] + idx_dict = {idx_nondrop[n]:n for n in list(range(len(idx_nondrop)))} + + edge_index = data.edge_index.numpy() + adj = torch.zeros((node_num, node_num)) + adj[edge_index[0], edge_index[1]] = 1 + adj[list(range(node_num)), list(range(node_num))] = 1 + adj = adj[idx_nondrop, :][:, idx_nondrop] + edge_index = adj.nonzero().t() + + # edge_index = [[idx_dict[edge_index[0, n]], idx_dict[edge_index[1, n]]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] + [[n, n] for n in idx_nondrop] + data.edge_index = edge_index + + return data + +class MaskNode(BaseAugmentation): + def __init__(self, aug_ratio): + super().__init__(aug_ratio) + + def _aug_data(self, data): + node_num, feat_dim = data.x.size() + mask_num = int(node_num * self.aug_ratio) + + token = data.x.mean(dim=0) + idx_mask = np.random.choice(node_num, mask_num, replace=False) + data.x[idx_mask] = torch.tensor(token, dtype=torch.float32) + + return data + +class RandomView(BaseAugmentation): + def __init__(self, candidates): + super().__init__() + self.candidates = candidates + + def _aug_data(self, data): + view = random.choice(self.candidates) + return view._aug_data(data) diff --git a/autogl/solver/classifier/ssl/ssl_graph_classifier.py b/autogl/solver/classifier/ssl/ssl_graph_classifier.py index 889692c..306dd75 100644 --- a/autogl/solver/classifier/ssl/ssl_graph_classifier.py +++ b/autogl/solver/classifier/ssl/ssl_graph_classifier.py @@ -302,7 +302,7 @@ class SSLGraphClassifier(BaseClassifier): num_classes=num_classes, feval=evaluator_list, device=self.runtime_device, - loss="NCE" if not hasattr(dataset, "loss") else dataset.loss, + loss="NT_Xent" if not hasattr(dataset, "loss") else dataset.loss, num_graph_features=(0 if not hasattr(dataset[0], "gf") else dataset[0].gf.size(1)) if BACKEND == 'pyg' else diff --git a/test/trainer/pyg/graphcl_ssl.py b/test/trainer/pyg/graphcl_ssl.py index ca3c0b1..c4bc1d4 100644 --- a/test/trainer/pyg/graphcl_ssl.py +++ b/test/trainer/pyg/graphcl_ssl.py @@ -35,12 +35,12 @@ def test_graph_trainer(): prediction_head="sumpoolmlp", views_fn=["random2", "random2"], batch_size=128, - p_lr=5.6004725115062315e-05, - p_weight_decay=0.00022810837622188083, - p_epoch=267, - f_epoch=131, - f_lr=0.0005362155524564354, - f_weight_decay=0.0022069814932058804, + p_lr=0.0001, + p_weight_decay=0.0002, + p_epoch=300, + f_epoch=150, + f_lr=0.0001, + f_weight_decay=0.002, p_early_stopping_round=50, f_early_stopping_round=50, z_dim=128, diff --git a/test/trainer/pyg/graphcl_ssl_full.py b/test/trainer/pyg/graphcl_ssl_full.py new file mode 100644 index 0000000..bacc49c --- /dev/null +++ b/test/trainer/pyg/graphcl_ssl_full.py @@ -0,0 +1,115 @@ +import os +import random +import torch +import torch.nn as nn +import numpy as np + +from autogl.module.train.ssl import GraphCLSemisupervisedTrainer +from autogl.datasets import build_dataset_from_name, utils +from autogl.datasets.utils.conversion import to_pyg_dataset as convert_dataset +from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer +from autogl.module.model.decoders import BaseDecoderMaintainer +from autogl.solver.utils import set_seed + +def fixed(**kwargs): + return [{ + 'parameterName': k, + 'type': "FIXED", + 'value': v + } for k, v in kwargs.items()] + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser('ssl pyg trainer') + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--dataset', type=str, choices=['MUTAG', 'NCI1', 'PROTEINS', 'PTC_MR'], default='PROTEINS') + parser.add_argument('--dataset_seed', type=int, default=2021) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--repeat', type=int, default=50) + # parser.add_argument('--model', type=str, choices=['gin', 'gat', 'gcn', 'sage'], default='gin') + parser.add_argument('--encoder', type=str, choices=['gin', 'gcn'], default='gcn') + parser.add_argument('--p_lr', type=float, default=0.0001) + parser.add_argument('--p_weight_decay', type=float, default=0) + parser.add_argument('--p_epoch', type=int, default=100) + parser.add_argument('--f_lr', type=float, default=0.001) + parser.add_argument('--f_weight_decay', type=float, default=0) + parser.add_argument('--f_epoch', type=int, default=100) + parser.add_argument('--epoch', type=int, default=100) + + args=parser.parse_args() + + # split dataset + dataset = build_dataset_from_name(args.dataset) + dataset = convert_dataset(dataset) + utils.graph_random_splits(dataset, train_ratio=0.1, val_ratio=0.1, seed=2022) + + accs = [[],[],[]] + + encoder_hp = { + "num_layers": 5, + "hidden": [32, 64, 64, 64], + "dropout": 0.5, + "act": "elu", + "eps": "true" + } + decoder_hp = { + "hidden": 32, + "act": "tanh", + "dropout": 0.35 + } + prediction_head = { + "hidden": 128, + "act": "relu", + "dropout": 0.4 + } + from tqdm import tqdm + for seed in tqdm(range(args.repeat)): + set_seed(seed) + trainer = GraphCLSemisupervisedTrainer( + model=(args.encoder, 'sumpoolmlp'), + prediction_head='sumpoolmlp', + views_fn=['random2', 'random2'], + device=args.device, + num_features=dataset[0].x.size(1), + num_classes=max([data.y.item() for data in dataset]) + 1, + batch_size=args.batch_size, + # p_lr=args.p_lr, + # p_weight_decay=args.p_weight_decay, + # p_epoch=args.p_epoch, + # f_lr=args.f_lr, + # f_weight_decay=args.f_weight_decay, + # f_epoch=args.f_epoch, + z_dim=128, + init=False + ) + trainer.initialize() + trainer = trainer.duplicate_from_hyper_parameter( + { + 'trainer': { + 'batch_size': args.batch_size, + 'p_lr': args.p_lr, + 'p_weight_decay': args.p_weight_decay, + 'p_epoch': args.p_epoch, + 'p_early_stopping_round': args.p_epoch + 1, + 'f_lr': args.f_lr, + 'f_weight_decay': args.f_weight_decay, + 'f_epoch': args.f_epoch, + 'f_early_stopping_round': args.f_epoch + 1, + }, + "encoder": encoder_hp, + "decoder": decoder_hp, + "prediction_head": prediction_head + } + ) + trainer.train(dataset, False) + out = trainer.predict(dataset, 'test').detach().cpu().numpy() + train_result = trainer.evaluate(dataset, 'train') + valid_result = trainer.evaluate(dataset, 'val') + test_result = trainer.evaluate(dataset, 'test') + print(f"{train_result[0]} - {valid_result[0]} - {test_result[0]}") + accs[0].append(train_result[0]) + accs[1].append(valid_result[0]) + accs[2].append(test_result[0]) + print('{:.4f} ~ {:.4f}'.format(np.mean(accs[0]), np.std(accs[0]))) + print('{:.4f} ~ {:.4f}'.format(np.mean(accs[1]), np.std(accs[1]))) + print('{:.4f} ~ {:.4f}'.format(np.mean(accs[2]), np.std(accs[2]))) \ No newline at end of file