Browse Source

Update t_robust.rst

develop/0.4/predevelop
caijie1998 GitHub 3 years ago
parent
commit
0dc4e0cfee
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 56 additions and 80 deletions
  1. +56
    -80
      docs/docfile/tutorial_cn/t_robust.rst

+ 56
- 80
docs/docfile/tutorial_cn/t_robust.rst View File

@@ -1,88 +1,64 @@
<<<<<<< HEAD
.. _nas_cn:
.. _robust:
图鲁棒性
============================
鲁棒模型
==========================
图鲁棒性是近年图机器学习领域重要的研究方向,我们在AutoGL中集成了图鲁棒性相关算法,可以方便地与其他模块结合使用
我们提供了一系列的防御方法,旨在增强图神经网络的鲁棒性
背景知识
------------
生成并训练 GNNGuard 模型
------------------------------
(介绍对抗攻击、鲁棒问题的定义等,可以适当引一些paper)
首先,加载预先攻击的图数据:
在AutoGL中,我们将图鲁棒性的算法分为三类,放在不同的模块中实现。
鲁棒图特征工程旨在数据预处理阶段生成鲁棒的图特征,增强下游任务的鲁棒性。
鲁棒图神经网络则是通过模型层面的设计,以在训练过程中确保模型的鲁棒性。
鲁棒图神经网络架构搜索旨在搜索出一个鲁棒的图神经网络架构。
下文中将分别介绍这三类图鲁棒性算法。
.. code-block:: python
鲁棒图特征工程
---------------------
perturbed_data = PrePtbDataset(root='/tmp/', name=dataset,attack_method='meta', ptb_rate=0.2)
modified_adj = perturbed_data.adj
鲁棒图神经网络
---------------------
然后,在原图 / 扰动图上训练图神经网络模型:
鲁棒图神经网络架构搜索
---------------------
=======
==========================
鲁棒模型
==========================

我们提供了一系列的防御方法,旨在增强图神经网络的鲁棒性。

生成并训练 GNNGuard 模型
------------------------------

首先,加载预先攻击的图数据:

.. code-block:: python
perturbed_data = PrePtbDataset(root='/tmp/', name=dataset,attack_method='meta', ptb_rate=0.2)
modified_adj = perturbed_data.adj

然后,在原图 / 扰动图上训练图神经网络模型:

.. code-block:: python
flag = False
print('=== testing GNN on original(clean) graph (AutoGL) ===')
print("acc_test:",test_autogl(adj, features, device, attention=flag))
print('=== testing GCN on perturbed graph (AutoGL) ===')
print("acc_test:",test_autogl(modified_adj, features, device, attention=flag))

训练图神经网络的细节如下:

.. code-block:: python
def test_autogl(adj, features, device, attention):
''
"""test on GCN """
"""model_name could be 'GCN', 'GAT', 'GIN','JK' """
accs = []
for seed in tqdm(range(5)):
# defense model
gcn = AutoGNNGuard(
num_features=pyg_data.num_node_features,
num_classes=pyg_data.num_classes,
device=args.device,
init=False
).from_hyper_parameter(model_hp).model
gcn = gcn.to(device)
gcn.fit(features, adj, labels, idx_train, idx_val=idx_val,
idx_test=idx_test,
attention=attention, verbose=True, train_iters=81)
gcn.eval()
acc_test, output = gcn.test(idx_test=idx_test)
accs.append(acc_test.item())
mean = np.mean(accs)
std = np.std(accs)
return {"mean": mean, "std": std}

最后,在扰动图上训练防御模型GNNGuard

.. code-block:: python
flag = True
print('=== testing GNN on original(clean) graph (AutoGL) + GNNGuard ===')
print("acc_test:",test_autogl(adj, features, device, attention=flag))
print('=== testing GCN on perturbed graph (AutoGL) + GNNGuard ===')
print("acc_test:",test_autogl(modified_adj, features, device, attention=flag))
>>>>>>> 200a684ee5167c44f74d7ad704506ecbca7e11d6
.. code-block:: python
flag = False
print('=== testing GNN on original(clean) graph (AutoGL) ===')
print("acc_test:",test_autogl(adj, features, device, attention=flag))
print('=== testing GCN on perturbed graph (AutoGL) ===')
print("acc_test:",test_autogl(modified_adj, features, device, attention=flag))
训练图神经网络的细节如下:
.. code-block:: python
def test_autogl(adj, features, device, attention):
''
"""test on GCN """
"""model_name could be 'GCN', 'GAT', 'GIN','JK' """
accs = []
for seed in tqdm(range(5)):
# defense model
gcn = AutoGNNGuard(
num_features=pyg_data.num_node_features,
num_classes=pyg_data.num_classes,
device=args.device,
init=False
).from_hyper_parameter(model_hp).model
gcn = gcn.to(device)
gcn.fit(features, adj, labels, idx_train, idx_val=idx_val,
idx_test=idx_test,
attention=attention, verbose=True, train_iters=81)
gcn.eval()
acc_test, output = gcn.test(idx_test=idx_test)
accs.append(acc_test.item())
mean = np.mean(accs)
std = np.std(accs)
return {"mean": mean, "std": std}
最后,在扰动图上训练防御模型GNNGuard
.. code-block:: python
flag = True
print('=== testing GNN on original(clean) graph (AutoGL) + GNNGuard ===')
print("acc_test:",test_autogl(adj, features, device, attention=flag))
print('=== testing GCN on perturbed graph (AutoGL) + GNNGuard ===')
print("acc_test:",test_autogl(modified_adj, features, device, attention=flag))

Loading…
Cancel
Save