You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_basic.py 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import torch
  2. import typing as _typing
  3. import autogl
  4. from autogl.data.graph import GeneralStaticGraph
  5. from .._base_feature_engineer import BaseFeatureEngineer
  6. from .._feature_engineer_registry import FeatureEngineerUniversalRegistry
  7. class BaseFeatureGenerator(BaseFeatureEngineer):
  8. def __init__(self, override_features: bool = False):
  9. super(BaseFeatureGenerator, self).__init__()
  10. if not isinstance(override_features, bool):
  11. raise TypeError
  12. else:
  13. self._override_features: bool = override_features
  14. def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor:
  15. raise NotImplementedError
  16. def __transform_homogeneous_static_graph(
  17. self, homogeneous_static_graph: GeneralStaticGraph
  18. ) -> GeneralStaticGraph:
  19. if not (
  20. homogeneous_static_graph.nodes.is_homogeneous and
  21. homogeneous_static_graph.edges.is_homogeneous
  22. ):
  23. raise ValueError("Provided static graph must be homogeneous")
  24. if 'x' in homogeneous_static_graph.nodes.data:
  25. feature_key: _typing.Optional[str] = 'x'
  26. features: _typing.Optional[torch.Tensor] = (
  27. homogeneous_static_graph.nodes.data['x']
  28. )
  29. elif 'feat' in homogeneous_static_graph.nodes.data:
  30. feature_key: _typing.Optional[str] = 'feat'
  31. features: _typing.Optional[torch.Tensor] = (
  32. homogeneous_static_graph.nodes.data['feat']
  33. )
  34. else:
  35. feature_key: _typing.Optional[str] = None
  36. features: _typing.Optional[torch.Tensor] = None
  37. if 'y' in homogeneous_static_graph.nodes.data:
  38. label: _typing.Optional[torch.Tensor] = (
  39. homogeneous_static_graph.nodes.data['y']
  40. )
  41. elif 'label' in homogeneous_static_graph.nodes.data:
  42. label: _typing.Optional[torch.Tensor] = (
  43. homogeneous_static_graph.nodes.data['label']
  44. )
  45. else:
  46. label: _typing.Optional[torch.Tensor] = None
  47. if (
  48. 'edge_weight' in homogeneous_static_graph.edges.data and
  49. homogeneous_static_graph.edges.data['edge_weight'].dim() == 1
  50. ):
  51. edge_weight: torch.Tensor = (
  52. homogeneous_static_graph.edges.data['edge_weight']
  53. )
  54. else:
  55. edge_weight: torch.Tensor = torch.ones(
  56. homogeneous_static_graph.edges.connections.size(1)
  57. )
  58. data = autogl.data.Data(
  59. edge_index=homogeneous_static_graph.edges.connections,
  60. x=features, y=label
  61. )
  62. setattr(data, "edge_weight", edge_weight)
  63. extracted_features: torch.Tensor = self._extract_nodes_feature(data)
  64. if isinstance(feature_key, str):
  65. nodes_features: torch.Tensor = (
  66. homogeneous_static_graph.nodes.data[feature_key].view(-1, 1)
  67. if homogeneous_static_graph.nodes.data[feature_key].dim() == 1
  68. else homogeneous_static_graph.nodes.data[feature_key]
  69. )
  70. assert extracted_features.size(0) == nodes_features.size(0)
  71. assert extracted_features.dim() == nodes_features.dim() == 2
  72. homogeneous_static_graph.nodes.data[feature_key] = (
  73. extracted_features.to(nodes_features.device)
  74. if self._override_features
  75. else torch.cat(
  76. [nodes_features, extracted_features.to(nodes_features.device)], dim=-1
  77. )
  78. )
  79. else:
  80. if autogl.backend.DependentBackend.is_pyg():
  81. homogeneous_static_graph.nodes.data['x'] = extracted_features
  82. elif autogl.backend.DependentBackend.is_dgl():
  83. homogeneous_static_graph.nodes.data['feat'] = extracted_features
  84. return homogeneous_static_graph
  85. def _transform(
  86. self, data: _typing.Union[GeneralStaticGraph, _typing.Any]
  87. ) -> _typing.Union[GeneralStaticGraph, _typing.Any]:
  88. if isinstance(data, GeneralStaticGraph):
  89. return self.__transform_homogeneous_static_graph(data)
  90. else:
  91. data.x = self._extract_nodes_feature(data)
  92. return data
  93. @FeatureEngineerUniversalRegistry.register_feature_engineer("OneHot".lower())
  94. class OneHotFeatureGenerator(BaseFeatureGenerator):
  95. def _extract_nodes_feature(self, data: autogl.data.Data) -> torch.Tensor:
  96. num_nodes: int = (
  97. data.x.size(0)
  98. if data.x is not None and isinstance(data.x, torch.Tensor)
  99. else (data.edge_index.max().item() + 1)
  100. )
  101. return torch.eye(num_nodes)