|
- import netlsd
- from .base import BaseGraph
- import numpy as np
- import torch
- from .. import register_feature
-
-
- @register_feature("netlsd")
- class SgNetLSD(BaseGraph):
- r"""
- Notes
- -----
- a graph feature generation method. This is a simple wrapper of NetLSD [#]_.
- References
- ----------
- .. [#] A. Tsitsulin, D. Mottin, P. Karras, A. Bronstein, and E. Müller, “NetLSD: Hearing the shape of a graph,”
- Proc. ACM SIGKDD Int. Conf. Knowl. Discov. Data Min., pp. 2347–2356, 2018.
- """
-
- def __init__(self, *args, **kwargs):
- super(SgNetLSD, self).__init__(data_t="nx")
- self._args = args
- self._kwargs = kwargs
-
- def _transform(self, data):
- dsc = torch.FloatTensor([netlsd.heat(data.G, *self._args, **self._kwargs)])
- data.gf = torch.cat([data.gf, dsc], dim=1)
- return data
|