From 2687014defd9ea1f6a0ff48235e29f862c024eef Mon Sep 17 00:00:00 2001 From: wondergo2017 Date: Thu, 25 Feb 2021 13:40:40 +0000 Subject: [PATCH] fix dtype inconsistency in fe fittransform pipe --- autogl/module/feature/base.py | 37 +++++++++++++++-------------------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/autogl/module/feature/base.py b/autogl/module/feature/base.py index 53beab2..1c7b52b 100644 --- a/autogl/module/feature/base.py +++ b/autogl/module/feature/base.py @@ -68,7 +68,12 @@ class BaseFeatureAtom: elif self._data_t == "nx": if not hasattr(data, "G") or data.G is None: data.G = to_networkx(data, to_undirected=True) - + def _adjust_to_tensor(self,data): + if self._data_t == "tensor": + pass + else: + data_np2tensor(data) + def _preprocess(self, data): pass @@ -98,22 +103,17 @@ class BaseFeatureAtom: if not self._check_dataset(dataset): return dataset = copy.deepcopy(dataset) - for p in self._pipe: - _dataset = [x for x in dataset] - if p._subgraph: + with torch.no_grad(): + for p in self._pipe: + _dataset = [x for x in dataset] for i, datai in enumerate(_dataset): p._adjust_t(datai) p._preprocess(datai) p._fit_transform(datai) p._postprocess(datai) + p._adjust_to_tensor(datai) _dataset[i] = datai - else: - data = dataset.data - p._adjust_t(data) - p._preprocess(data) - data = p._fit_transform(data) - p._postprocess(data) - dataset = self._rebuild(dataset, _dataset) + dataset = self._rebuild(dataset, _dataset) def transform(self, dataset, inplace=True): @@ -122,22 +122,17 @@ class BaseFeatureAtom: return dataset if not inplace: dataset = copy.deepcopy(dataset) - for p in self._pipe: - self._dataset = _dataset = [x for x in dataset] - if p._subgraph: + with torch.no_grad(): + for p in self._pipe: + self._dataset = _dataset = [x for x in dataset] for i, datai in enumerate(_dataset): p._adjust_t(datai) p._preprocess(datai) datai = p._transform(datai) p._postprocess(datai) + p._adjust_to_tensor(datai) _dataset[i] = datai - else: - data = dataset.data - p._adjust_t(data) - p._preprocess(data) - data = p._transform(data) - p._postprocess(data) - dataset = self._rebuild(dataset, _dataset) + dataset = self._rebuild(dataset, _dataset) dataset.data = data_np2tensor(dataset.data) return dataset