diff --git a/autogl/module/feature/auto_feature.py b/autogl/module/feature/auto_feature.py index 4b7fb3b..3d7f8c0 100644 --- a/autogl/module/feature/auto_feature.py +++ b/autogl/module/feature/auto_feature.py @@ -11,7 +11,7 @@ from .selectors import SeGBDT from . import register_feature from ...utils import get_logger - +import torch LOGGER = get_logger("Feature") @@ -28,10 +28,13 @@ class Onlyconst(BaseFeatureEngineer): r"""it is a dummy feature engineer , which directly returns identical data""" def __init__(self, *args, **kwargs): - super(Onlyconst, self).__init__(multigraph=True, *args, **kwargs) + super(Onlyconst, self).__init__(data_t='tensor',multigraph=True, *args, **kwargs) def _transform(self, data): - data.x = np.ones((data.x.shape[0], 1)) + if 'x' in data: + data.x = torch.ones((data.x.shape[0], 1)) + else: + data.x= torch.ones((torch.unique(data.edge_index).shape[0],1)) return data