You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_structure_engineer.py 3.9 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from .. import _data_preprocessor
  2. class StructureEngineer(_data_preprocessor.DataPreprocessor):
  3. ...
  4. import torch
  5. from ....utils import get_logger
  6. LOGGER = get_logger("Structure")
  7. from torch_geometric.utils import to_dense_adj
  8. def get_feature(data):
  9. """return features : numpy.ndarray
  10. """
  11. for fk in 'x feat'.split():
  12. if fk in data.nodes.data:
  13. features=data.nodes.data[fk].numpy()
  14. return features
  15. def get_edges(data):
  16. return data.edges.connections
  17. def set_edges(data,adj):
  18. data.data["edge_index"]=adj
  19. def to_adjacency_matrix(adj):
  20. """
  21. adj : torch.Tensor [2,E]
  22. return Tensor [N,N]
  23. """
  24. adj = to_dense_adj(adj)[0].long() # adjacency matrix
  25. return adj
  26. def to_adjacency_list(adj):
  27. """
  28. adj : Tensor [N,N]
  29. return Tensor [2,E]
  30. """
  31. adj = torch.stack(adj.nonzero(as_tuple=True)).long() # edge list
  32. return adj
  33. from .._data_preprocessor_registry import DataPreprocessorUniversalRegistry
  34. from deeprobust.graph.defense.gcn_preprocess import GCNJaccard as Jaccard
  35. @DataPreprocessorUniversalRegistry.register_data_preprocessor("gcnjaccard")
  36. class GCNJaccard(StructureEngineer):
  37. """
  38. GCNJaccard preprocesses input graph via droppining dissimilar
  39. edges. See more details in
  40. Adversarial Examples on Graph Data: Deep Insights into Attack and Defense,
  41. https://arxiv.org/pdf/1903.01610.pdf.
  42. """
  43. def __init__(self, threshold=0.01, *args, **kwargs):
  44. """ drop dissimilar edges with similarity smaller than given threshold
  45. Parameters
  46. ----------
  47. threshold : float
  48. similarity threshold for dropping edges. If two connected nodes with similarity smaller than threshold, the edge between them will be removed.
  49. """
  50. super(GCNJaccard, self).__init__(*args, **kwargs)
  51. self.engine=Jaccard(2,2,2)
  52. self.engine.threshold=threshold
  53. def _transform(self,data):
  54. features = get_feature(data)
  55. adj = get_edges(data) # edge list
  56. LOGGER.info(f'before modified: {adj.shape}')
  57. adj = to_adjacency_matrix(adj).numpy() # adjacency matrix
  58. modified_adj = self.engine.drop_dissimilar_edges(features, adj).toarray() # adjacency matrix
  59. modified_adj = to_adjacency_list(torch.Tensor(modified_adj)) # edge list
  60. LOGGER.info(f'after modified: {modified_adj.shape}' )
  61. set_edges(data,modified_adj)
  62. return data
  63. from deeprobust.graph.defense.gcn_preprocess import GCNSVD as SVD
  64. @DataPreprocessorUniversalRegistry.register_data_preprocessor("gcnsvd")
  65. class GCNSVD(StructureEngineer):
  66. """GCNSVD uses Truncated SVD as preprocessing.See more details in All You Need Is Low (Rank): Defending
  67. Against Adversarial Attacks on Graphs,
  68. https://dl.acm.org/doi/abs/10.1145/3336191.3371789.
  69. """
  70. def __init__(self, k=50, threshold=0.05, *args, **kwargs):
  71. """perform rank-k approximation of adjacency matrix via
  72. truncated SVD
  73. Parameters
  74. ----------
  75. k : int
  76. number of singular values and vectors to compute.
  77. threshold : float
  78. edges with scores larger than threshold will be kept.
  79. """
  80. super(GCNSVD, self).__init__(*args, **kwargs)
  81. self.engine=SVD(2,2,2)
  82. self.k=k
  83. self.threshold=threshold
  84. def _transform(self,data):
  85. adj = get_edges(data) # edge list
  86. LOGGER.info(f'before modified: {adj.shape}')
  87. adj = to_adjacency_matrix(adj).numpy() # adjacency matrix
  88. modified_adj = self.engine.truncatedSVD(adj,self.k) # adjacency matrix
  89. modified_adj = (modified_adj> self.threshold).astype(int)
  90. modified_adj = to_adjacency_list(torch.Tensor(modified_adj)) # edge list
  91. LOGGER.info(f'after modified: {modified_adj.shape}' )
  92. set_edges(data,modified_adj)
  93. return data