diff --git a/examples/grna_test.py b/examples/grna_test.py index 4eec524..d8edb31 100644 --- a/examples/grna_test.py +++ b/examples/grna_test.py @@ -33,10 +33,11 @@ def metattack(data): attack_structure=True, attack_features=False, device='cpu', lambda_=0).to('cpu') # Attack n_perturbations = int(data.edge_index.size(1)/2 * 0.05) + n_perturbations = 1 model.attack(features, adj, labels, idx_train, idx_unlabeled, n_perturbations=n_perturbations, ll_constraint=False) perturbed_adj = model.modified_adj perturbed_data = data.clone() - perturbed_data.edge_index = torch.LongTensor(perturbed_adj.nonzero()) + perturbed_data.edge_index = torch.LongTensor(perturbed_adj.nonzero().T) return perturbed_data @@ -75,6 +76,4 @@ if __name__ == '__main__': data = dataset[0].cpu() dataset[0] = metattack(data).to(device) ptb_acc = test_from_data(trainer, dataset, args) - - - + \ No newline at end of file