You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_netlsd.py 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import netlsd
  2. import networkx
  3. import torch
  4. import typing
  5. from autogl.data.graph import GeneralStaticGraph
  6. from autogl.data.graph.utils import conversion
  7. from .._feature_engineer import FeatureEngineer
  8. from ..._data_preprocessor_registry import DataPreprocessorUniversalRegistry
  9. @DataPreprocessorUniversalRegistry.register_data_preprocessor("NetLSD".lower())
  10. class NetLSD(FeatureEngineer):
  11. r"""
  12. Notes
  13. -----
  14. a graph feature generation method. This is a simple wrapper of NetLSD [#]_.
  15. References
  16. ----------
  17. .. [#] A. Tsitsulin, D. Mottin, P. Karras, A. Bronstein, and E. Müller, “NetLSD: Hearing the shape of a graph,”
  18. Proc. ACM SIGKDD Int. Conf. Knowl. Discov. Data Min., pp. 2347–2356, 2018.
  19. """
  20. def __init__(self, *args, **kwargs):
  21. self.__args = args
  22. self.__kwargs = kwargs
  23. super(NetLSD, self).__init__()
  24. def __extract(self, nx_g: networkx.Graph) -> torch.Tensor:
  25. return torch.tensor(netlsd.heat(nx_g, *self.__args, **self.__kwargs)).view(-1)
  26. def __transform_homogeneous_static_graph(
  27. self, homogeneous_static_graph: GeneralStaticGraph
  28. ) -> GeneralStaticGraph:
  29. if not (
  30. homogeneous_static_graph.nodes.is_homogeneous and
  31. homogeneous_static_graph.edges.is_homogeneous
  32. ):
  33. raise ValueError("Provided static graph must be homogeneous")
  34. dsc: torch.Tensor = self.__extract(
  35. conversion.HomogeneousStaticGraphToNetworkX(to_undirected=True).__call__(
  36. homogeneous_static_graph, to_undirected=True
  37. )
  38. )
  39. if 'gf' in homogeneous_static_graph.data:
  40. gf = homogeneous_static_graph.data['gf'].view(-1)
  41. homogeneous_static_graph.data['gf'] = torch.cat([gf, dsc])
  42. else:
  43. homogeneous_static_graph.data['gf'] = dsc
  44. return homogeneous_static_graph
  45. @classmethod
  46. def __edge_index_to_nx_graph(cls, edge_index: torch.Tensor) -> networkx.Graph:
  47. g: networkx.Graph = networkx.Graph()
  48. for u, v in edge_index.t().tolist():
  49. if u == v:
  50. continue
  51. else:
  52. g.add_edge(u, v)
  53. return g
  54. def __transform_data(self, data):
  55. if not (
  56. hasattr(data, "edge_index") and
  57. torch.is_tensor(data.edge_index) and
  58. isinstance(data.edge_index, torch.Tensor) and
  59. data.edge_index.dim() == data.edge_index.size(0) == 2 and
  60. data.edge_index.dtype == torch.long
  61. ):
  62. raise TypeError("Unsupported provided data")
  63. dsc: torch.Tensor = self.__extract(self.__edge_index_to_nx_graph(data.edge_index))
  64. if hasattr(data, 'gf') and isinstance(data.gf, torch.Tensor):
  65. gf = data.gf.view(-1)
  66. data.gf = torch.cat([gf, dsc])
  67. else:
  68. data.gf = dsc
  69. return data
  70. def _transform(
  71. self, data: typing.Union[GeneralStaticGraph, typing.Any]
  72. ) -> typing.Union[GeneralStaticGraph, typing.Any]:
  73. if isinstance(data, GeneralStaticGraph):
  74. return self.__transform_homogeneous_static_graph(data)
  75. else:
  76. return self.__transform_data(data)