|
- import numpy as np
- import torch
-
-
- def data_is_tensor(data):
- return isinstance(data.x, torch.Tensor)
-
-
- def data_is_numpy(data):
- return isinstance(data.x, np.ndarray)
-
-
- def data_tensor2np(data):
- if data_is_tensor(data):
- data.x = data.x.numpy()
- data.y = data.y.numpy()
- data.edge_index = data.edge_index.numpy()
- return data
-
-
- def data_np2tensor(data):
- if not data_is_tensor(data):
- if data_is_numpy(data):
- data.x = torch.FloatTensor(data.x)
- data.y = torch.tensor(data.y)
- data.edge_index = torch.tensor(data.edge_index, dtype=torch.long)
- return data
-
-
- # from .base import BaseFeatureAtom
- # class DataTensor2Np(BaseFeatureAtom):
- # def __call__(self,data):
- # return data_tensor2np(data)
- # class DataNp2Tensor(BaseFeatureAtom):
- # def __call__(self,data):
- # return data_np2tensor(data)
|