Browse Source

fix bug

develop/0.4/predevelop
defineZYP 3 years ago
parent
commit
5d7b06f6f0
9 changed files with 313 additions and 66 deletions
  1. +4
    -1
      autogl/datasets/_ogb.py
  2. +8
    -25
      autogl/module/train/ssl/base.py
  3. +5
    -13
      autogl/module/train/ssl/graphcl.py
  4. +17
    -0
      autogl/module/train/ssl/losses.py
  5. +20
    -20
      autogl/module/train/ssl/utils.py
  6. +137
    -0
      autogl/module/train/ssl/views_fn.py
  7. +1
    -1
      autogl/solver/classifier/ssl/ssl_graph_classifier.py
  8. +6
    -6
      test/trainer/pyg/graphcl_ssl.py
  9. +115
    -0
      test/trainer/pyg/graphcl_ssl_full.py

+ 4
- 1
autogl/datasets/_ogb.py View File

@@ -40,7 +40,10 @@ class _OGBNDatasetUtil(_OGBDatasetUtil):
edge_feat = torch.tensor(edge_feat)
edge_index = SparseTensor(row=torch.tensor(edge_index[0]), col=torch.tensor(edge_index[1]), value=edge_feat, sparse_sizes=(num_nodes, num_nodes))
_, _, value = edge_index.coo()
ogbn_data['edge_feat'] = value.cpu().detach().numpy()
if value is not None:
ogbn_data['edge_feat'] = value.cpu().detach().numpy()
else:
ogbn_data['edge_feat'] = edge_feat
edge_index = edge_index.to_symmetric()
row, col, _ = edge_index.coo()
edge_index = np.array([row.cpu().detach().numpy(), col.cpu().detach().numpy()])


+ 8
- 25
autogl/module/train/ssl/base.py View File

@@ -18,7 +18,7 @@ from torch.optim.lr_scheduler import (
ReduceLROnPlateau,
)

from dig.sslgraph.method.contrastive.objectives import NCE_loss, JSE_loss
from .losses import NTXent_loss
from .utils import get_view_by_name

