You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 884 B

123456789101112131415161718192021222324252627282930313233343536
  1. import numpy as np
  2. import torch
  3. def data_is_tensor(data):
  4. return isinstance(data.x, torch.Tensor)
  5. def data_is_numpy(data):
  6. return isinstance(data.x, np.ndarray)
  7. def data_tensor2np(data):
  8. if data_is_tensor(data):
  9. data.x = data.x.numpy()
  10. data.y = data.y.numpy()
  11. data.edge_index = data.edge_index.numpy()
  12. return data
  13. def data_np2tensor(data):
  14. if not data_is_tensor(data):
  15. if data_is_numpy(data):
  16. data.x = torch.FloatTensor(data.x)
  17. data.y = torch.tensor(data.y)
  18. data.edge_index = torch.tensor(data.edge_index, dtype=torch.long)
  19. return data
  20. # from .base import BaseFeatureAtom
  21. # class DataTensor2Np(BaseFeatureAtom):
  22. # def __call__(self,data):
  23. # return data_tensor2np(data)
  24. # class DataNp2Tensor(BaseFeatureAtom):
  25. # def __call__(self,data):
  26. # return data_np2tensor(data)