| @@ -1,88 +1,64 @@ | |||||
| <<<<<<< HEAD | |||||
| .. _nas_cn: | |||||
| .. _robust: | |||||
| 图鲁棒性 | |||||
| ============================ | |||||
| 鲁棒模型 | |||||
| ========================== | |||||
| 图鲁棒性是近年图机器学习领域重要的研究方向,我们在AutoGL中集成了图鲁棒性相关算法,可以方便地与其他模块结合使用。 | |||||
| 我们提供了一系列的防御方法,旨在增强图神经网络的鲁棒性。 | |||||
| 背景知识 | |||||
| ------------ | |||||
| 生成并训练 GNNGuard 模型 | |||||
| ------------------------------ | |||||
| (介绍对抗攻击、鲁棒问题的定义等,可以适当引一些paper) | |||||
| 首先,加载预先攻击的图数据: | |||||
| 在AutoGL中,我们将图鲁棒性的算法分为三类,放在不同的模块中实现。 | |||||
| 鲁棒图特征工程旨在数据预处理阶段生成鲁棒的图特征,增强下游任务的鲁棒性。 | |||||
| 鲁棒图神经网络则是通过模型层面的设计,以在训练过程中确保模型的鲁棒性。 | |||||
| 鲁棒图神经网络架构搜索旨在搜索出一个鲁棒的图神经网络架构。 | |||||
| 下文中将分别介绍这三类图鲁棒性算法。 | |||||
| .. code-block:: python | |||||
| 鲁棒图特征工程 | |||||
| --------------------- | |||||
| perturbed_data = PrePtbDataset(root='/tmp/', name=dataset,attack_method='meta', ptb_rate=0.2) | |||||
| modified_adj = perturbed_data.adj | |||||
| 鲁棒图神经网络 | |||||
| --------------------- | |||||
| 然后,在原图 / 扰动图上训练图神经网络模型: | |||||
| 鲁棒图神经网络架构搜索 | |||||
| --------------------- | |||||
| ======= | |||||
| ========================== | |||||
| 鲁棒模型 | |||||
| ========================== | |||||
| 我们提供了一系列的防御方法,旨在增强图神经网络的鲁棒性。 | |||||
| 生成并训练 GNNGuard 模型 | |||||
| ------------------------------ | |||||
| 首先,加载预先攻击的图数据: | |||||
| .. code-block:: python | |||||
| perturbed_data = PrePtbDataset(root='/tmp/', name=dataset,attack_method='meta', ptb_rate=0.2) | |||||
| modified_adj = perturbed_data.adj | |||||
| 然后,在原图 / 扰动图上训练图神经网络模型: | |||||
| .. code-block:: python | |||||
| flag = False | |||||
| print('=== testing GNN on original(clean) graph (AutoGL) ===') | |||||
| print("acc_test:",test_autogl(adj, features, device, attention=flag)) | |||||
| print('=== testing GCN on perturbed graph (AutoGL) ===') | |||||
| print("acc_test:",test_autogl(modified_adj, features, device, attention=flag)) | |||||
| 训练图神经网络的细节如下: | |||||
| .. code-block:: python | |||||
| def test_autogl(adj, features, device, attention): | |||||
| '' | |||||
| """test on GCN """ | |||||
| """model_name could be 'GCN', 'GAT', 'GIN','JK' """ | |||||
| accs = [] | |||||
| for seed in tqdm(range(5)): | |||||
| # defense model | |||||
| gcn = AutoGNNGuard( | |||||
| num_features=pyg_data.num_node_features, | |||||
| num_classes=pyg_data.num_classes, | |||||
| device=args.device, | |||||
| init=False | |||||
| ).from_hyper_parameter(model_hp).model | |||||
| gcn = gcn.to(device) | |||||
| gcn.fit(features, adj, labels, idx_train, idx_val=idx_val, | |||||
| idx_test=idx_test, | |||||
| attention=attention, verbose=True, train_iters=81) | |||||
| gcn.eval() | |||||
| acc_test, output = gcn.test(idx_test=idx_test) | |||||
| accs.append(acc_test.item()) | |||||
| mean = np.mean(accs) | |||||
| std = np.std(accs) | |||||
| return {"mean": mean, "std": std} | |||||
| 最后,在扰动图上训练防御模型GNNGuard | |||||
| .. code-block:: python | |||||
| flag = True | |||||
| print('=== testing GNN on original(clean) graph (AutoGL) + GNNGuard ===') | |||||
| print("acc_test:",test_autogl(adj, features, device, attention=flag)) | |||||
| print('=== testing GCN on perturbed graph (AutoGL) + GNNGuard ===') | |||||
| print("acc_test:",test_autogl(modified_adj, features, device, attention=flag)) | |||||
| >>>>>>> 200a684ee5167c44f74d7ad704506ecbca7e11d6 | |||||
| .. code-block:: python | |||||
| flag = False | |||||
| print('=== testing GNN on original(clean) graph (AutoGL) ===') | |||||
| print("acc_test:",test_autogl(adj, features, device, attention=flag)) | |||||
| print('=== testing GCN on perturbed graph (AutoGL) ===') | |||||
| print("acc_test:",test_autogl(modified_adj, features, device, attention=flag)) | |||||
| 训练图神经网络的细节如下: | |||||
| .. code-block:: python | |||||
| def test_autogl(adj, features, device, attention): | |||||
| '' | |||||
| """test on GCN """ | |||||
| """model_name could be 'GCN', 'GAT', 'GIN','JK' """ | |||||
| accs = [] | |||||
| for seed in tqdm(range(5)): | |||||
| # defense model | |||||
| gcn = AutoGNNGuard( | |||||
| num_features=pyg_data.num_node_features, | |||||
| num_classes=pyg_data.num_classes, | |||||
| device=args.device, | |||||
| init=False | |||||
| ).from_hyper_parameter(model_hp).model | |||||
| gcn = gcn.to(device) | |||||
| gcn.fit(features, adj, labels, idx_train, idx_val=idx_val, | |||||
| idx_test=idx_test, | |||||
| attention=attention, verbose=True, train_iters=81) | |||||
| gcn.eval() | |||||
| acc_test, output = gcn.test(idx_test=idx_test) | |||||
| accs.append(acc_test.item()) | |||||
| mean = np.mean(accs) | |||||
| std = np.std(accs) | |||||
| return {"mean": mean, "std": std} | |||||
| 最后,在扰动图上训练防御模型GNNGuard | |||||
| .. code-block:: python | |||||
| flag = True | |||||
| print('=== testing GNN on original(clean) graph (AutoGL) + GNNGuard ===') | |||||
| print("acc_test:",test_autogl(adj, features, device, attention=flag)) | |||||
| print('=== testing GCN on perturbed graph (AutoGL) + GNNGuard ===') | |||||
| print("acc_test:",test_autogl(modified_adj, features, device, attention=flag)) | |||||