|
|
|
@@ -11,7 +11,7 @@ from .selectors import SeGBDT |
|
|
|
from . import register_feature |
|
|
|
|
|
|
|
from ...utils import get_logger |
|
|
|
|
|
|
|
import torch |
|
|
|
LOGGER = get_logger("Feature") |
|
|
|
|
|
|
|
|
|
|
|
@@ -28,10 +28,13 @@ class Onlyconst(BaseFeatureEngineer): |
|
|
|
r"""it is a dummy feature engineer , which directly returns identical data""" |
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super(Onlyconst, self).__init__(multigraph=True, *args, **kwargs) |
|
|
|
super(Onlyconst, self).__init__(data_t='tensor',multigraph=True, *args, **kwargs) |
|
|
|
|
|
|
|
def _transform(self, data): |
|
|
|
data.x = np.ones((data.x.shape[0], 1)) |
|
|
|
if 'x' in data: |
|
|
|
data.x = torch.ones((data.x.shape[0], 1)) |
|
|
|
else: |
|
|
|
data.x= torch.ones((torch.unique(data.edge_index).shape[0],1)) |
|
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|