import copy import logging import torch import typing as _typing from autogl.data import Dataset LOGGER = logging.getLogger("FeatureEngineer") class _BaseFeatureEngineer: def __and__(self, other): raise NotImplementedError def fit_transform(self, dataset: Dataset, inplace=True) -> Dataset: """ Fit and transform dataset inplace or not w.r.t bool argument ``inplace`` """ dataset = self.fit(dataset) return self.transform(dataset, inplace=inplace) def fit(self, dataset: Dataset) -> Dataset: raise NotImplementedError def transform(self, dataset: Dataset, inplace: bool = True) -> Dataset: raise NotImplementedError class _ComposedFeatureEngineer(_BaseFeatureEngineer): @property def fe_components(self) -> _typing.Iterable[_BaseFeatureEngineer]: return self.__fe_components def __init__(self, feature_engineers: _typing.Iterable[_BaseFeatureEngineer]): self.__fe_components: _typing.List[_BaseFeatureEngineer] = [] for fe in feature_engineers: if isinstance(fe, _ComposedFeatureEngineer): self.__fe_components.extend(fe.fe_components) else: self.__fe_components.append(fe) def __and__(self, other: _BaseFeatureEngineer): return _ComposedFeatureEngineer((self, other)) def fit(self, dataset) -> Dataset: for fe in self.fe_components: dataset = fe.fit(dataset) return dataset def transform(self, dataset: Dataset, inplace: bool = True) -> Dataset: for fe in self.fe_components: dataset = fe.transform(dataset, inplace) return dataset class BaseFeature(_BaseFeatureEngineer): def __init__(self, multi_graph: bool = True, subgraph=False): self._multi_graph: bool = multi_graph def __and__(self, other): return _ComposedFeatureEngineer((self, other)) def _preprocess(self, data: _typing.Any) -> _typing.Any: return data def _fit(self, data: _typing.Any) -> _typing.Any: return data def _transform(self, data: _typing.Any) -> _typing.Any: return data def _postprocess(self, data: _typing.Any) -> _typing.Any: return data def fit(self, dataset: Dataset) -> Dataset: with torch.no_grad(): for i, data in enumerate(dataset): dataset[i] = self._postprocess(self._transform(self._fit(self._preprocess(data)))) return dataset def transform(self, dataset: Dataset, inplace: bool = True) -> Dataset: if not inplace: dataset = copy.deepcopy(dataset) with torch.no_grad(): for i, data in enumerate(dataset): dataset[i] = self._postprocess(self._transform(self._preprocess(data))) return dataset class BaseFeatureEngineer(BaseFeature): ...