| @@ -68,7 +68,12 @@ class BaseFeatureAtom: | |||
| elif self._data_t == "nx": | |||
| if not hasattr(data, "G") or data.G is None: | |||
| data.G = to_networkx(data, to_undirected=True) | |||
| def _adjust_to_tensor(self,data): | |||
| if self._data_t == "tensor": | |||
| pass | |||
| else: | |||
| data_np2tensor(data) | |||
| def _preprocess(self, data): | |||
| pass | |||
| @@ -98,22 +103,17 @@ class BaseFeatureAtom: | |||
| if not self._check_dataset(dataset): | |||
| return | |||
| dataset = copy.deepcopy(dataset) | |||
| for p in self._pipe: | |||
| _dataset = [x for x in dataset] | |||
| if p._subgraph: | |||
| with torch.no_grad(): | |||
| for p in self._pipe: | |||
| _dataset = [x for x in dataset] | |||
| for i, datai in enumerate(_dataset): | |||
| p._adjust_t(datai) | |||
| p._preprocess(datai) | |||
| p._fit_transform(datai) | |||
| p._postprocess(datai) | |||
| p._adjust_to_tensor(datai) | |||
| _dataset[i] = datai | |||
| else: | |||
| data = dataset.data | |||
| p._adjust_t(data) | |||
| p._preprocess(data) | |||
| data = p._fit_transform(data) | |||
| p._postprocess(data) | |||
| dataset = self._rebuild(dataset, _dataset) | |||
| dataset = self._rebuild(dataset, _dataset) | |||
| def transform(self, dataset, inplace=True): | |||
| @@ -122,22 +122,17 @@ class BaseFeatureAtom: | |||
| return dataset | |||
| if not inplace: | |||
| dataset = copy.deepcopy(dataset) | |||
| for p in self._pipe: | |||
| self._dataset = _dataset = [x for x in dataset] | |||
| if p._subgraph: | |||
| with torch.no_grad(): | |||
| for p in self._pipe: | |||
| self._dataset = _dataset = [x for x in dataset] | |||
| for i, datai in enumerate(_dataset): | |||
| p._adjust_t(datai) | |||
| p._preprocess(datai) | |||
| datai = p._transform(datai) | |||
| p._postprocess(datai) | |||
| p._adjust_to_tensor(datai) | |||
| _dataset[i] = datai | |||
| else: | |||
| data = dataset.data | |||
| p._adjust_t(data) | |||
| p._preprocess(data) | |||
| data = p._transform(data) | |||
| p._postprocess(data) | |||
| dataset = self._rebuild(dataset, _dataset) | |||
| dataset = self._rebuild(dataset, _dataset) | |||
| dataset.data = data_np2tensor(dataset.data) | |||
| return dataset | |||