You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_structure_engineer.py 911 B

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from .. import _data_preprocessor
  2. class StructureEngineer(_data_preprocessor.DataPreprocessor):
  3. ...
  4. import torch
  5. from ....utils import get_logger
  6. LOGGER = get_logger("Structure")
  7. def get_feature(data):
  8. """return features : numpy.ndarray
  9. """
  10. for fk in 'x feat'.split():
  11. if fk in data.nodes.data:
  12. features=data.nodes.data[fk].numpy()
  13. return features
  14. def get_edges(data):
  15. return data.edges.connections
  16. def set_edges(data,adj):
  17. data.data["edge_index"]=adj
  18. def to_adjacency_matrix(adj):
  19. """
  20. adj : torch.Tensor [2,E]
  21. return Tensor [N,N]
  22. """
  23. num_nodes=adj.max().item()+1
  24. mat = torch.zeros((num_nodes,num_nodes), dtype=bool)
  25. mat[tuple(adj)]=1
  26. return mat
  27. def to_adjacency_list(adj):
  28. """
  29. adj : Tensor [N,N]
  30. return Tensor [2,E]
  31. """
  32. adj = torch.stack(adj.nonzero(as_tuple=True)).long() # edge list
  33. return adj