|
- from operator import xor
- from .utils import data_is_tensor, data_tensor2np, data_np2tensor
- import numpy as np
- import copy
- from torch_geometric.utils.convert import to_networkx
- from torch.utils.data import Dataset
- import torch
- from ...utils import get_logger
-
- LOGGER = get_logger("Feature")
-
-
- class BaseFeature:
- r"""Any feature funcion object should inherit BaseFeature,
- which provides basic transformations and composing operation for feature
- engineering. Basic transformations include data type adjusting(tensor or numpy),
- complementing necessary attributes for future transform. Any subclass needs
- to overload methods ``_func`` and ``_transform`` to implement feature transformation.
- For specific needs, you may want to overload methods ``_preprocess`` and ``_postprocess``
- to enable specific processing before and after ``_transform`` .
-
- Parameters
- ----------
- pipe : list
- stores pipeline of ``BaseFeature``.
- data_t: str
- represents the data type needed for this transform, where 'tensor' accounts for ``torch.Tensor``,
- 'np' for ``numpy.array`` and 'nx' for ``networkx``. When ``data_t`` values 'nx', then a ``networkx.DiGraph`` will
- be added to data as data.G .
- multigraph : bool
- determine whether it supports dataset with multiple graphs
- subgraph : bool
- determine whether it extracts subgraph features.
- """
-
- def __init__(self, pipe=None, data_t="tensor", multigraph=True, subgraph=False):
- r""""""
- if pipe is None:
- pipe = [self]
- self._pipe = pipe
- self._data_t = data_t
- self._multigraph = multigraph
- self._subgraph = subgraph
-
- def __and__(self, o):
- r"""enable and operation to support feature engineering pipeline syntax like
- SeFilterConstant()&GeEigen()&...
- """
- return BaseFeature(self._pipe + o._pipe)
-
- def _rebuild(self, dataset, datalist):
- dataset.__indices__ = None
- data, slices = dataset.collate(datalist)
- dataset.data.__dict__.update(data.__dict__)
- dataset.slices.update(slices)
- return dataset
-
- def _adjust_t(self, data):
- r"""adjust data type for current transform."""
- if self._data_t == "tensor":
- data_np2tensor(data)
- elif self._data_t == "np":
- data_tensor2np(data)
- elif self._data_t == "nx":
- if not hasattr(data, "G") or data.G is None:
- data.G = to_networkx(data, to_undirected=True)
-
- def _adjust_to_tensor(self, data):
- if self._data_t == "tensor":
- pass
- else:
- data_np2tensor(data)
-
- def _preprocess(self, data):
- pass
-
- def _postprocess(self, data):
- pass
-
- def _check_dataset(self, dataset):
- if len(dataset) > 1:
- for p in self._pipe:
- if not p._multigraph:
- LOGGER.warn(p.__class__.__name__, " does not support multigraph")
- return False
- return True
-
- def _fit(self, data):
- pass
-
- def _transform(self, data):
- return data
-
- def _fit_transform(self, data):
- self._fit(data)
- return self._transform(data)
-
- def fit(self, dataset):
- r"""fit dataset"""
- if not self._check_dataset(dataset):
- return
- dataset = copy.deepcopy(dataset)
- with torch.no_grad():
- for p in self._pipe:
- _dataset = [x for x in dataset]
- for i, datai in enumerate(_dataset):
- p._adjust_t(datai)
- p._preprocess(datai)
- p._fit_transform(datai)
- p._postprocess(datai)
- p._adjust_to_tensor(datai)
- _dataset[i] = datai
- dataset = self._rebuild(dataset, _dataset)
-
- def transform(self, dataset, inplace=True):
- r"""transform dataset inplace or not w.r.t bool argument ``inplace``"""
- if not self._check_dataset(dataset):
- return dataset
- if not inplace:
- dataset = copy.deepcopy(dataset)
- with torch.no_grad():
- for p in self._pipe:
- self._dataset = _dataset = [x for x in dataset]
- for i, datai in enumerate(_dataset):
- p._adjust_t(datai)
- p._preprocess(datai)
- datai = p._transform(datai)
- p._postprocess(datai)
- p._adjust_to_tensor(datai)
- _dataset[i] = datai
- dataset = self._rebuild(dataset, _dataset)
- dataset.data = data_np2tensor(dataset.data)
- return dataset
-
- def fit_transform(self, dataset, inplace=True):
- r"""fit and transform dataset inplace or not w.r.t bool argument ``inplace``"""
- self.fit(dataset)
- return self.transform(dataset, inplace=inplace)
-
- @staticmethod
- def compose(trans_list):
- r"""put a list of ``BaseFeature`` into feature engineering pipeline"""
- res = BaseFeature()
- for tran in trans_list:
- res = res & tran
- return res
-
-
- class BaseFeatureEngineer(BaseFeature):
- def __init__(self, data_t="np", multigraph=False, *args, **kwargs):
- super(BaseFeatureEngineer, self).__init__(
- data_t=data_t, multigraph=multigraph, *args, **kwargs
- )
- self.args = args
- self.kwargs = kwargs
-
-
- class TransformWrapper(BaseFeature):
- def __init__(self, cls, *args, **kwargs):
- super(TransformWrapper, self).__init__(data_t="tensor", *args, **kwargs)
- self._cls = cls
- self._func = None
- self._args = args
- self._kwargs = kwargs
-
- def __call__(self):
- if self._func is None:
- self._func = self._cls(*self._args, **self._kwargs)
- return self
-
- def _transform(self, data=None):
- return self._func(data)
|