|
- .. _robust:
-
- Robust Model
- ============
-
- We provides a series of defense methods that aim to enhance the robustness of GNNs.
-
- Building GNNGuard Module
- ------------------------
-
- Firstly, load pre-attacked graph data:
-
- .. code-block:: python
-
- perturbed_data = PrePtbDataset(root='/tmp/', name=dataset,attack_method='meta', ptb_rate=0.2)
- modified_adj = perturbed_data.adj
-
- Secondly, train a victim model (GCN) on clearn/poinsed graph:
-
- .. code-block:: python
-
- flag = False
- print('=== testing GNN on original(clean) graph (AutoGL) ===')
- print("acc_test:",test_autogl(adj, features, device, attention=flag))
- print('=== testing GCN on perturbed graph (AutoGL) ===')
- print("acc_test:",test_autogl(modified_adj, features, device, attention=flag))
-
- For details in training GNN models:
-
- .. code-block:: python
-
- def test_autogl(adj, features, device, attention):
- ''
- """test on GCN """
- """model_name could be 'GCN', 'GAT', 'GIN','JK' """
- accs = []
- for seed in tqdm(range(5)):
- # defense model
- gcn = AutoGNNGuard(
- num_features=pyg_data.num_node_features,
- num_classes=pyg_data.num_classes,
- device=args.device,
- init=False
- ).from_hyper_parameter(model_hp).model
- gcn = gcn.to(device)
- gcn.fit(features, adj, labels, idx_train, idx_val=idx_val,
- idx_test=idx_test,
- attention=attention, verbose=True, train_iters=81)
- gcn.eval()
- acc_test, output = gcn.test(idx_test=idx_test)
- accs.append(acc_test.item())
- mean = np.mean(accs)
- std = np.std(accs)
- return {"mean": mean, "std": std}
-
- Thirdly, train defense model GNNGuard on poinsed graph:
-
- .. code-block:: python
-
- flag = True
- print('=== testing GNN on original(clean) graph (AutoGL) + GNNGuard ===')
- print("acc_test:",test_autogl(adj, features, device, attention=flag))
- print('=== testing GCN on perturbed graph (AutoGL) + GNNGuard ===')
- print("acc_test:",test_autogl(modified_adj, features, device, attention=flag))
-
|