import numpy as np from .. import register_feature from ..base import BaseFeature class BaseGenerator(BaseFeature): def __init__(self, data_t="np", multigraph=True, **kwargs): super(BaseGenerator, self).__init__( data_t=data_t, multigraph=multigraph, **kwargs ) @register_feature("onehot") class GeOnehot(BaseGenerator): def _transform(self, data): fe = np.eye(data.x.shape[0]) data.x = np.concatenate([data.x, fe], axis=1) return data