from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer
@@ -49,7 +49,7 @@ class BaseContrastiveTrainer(BaseTrainer):
feval: _typing.Union[
_typing.Sequence[str], _typing.Sequence[_typing.Type[Evaluation]]
] = (Acc,),
loss: Union[str, Callable] = "NCE",
loss: Union[str, Callable] = "NT_Xent",
f_loss: Union[str, Callable] = "nll_loss",
views_fn: _typing.Union[
_typing.Sequence[_typing.Callable], None
@@ -59,7 +59,6 @@ class BaseContrastiveTrainer(BaseTrainer):
node_level: bool = False,
z_dim: _typing.Union[int, None] = None,
z_node_dim: _typing.Union[int, None] = None,
neg_by_crpt: bool = False,
tau: int = 0.5,
p_optim: Union[torch.optim.Optimizer, str] = "Adam",
p_lr: float = 0.0001,
@@ -110,16 +109,13 @@ class BaseContrastiveTrainer(BaseTrainer):
The dimension of graph-level representations
z_node_dim: `int`, Optional
The dimension of node-level representations
neg_by_crpt: `bool`, Optional
The mode to obtain negative samples
tau: `int`, Optional
The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE"
The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent"
model_path: `str` or None, Optional
The directory to restore the saved model.
If `model_path` = None, the model will not be saved.
"""
assert (node_level or graph_level) is True
assert not (loss == "NCE" and neg_by_crpt)
assert isinstance(encoder, BaseEncoderMaintainer) or isinstance(encoder, str) or encoder is None
self.loss = self._get_loss(loss)
self.node_level = node_level
@@ -141,7 +137,6 @@ class BaseContrastiveTrainer(BaseTrainer):
self.last_dim = z_dim if graph_level else z_node_dim
self.num_features = num_features
self.num_graph_features = num_graph_features
self.neg_by_crpt = neg_by_crpt
self.tau = tau
self.model_path = model_path
if isinstance(device, str):
@@ -195,8 +190,8 @@ class BaseContrastiveTrainer(BaseTrainer):
if callable(loss):
return loss
elif isinstance(loss, str):
assert loss in ['JSE', 'NCE']
return {'JSE': JSE_loss, 'NCE': NCE_loss}[loss]
assert loss in ['NT_Xent']
return {'NT_Xent': NTXent_loss}[loss]
else:
raise NotImplementedError("The argument `loss` should be str or callable which returns a loss tensor")

@@ -453,7 +448,7 @@ class BaseContrastiveTrainer(BaseTrainer):
for view in views:
z = self._get_embed(view.to(self.device))
zs.append(self.decoder.decoder(z, view.to(self.device)))
loss = self.loss(zs, neg_by_crpt=self.neg_by_crpt, tau=self.tau)
loss = self.loss(zs, tau=self.tau)
loss.backward()
optimizer.step()
if self.p_lr_scheduler_type:
@@ -474,7 +469,7 @@ class BaseContrastiveTrainer(BaseTrainer):
for view in views:
z = self._get_embed(view.to(self.device))
zs.append(self.decoder.decoder(z, view.to(self.device)))
loss = self.loss(zs, neg_by_crpt=self.neg_by_crpt, tau=self.tau)
loss = self.loss(zs, tau=self.tau)
epoch_loss += loss.item()
last_loss = loss.item()
return epoch_loss, last_loss
@@ -534,19 +529,7 @@ class BaseContrastiveTrainer(BaseTrainer):
return self.encoder.encoder.to(self.device)

def _get_embed(self, view):
if self.neg_by_crpt:
view_crpt = self._corrupt_graph(view)
if self.node_level and self.graph_level:
z_g, z_n = self.encoder.encoder(view)
z_g_crpt, z_n_crpt = self.encoder.encoder(view_crpt)
z = (torch.cat([z_g, z_g_crpt], 0),
torch.cat([z_n, z_n_crpt], 0))
else:
z = self.encoder.encoder(view)
z_crpt = self.encoder.encoder(view_crpt)
z = torch.cat([z, z_crpt], 0)
else:
z = self.encoder.encoder(view)
z = self.encoder.encoder(view)
return z

def predict(self, dataset, mask="test"):


+ 5
- 13
autogl/module/train/ssl/graphcl.py View File

@@ -1,4 +1,3 @@
# codes in this file are reproduced from <https://github.com/divelab/DIG> with some changes.
import os
import torch
import logging
@@ -14,7 +13,6 @@ from typing import Union, Tuple, Sequence, Type, Callable

from tqdm import trange
from copy import deepcopy
from dig.sslgraph.evaluation.eval_graph import k_fold

from .base import BaseContrastiveTrainer

@@ -54,7 +52,6 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer):
] = None,
aug_ratio: Union[float, Sequence[float]] = 0.2,
z_dim: Union[int, None] = 128,
neg_by_crpt: bool = False,
tau: int = 0.5,
model_path: Union[str, None] = "./models",
num_workers: int = 0,
@@ -105,10 +102,8 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer):
If aug_ratio is set as a list of float, the value of this list and views_fn one to one correspondence.
z_dim: `int`
The dimension of graph-level representations
neg_by_crpt: `bool`
The mode to obtain negative samples. Only required when `loss` = "JSE"
tau: `int`
The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE"
The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent"
model_path: `str` or None
The directory to restore the saved model.
If `model_path` = None, the model will not be saved.
@@ -165,9 +160,10 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer):
feval=feval,
z_dim=z_dim,
z_node_dim=None,
neg_by_crpt=neg_by_crpt,
tau=tau,
model_path=model_path
model_path=model_path,
*args,
**kwargs
)
self.views_fn = views_fn
self.aug_ratio = aug_ratio
@@ -438,7 +434,6 @@ class GraphCLSemisupervisedTrainer(BaseContrastiveTrainer):
views_fn=self.views_fn_opt,
aug_ratio=self.aug_ratio,
z_dim=self.last_dim,
neg_by_crpt=self.neg_by_crpt,
tau=self.tau,
model_path=self.model_path,
num_workers=self.num_workers,
@@ -530,10 +525,8 @@ class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer):
If aug_ratio is set as a list of float, the value of this list and views_fn one to one correspondence.
z_dim: `int`
The dimension of graph-level representations
neg_by_crpt: `bool`
The mode to obtain negative samples. Only required when `loss` = "JSE"
tau: `int`
The temperature parameter in InfoNCE loss. Only used when `loss` = "NCE"
The temperature parameter in NT_Xent loss. Only used when `loss` = "NT_Xent"
model_path: `str` or None
The directory to restore the saved model.
If `model_path` = None, the model will not be saved.
@@ -894,7 +887,6 @@ class GraphCLUnsupervisedTrainer(BaseContrastiveTrainer):
views_fn=self.views_fn_opt,
aug_ratio=self.aug_ratio,
z_dim=self.last_dim,
neg_by_crpt=self.neg_by_crpt,
tau=self.tau,
model_path=self.model_path,
num_workers=self.num_workers,


+ 17
- 0
autogl/module/train/ssl/losses.py View File

@@ -0,0 +1,17 @@
# NTXent_loss from <https://github.com/Shen-Lab/GraphCL/>
import torch
import torch.nn as nn

def NTXent_loss(zs, tau=0.5, norm=True):
batch_size, _ = zs[0].size()
sim_matrix = torch.einsum('ik,jk->ij', zs[0], zs[1])
if norm:
z1_abs = zs[0].norm(dim=1)
z2_abs = zs[1].norm(dim=1)
sim_matrix = sim_matrix / torch.einsum('i,j->ij', z1_abs, z2_abs)
sim_matrix = torch.exp(sim_matrix/tau)
pos_sim = sim_matrix[range(batch_size), range(batch_size)]
loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim)
loss = - torch.log(loss).mean()
return loss

+ 20
- 20
autogl/module/train/ssl/utils.py View File

@@ -1,8 +1,8 @@
from dig.sslgraph.method.contrastive.views_fn import (
NodeAttrMask,
EdgePerturbation,
UniformSample,
RWSample,
from .views_fn import (
DropNode,
PermuteEdge,
MaskNode,
SubGraph,
RandomView
)

@@ -10,28 +10,28 @@ def get_view_by_name(view, aug_ratio):
if view is None:
return lambda x: x
elif view == "dropN":
return UniformSample(ratio=aug_ratio)
return DropNode(aug_ratio=aug_ratio)
elif view == "permE":
return EdgePerturbation(ratio=aug_ratio)
return PermuteEdge(aug_ratio=aug_ratio)
elif view == "subgraph":
return RWSample(ratio=aug_ratio)
return SubGraph(aug_ratio=aug_ratio)
elif view == "maskN":
return NodeAttrMask(mask_ratio=aug_ratio)
return MaskNode(aug_ratio=aug_ratio)
elif view == "random2":
canditates = [UniformSample(ratio=aug_ratio),
RWSample(ratio=aug_ratio)]
canditates = [DropNode(aug_ratio=aug_ratio),
SubGraph(aug_ratio=aug_ratio)]
return RandomView(candidates=canditates)
elif view == "random3":
canditates = [UniformSample(ratio=aug_ratio),
RWSample(ratio=aug_ratio),
EdgePerturbation(ratio=aug_ratio)]
canditates = [DropNode(aug_ratio=aug_ratio),
SubGraph(aug_ratio=aug_ratio),
PermuteEdge(aug_ratio=aug_ratio)]
return RandomView(candidates=canditates)
elif view == "random4":
canditates = [UniformSample(ratio=aug_ratio),
RWSample(ratio=aug_ratio),
EdgePerturbation(ratio=aug_ratio),
NodeAttrMask(mask_ratio=aug_ratio)]
canditates = [DropNode(aug_ratio=aug_ratio),
SubGraph(aug_ratio=aug_ratio),
PermuteEdge(aug_ratio=aug_ratio),
MaskNode(aug_ratio=aug_ratio)]
return RandomView(candidates=canditates)
else:
raise NotImplementedError(f'The augmentation method must be in ["dropN", "permE", "subgraph", \
"maskN", "random2", "random3", "random4"] or None. And {view} is not supported yet.')
raise NotImplementedError(f'{view} is not supported yet. Support: ["dropN", "permE", "subgraph", \
"maskN", "random2", "random3", "random4", None]')

+ 137
- 0
autogl/module/train/ssl/views_fn.py View File

@@ -0,0 +1,137 @@
# pyg augmentation method from <https://github.com/Shen-Lab/GraphCL/>

import random
import torch
import numpy as np
from itertools import repeat, product
from torch_geometric.data import Batch

class BaseAugmentation:
def __init__(self, aug_ratio=None):
self.aug_ratio = aug_ratio
def _aug_data(self, data):
pass
def __call__(self, batch):
new_data = []
for data in batch.to_data_list():
new_data.append(self._aug_data(data))
return Batch.from_data_list(new_data)

class DropNode(BaseAugmentation):
def __init__(self, aug_ratio):
super().__init__(aug_ratio)
def _aug_data(self, data):
node_num, _ = data.x.size()
_, edge_num = data.edge_index.size()
drop_num = int(node_num * self.aug_ratio)

idx_perm = np.random.permutation(node_num)

idx_drop = idx_perm[:drop_num]
idx_nondrop = idx_perm[drop_num:]
idx_nondrop.sort()
idx_dict = {idx_nondrop[n]:n for n in list(range(idx_nondrop.shape[0]))}

edge_index = data.edge_index.numpy()
adj = torch.zeros((node_num, node_num))
adj[edge_index[0], edge_index[1]] = 1
adj = adj[idx_nondrop, :][:, idx_nondrop]
edge_index = adj.nonzero().t()

try:
data.edge_index = edge_index
data.x = data.x[idx_nondrop]
except:
data = data
return data

class PermuteEdge(BaseAugmentation):
def __init__(self, aug_ratio):
super().__init__(aug_ratio)

def _aug_data(self, data):
node_num, _ = data.x.size()
_, edge_num = data.edge_index.size()
permute_num = int(edge_num * self.aug_ratio)

edge_index = data.edge_index.numpy()

idx_add = np.random.choice(node_num, (2, permute_num))

# idx_add = [[idx_add[0, n], idx_add[1, n]] for n in range(permute_num) if not (idx_add[0, n], idx_add[1, n]) in edge_index]
# edge_index = [edge_index[n] for n in range(edge_num) if not n in np.random.choice(edge_num, permute_num, replace=False)] + idx_add

edge_index = np.concatenate((edge_index[:, np.random.choice(edge_num, (edge_num - permute_num), replace=False)], idx_add), axis=1)
data.edge_index = torch.tensor(edge_index)

return data

class SubGraph(BaseAugmentation):
def __init__(self, aug_ratio):
super().__init__(aug_ratio)

def _aug_data(self, data):
node_num, _ = data.x.size()
_, edge_num = data.edge_index.size()
sub_num = int(node_num * self.aug_ratio)

edge_index = data.edge_index.numpy()

idx_sub = [np.random.randint(node_num, size=1)[0]]
idx_neigh = set([n for n in edge_index[1][edge_index[0]==idx_sub[0]]])

count = 0
while len(idx_sub) <= sub_num:
count = count + 1
if count > node_num:
break
if len(idx_neigh) == 0:
break
sample_node = np.random.choice(list(idx_neigh))
if sample_node in idx_sub:
continue
idx_sub.append(sample_node)
idx_neigh.union(set([n for n in edge_index[1][edge_index[0]==idx_sub[-1]]]))

idx_drop = [n for n in range(node_num) if not n in idx_sub]
idx_nondrop = idx_sub
data.x = data.x[idx_nondrop]
idx_dict = {idx_nondrop[n]:n for n in list(range(len(idx_nondrop)))}

edge_index = data.edge_index.numpy()
adj = torch.zeros((node_num, node_num))
adj[edge_index[0], edge_index[1]] = 1
adj[list(range(node_num)), list(range(node_num))] = 1
adj = adj[idx_nondrop, :][:, idx_nondrop]
edge_index = adj.nonzero().t()

# edge_index = [[idx_dict[edge_index[0, n]], idx_dict[edge_index[1, n]]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] + [[n, n] for n in idx_nondrop]
data.edge_index = edge_index

return data

class MaskNode(BaseAugmentation):
def __init__(self, aug_ratio):
super().__init__(aug_ratio)

def _aug_data(self, data):
node_num, feat_dim = data.x.size()
mask_num = int(node_num * self.aug_ratio)

token = data.x.mean(dim=0)
idx_mask = np.random.choice(node_num, mask_num, replace=False)
data.x[idx_mask] = torch.tensor(token, dtype=torch.float32)

return data

class RandomView(BaseAugmentation):
def __init__(self, candidates):
super().__init__()
self.candidates = candidates
def _aug_data(self, data):
view = random.choice(self.candidates)
return view._aug_data(data)

+ 1
- 1
autogl/solver/classifier/ssl/ssl_graph_classifier.py View File

@@ -302,7 +302,7 @@ class SSLGraphClassifier(BaseClassifier):
num_classes=num_classes,
feval=evaluator_list,
device=self.runtime_device,
loss="NCE" if not hasattr(dataset, "loss") else dataset.loss,
loss="NT_Xent" if not hasattr(dataset, "loss") else dataset.loss,
num_graph_features=(0
if not hasattr(dataset[0], "gf")
else dataset[0].gf.size(1)) if BACKEND == 'pyg' else


+ 6
- 6
test/trainer/pyg/graphcl_ssl.py View File

@@ -35,12 +35,12 @@ def test_graph_trainer():
prediction_head="sumpoolmlp",
views_fn=["random2", "random2"],
batch_size=128,
p_lr=5.6004725115062315e-05,
p_weight_decay=0.00022810837622188083,
p_epoch=267,
f_epoch=131,
f_lr=0.0005362155524564354,
f_weight_decay=0.0022069814932058804,
p_lr=0.0001,
p_weight_decay=0.0002,
p_epoch=300,
f_epoch=150,
f_lr=0.0001,
f_weight_decay=0.002,
p_early_stopping_round=50,
f_early_stopping_round=50,
z_dim=128,


+ 115
- 0
test/trainer/pyg/graphcl_ssl_full.py View File

@@ -0,0 +1,115 @@
import os
import random
import torch
import torch.nn as nn
import numpy as np

from autogl.module.train.ssl import GraphCLSemisupervisedTrainer
from autogl.datasets import build_dataset_from_name, utils
from autogl.datasets.utils.conversion import to_pyg_dataset as convert_dataset
from autogl.module.model.encoders.base_encoder import AutoHomogeneousEncoderMaintainer
from autogl.module.model.decoders import BaseDecoderMaintainer
from autogl.solver.utils import set_seed

def fixed(**kwargs):
return [{
'parameterName': k,
'type': "FIXED",
'value': v
} for k, v in kwargs.items()]

if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser('ssl pyg trainer')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--dataset', type=str, choices=['MUTAG', 'NCI1', 'PROTEINS', 'PTC_MR'], default='PROTEINS')
parser.add_argument('--dataset_seed', type=int, default=2021)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--repeat', type=int, default=50)
# parser.add_argument('--model', type=str, choices=['gin', 'gat', 'gcn', 'sage'], default='gin')
parser.add_argument('--encoder', type=str, choices=['gin', 'gcn'], default='gcn')
parser.add_argument('--p_lr', type=float, default=0.0001)
parser.add_argument('--p_weight_decay', type=float, default=0)
parser.add_argument('--p_epoch', type=int, default=100)
parser.add_argument('--f_lr', type=float, default=0.001)
parser.add_argument('--f_weight_decay', type=float, default=0)
parser.add_argument('--f_epoch', type=int, default=100)
parser.add_argument('--epoch', type=int, default=100)

args=parser.parse_args()

# split dataset
dataset = build_dataset_from_name(args.dataset)
dataset = convert_dataset(dataset)
utils.graph_random_splits(dataset, train_ratio=0.1, val_ratio=0.1, seed=2022)

accs = [[],[],[]]

encoder_hp = {
"num_layers": 5,
"hidden": [32, 64, 64, 64],
"dropout": 0.5,
"act": "elu",
"eps": "true"
}
decoder_hp = {
"hidden": 32,
"act": "tanh",
"dropout": 0.35
}
prediction_head = {
"hidden": 128,
"act": "relu",
"dropout": 0.4
}
from tqdm import tqdm
for seed in tqdm(range(args.repeat)):
set_seed(seed)
trainer = GraphCLSemisupervisedTrainer(
model=(args.encoder, 'sumpoolmlp'),
prediction_head='sumpoolmlp',
views_fn=['random2', 'random2'],
device=args.device,
num_features=dataset[0].x.size(1),
num_classes=max([data.y.item() for data in dataset]) + 1,
batch_size=args.batch_size,
# p_lr=args.p_lr,
# p_weight_decay=args.p_weight_decay,
# p_epoch=args.p_epoch,
# f_lr=args.f_lr,
# f_weight_decay=args.f_weight_decay,
# f_epoch=args.f_epoch,
z_dim=128,
init=False
)
trainer.initialize()
trainer = trainer.duplicate_from_hyper_parameter(
{
'trainer': {
'batch_size': args.batch_size,
'p_lr': args.p_lr,
'p_weight_decay': args.p_weight_decay,
'p_epoch': args.p_epoch,
'p_early_stopping_round': args.p_epoch + 1,
'f_lr': args.f_lr,
'f_weight_decay': args.f_weight_decay,
'f_epoch': args.f_epoch,
'f_early_stopping_round': args.f_epoch + 1,
},
"encoder": encoder_hp,
"decoder": decoder_hp,
"prediction_head": prediction_head
}
)
trainer.train(dataset, False)
out = trainer.predict(dataset, 'test').detach().cpu().numpy()
train_result = trainer.evaluate(dataset, 'train')
valid_result = trainer.evaluate(dataset, 'val')
test_result = trainer.evaluate(dataset, 'test')
print(f"{train_result[0]} - {valid_result[0]} - {test_result[0]}")
accs[0].append(train_result[0])
accs[1].append(valid_result[0])
accs[2].append(test_result[0])
print('{:.4f} ~ {:.4f}'.format(np.mean(accs[0]), np.std(accs[0])))
print('{:.4f} ~ {:.4f}'.format(np.mean(accs[1]), np.std(accs[1])))
print('{:.4f} ~ {:.4f}'.format(np.mean(accs[2]), np.std(accs[2])))

Loading…
Cancel
Save