You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grna_test.py 3.0 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import os
  2. os.environ["AUTOGL_BACKEND"] = "pyg"
  3. import sys
  4. sys.path.append('../')
  5. from autogl.datasets import build_dataset_from_name
  6. from autogl.datasets.utils.conversion._to_pyg_dataset import to_pyg_dataset
  7. from autogl.solver import AutoNodeClassifier
  8. from autogl.module.train import Acc
  9. from autogl.solver.utils import set_seed
  10. import argparse
  11. import torch
  12. from tqdm import tqdm
  13. import numpy as np
  14. import time
  15. from torch_geometric.utils import to_scipy_sparse_matrix
  16. from deeprobust.graph.defense import GCN
  17. from deeprobust.graph.global_attack import Metattack
  18. def metattack(data):
  19. print('Meta-attack...')
  20. adj, features, labels = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes), data.x.numpy(), data.y.numpy()
  21. idx = np.arange(data.num_nodes)
  22. idx_train, idx_val, idx_test = idx[data.train_mask], idx[data.val_mask], idx[data.test_mask]
  23. idx_unlabeled = np.union1d(idx_val, idx_test)
  24. # Setup Surrogate model
  25. surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1,
  26. nhid=16, dropout=0, with_relu=False, with_bias=False, device=device).to(device)
  27. surrogate.fit(features, adj, labels, idx_train, idx_val, patience=30)
  28. # Setup Attack Model
  29. model = Metattack(surrogate, nnodes=adj.shape[0], feature_shape=features.shape,
  30. attack_structure=True, attack_features=False, device=device, lambda_=0).to(device)
  31. # Attack
  32. n_perturbations = int(data.edge_index.size(1)/2 * 0.05)
  33. model.attack(features, adj, labels, idx_train, idx_unlabeled, n_perturbations=n_perturbations, ll_constraint=False)
  34. perturbed_adj = model.modified_adj
  35. perturbed_data = data.clone()
  36. perturbed_data.edge_index = torch.LongTensor(perturbed_adj.nonzero().T)
  37. return perturbed_data
  38. def test_from_data(trainer, dataset, args):
  39. for seed in tqdm(range(args.repeat)):
  40. set_seed(seed)
  41. trainer.train(dataset)
  42. acc = trainer.evaluate(dataset, mask='test')
  43. return acc
  44. if __name__ == '__main__':
  45. time0 = time.time()
  46. set_seed(202106)
  47. parser = argparse.ArgumentParser()
  48. parser.add_argument('--config', type=str, default='../configs/nodeclf_nas_grna.yml')
  49. parser.add_argument('--dataset', choices=['cora', 'citeseer', 'pubmed'], default='citeseer', type=str)
  50. parser.add_argument('--repeat', type=int, default=1)
  51. args = parser.parse_args()
  52. device = 'cuda'
  53. dataset = build_dataset_from_name(args.dataset)
  54. print('architecture search')
  55. solver = AutoNodeClassifier.from_config(args.config)
  56. solver.fit(dataset)
  57. solver.get_leaderboard().show()
  58. orig_acc = solver.evaluate(metric="acc")
  59. trainer = solver.graph_model_list[0]
  60. trainer.device = device
  61. ## test searched model on clean data
  62. dataset = to_pyg_dataset(dataset)
  63. acc = test_from_data(trainer, dataset, args)
  64. ## test searched model on perturbed data
  65. data = dataset[0].cpu()
  66. dataset[0] = metattack(data).to(device)
  67. ptb_acc = test_from_data(trainer, dataset, args)