Browse Source

fix dtype inconsistency in fe fittransform pipe

tags/v0.3.1
wondergo2017 5 years ago
parent
commit
2687014def
1 changed files with 16 additions and 21 deletions
  1. +16
    -21
      autogl/module/feature/base.py

+ 16
- 21
autogl/module/feature/base.py View File

@@ -68,7 +68,12 @@ class BaseFeatureAtom:
elif self._data_t == "nx":
if not hasattr(data, "G") or data.G is None:
data.G = to_networkx(data, to_undirected=True)

def _adjust_to_tensor(self,data):
if self._data_t == "tensor":
pass
else:
data_np2tensor(data)
def _preprocess(self, data):
pass

@@ -98,22 +103,17 @@ class BaseFeatureAtom:
if not self._check_dataset(dataset):
return
dataset = copy.deepcopy(dataset)
for p in self._pipe:
_dataset = [x for x in dataset]
if p._subgraph:
with torch.no_grad():
for p in self._pipe:
_dataset = [x for x in dataset]
for i, datai in enumerate(_dataset):
p._adjust_t(datai)
p._preprocess(datai)
p._fit_transform(datai)
p._postprocess(datai)
p._adjust_to_tensor(datai)
_dataset[i] = datai
else:
data = dataset.data
p._adjust_t(data)
p._preprocess(data)
data = p._fit_transform(data)
p._postprocess(data)
dataset = self._rebuild(dataset, _dataset)
dataset = self._rebuild(dataset, _dataset)

def transform(self, dataset, inplace=True):
@@ -122,22 +122,17 @@ class BaseFeatureAtom:
return dataset
if not inplace:
dataset = copy.deepcopy(dataset)
for p in self._pipe:
self._dataset = _dataset = [x for x in dataset]
if p._subgraph:
with torch.no_grad():
for p in self._pipe:
self._dataset = _dataset = [x for x in dataset]
for i, datai in enumerate(_dataset):
p._adjust_t(datai)
p._preprocess(datai)
datai = p._transform(datai)
p._postprocess(datai)
p._adjust_to_tensor(datai)
_dataset[i] = datai
else:
data = dataset.data
p._adjust_t(data)
p._preprocess(data)
data = p._transform(data)
p._postprocess(data)
dataset = self._rebuild(dataset, _dataset)
dataset = self._rebuild(dataset, _dataset)
dataset.data = data_np2tensor(dataset.data)
return dataset



Loading…
Cancel
Save