|
123456789101112131415161718192021222324252627282930313233343536373839404142 |
- from .. import _data_preprocessor
- class StructureEngineer(_data_preprocessor.DataPreprocessor):
- ...
-
- import torch
- from ....utils import get_logger
- LOGGER = get_logger("Structure")
-
- def get_feature(data):
- """return features : numpy.ndarray
- """
- for fk in 'x feat'.split():
- if fk in data.nodes.data:
- features=data.nodes.data[fk].numpy()
- return features
-
- def get_edges(data):
- return data.edges.connections
-
- def set_edges(data,adj):
- data.data["edge_index"]=adj
-
- def to_adjacency_matrix(adj):
- """
- adj : torch.Tensor [2,E]
- return Tensor [N,N]
- """
- num_nodes=adj.max().item()+1
- mat = torch.zeros((num_nodes,num_nodes), dtype=bool)
- mat[tuple(adj)]=1
- return mat
-
- def to_adjacency_list(adj):
- """
- adj : Tensor [N,N]
- return Tensor [2,E]
- """
- adj = torch.stack(adj.nonzero(as_tuple=True)).long() # edge list
- return adj
-
-
-
|