|
|
|
@@ -16,16 +16,12 @@ class BaseFeatureSelector(BaseFeatureEngineer): |
|
|
|
) -> GeneralStaticGraph: |
|
|
|
if ( |
|
|
|
'x' in static_graph.nodes.data and |
|
|
|
self._selection not in (Ellipsis, None) and |
|
|
|
isinstance(self._selection, torch.Tensor) and |
|
|
|
torch.is_tensor(self._selection) and self._selection.dim() == 1 |
|
|
|
isinstance(self._selection, (torch.Tensor, np.ndarray)) |
|
|
|
): |
|
|
|
static_graph.nodes.data['x'] = static_graph.nodes.data['x'][:, self._selection] |
|
|
|
if ( |
|
|
|
'feat' in static_graph.nodes.data and |
|
|
|
self._selection not in (Ellipsis, None) and |
|
|
|
isinstance(self._selection, torch.Tensor) and |
|
|
|
torch.is_tensor(self._selection) and self._selection.dim() == 1 |
|
|
|
isinstance(self._selection, (torch.Tensor, np.ndarray)) |
|
|
|
): |
|
|
|
static_graph.nodes.data['feat'] = static_graph.nodes.data['feat'][:, self._selection] |
|
|
|
return static_graph |
|
|
|
|