You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
12345678910111213141516171819
  1. from ..base import BaseFeature
  2. import numpy as np
  3. import torch
  4. from .. import register_feature
  5. @register_feature("graph")
  6. class BaseGraph(BaseFeature):
  7. def __init__(self, data_t="np", multigraph=True, **kwargs):
  8. super(BaseGraph, self).__init__(
  9. data_t=data_t, multigraph=multigraph, subgraph=True, **kwargs
  10. )
  11. def _preprocess(self, data):
  12. if not hasattr(data, "gf") or data["gf"] is None:
  13. data.gf = torch.FloatTensor([[]])
  14. def _postprocess(self, data):
  15. pass