|
- import typing as _typing
- import torch.nn.functional
- from torch_geometric.nn.conv import MessagePassing
- from torch_sparse import SparseTensor, matmul
-
- from . import register_model
- from .base import ClassificationModel, ClassificationSupportedSequentialModel
-
-
- class _GraphSAINTAggregationLayers:
- class MultiOrderAggregationLayer(torch.nn.Module):
- class Order0Aggregator(torch.nn.Module):
- def __init__(
- self,
- input_dimension: int,
- output_dimension: int,
- bias: bool = True,
- activation: _typing.Optional[str] = "ReLU",
- batch_norm: bool = True,
- ):
- super().__init__()
- if not type(input_dimension) == type(output_dimension) == int:
- raise TypeError
- if not (input_dimension > 0 and output_dimension > 0):
- raise ValueError
- if not type(bias) == bool:
- raise TypeError
- self.__linear_transform = torch.nn.Linear(
- input_dimension, output_dimension, bias
- )
- self.__linear_transform.reset_parameters()
- if type(activation) == str:
- if activation.lower() == "ReLU".lower():
- self.__activation = torch.nn.functional.relu
- elif activation.lower() == "elu":
- self.__activation = torch.nn.functional.elu
- elif hasattr(torch.nn.functional, activation) and callable(
- getattr(torch.nn.functional, activation)
- ):
- self.__activation = getattr(torch.nn.functional, activation)
- else:
- self.__activation = lambda x: x
- else:
- self.__activation = lambda x: x
- if type(batch_norm) != bool:
- raise TypeError
- else:
- self.__optional_batch_normalization: _typing.Optional[
- torch.nn.BatchNorm1d
- ] = (
- torch.nn.BatchNorm1d(output_dimension, 1e-8)
- if batch_norm
- else None
- )
-
- def forward(
- self,
- x: _typing.Union[
- torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]
- ],
- _edge_index: torch.Tensor,
- _edge_weight: _typing.Optional[torch.Tensor] = None,
- _size: _typing.Optional[_typing.Tuple[int, int]] = None,
- ) -> torch.Tensor:
- __output: torch.Tensor = self.__linear_transform(x)
- if self.__activation is not None and callable(self.__activation):
- __output: torch.Tensor = self.__activation(__output)
- if self.__optional_batch_normalization is not None and isinstance(
- self.__optional_batch_normalization, torch.nn.BatchNorm1d
- ):
- __output: torch.Tensor = self.__optional_batch_normalization(
- __output
- )
- return __output
-
- class Order1Aggregator(MessagePassing):
- def __init__(
- self,
- input_dimension: int,
- output_dimension: int,
- bias: bool = True,
- activation: _typing.Optional[str] = "ReLU",
- batch_norm: bool = True,
- ):
- super().__init__(aggr="add")
- if not type(input_dimension) == type(output_dimension) == int:
- raise TypeError
- if not (input_dimension > 0 and output_dimension > 0):
- raise ValueError
- if not type(bias) == bool:
- raise TypeError
- self.__linear_transform = torch.nn.Linear(
- input_dimension, output_dimension, bias
- )
- self.__linear_transform.reset_parameters()
- if type(activation) == str:
- if activation.lower() == "ReLU".lower():
- self.__activation = torch.nn.functional.relu
- elif activation.lower() == "elu":
- self.__activation = torch.nn.functional.elu
- elif hasattr(torch.nn.functional, activation) and callable(
- getattr(torch.nn.functional, activation)
- ):
- self.__activation = getattr(torch.nn.functional, activation)
- else:
- self.__activation = lambda x: x
- else:
- self.__activation = lambda x: x
- if type(batch_norm) != bool:
- raise TypeError
- else:
- self.__optional_batch_normalization: _typing.Optional[
- torch.nn.BatchNorm1d
- ] = (
- torch.nn.BatchNorm1d(output_dimension, 1e-8)
- if batch_norm
- else None
- )
-
- def forward(
- self,
- x: _typing.Union[
- torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]
- ],
- _edge_index: torch.Tensor,
- _edge_weight: _typing.Optional[torch.Tensor] = None,
- _size: _typing.Optional[_typing.Tuple[int, int]] = None,
- ) -> torch.Tensor:
-
- if type(x) == torch.Tensor:
- x: _typing.Tuple[torch.Tensor, torch.Tensor] = (x, x)
-
- __output = self.propagate(
- _edge_index, x=x, edge_weight=_edge_weight, size=_size
- )
- __output: torch.Tensor = self.__linear_transform(__output)
- if self.__activation is not None and callable(self.__activation):
- __output: torch.Tensor = self.__activation(__output)
- if self.__optional_batch_normalization is not None and isinstance(
- self.__optional_batch_normalization, torch.nn.BatchNorm1d
- ):
- __output: torch.Tensor = self.__optional_batch_normalization(
- __output
- )
- return __output
-
- def message(
- self, x_j: torch.Tensor, edge_weight: _typing.Optional[torch.Tensor]
- ) -> torch.Tensor:
- return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
-
- def message_and_aggregate(
- self,
- adj_t: SparseTensor,
- x: _typing.Union[
- torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]
- ],
- ) -> torch.Tensor:
- return matmul(adj_t, x[0], reduce=self.aggr)
-
- @property
- def integral_output_dimension(self) -> int:
- return (self._order + 1) * self._each_order_output_dimension
-
- def __init__(
- self,
- _input_dimension: int,
- _each_order_output_dimension: int,
- _order: int,
- bias: bool = True,
- activation: _typing.Optional[str] = "ReLU",
- batch_norm: bool = True,
- _dropout: _typing.Optional[float] = ...,
- ):
- super().__init__()
- if not (
- type(_input_dimension) == type(_order) == int
- and type(_each_order_output_dimension) == int
- ):
- raise TypeError
- if _input_dimension <= 0 or _each_order_output_dimension <= 0:
- raise ValueError
- if _order not in (0, 1):
- raise ValueError("Unsupported order number")
- self._input_dimension: int = _input_dimension
- self._each_order_output_dimension: int = _each_order_output_dimension
- self._order: int = _order
- if type(bias) != bool:
- raise TypeError
- self.__order0_transform = self.Order0Aggregator(
- self._input_dimension,
- self._each_order_output_dimension,
- bias,
- activation,
- batch_norm,
- )
- if _order == 1:
- self.__order1_transform = self.Order1Aggregator(
- self._input_dimension,
- self._each_order_output_dimension,
- bias,
- activation,
- batch_norm,
- )
- else:
- self.__order1_transform = None
- if _dropout is not None and type(_dropout) == float:
- if _dropout < 0:
- _dropout = 0
- if _dropout > 1:
- _dropout = 1
- self.__optional_dropout: _typing.Optional[
- torch.nn.Dropout
- ] = torch.nn.Dropout(_dropout)
- else:
- self.__optional_dropout: _typing.Optional[torch.nn.Dropout] = None
-
- def _forward(
- self,
- x: _typing.Union[torch.Tensor, _typing.Tuple[torch.Tensor, torch.Tensor]],
- edge_index: torch.Tensor,
- edge_weight: _typing.Optional[torch.Tensor] = None,
- size: _typing.Optional[_typing.Tuple[int, int]] = None,
- ) -> torch.Tensor:
- if self.__order1_transform is not None and isinstance(
- self.__order1_transform, self.Order1Aggregator
- ):
- __output: torch.Tensor = torch.cat(
- [
- self.__order0_transform(x, edge_index, edge_weight, size),
- self.__order1_transform(x, edge_index, edge_weight, size),
- ],
- dim=1,
- )
- else:
- __output: torch.Tensor = self.__order0_transform(
- x, edge_index, edge_weight, size
- )
- if self.__optional_dropout is not None and isinstance(
- self.__optional_dropout, torch.nn.Dropout
- ):
- __output: torch.Tensor = self.__optional_dropout(__output)
- return __output
-
- def forward(self, data) -> torch.Tensor:
- x: torch.Tensor = getattr(data, "x")
- if type(x) != torch.Tensor:
- raise TypeError
- edge_index: torch.LongTensor = getattr(data, "edge_index")
- if type(edge_index) != torch.Tensor:
- raise TypeError
- edge_weight: _typing.Optional[torch.Tensor] = getattr(
- data, "edge_weight", None
- )
- if edge_weight is not None and type(edge_weight) != torch.Tensor:
- raise TypeError
- return self._forward(x, edge_index, edge_weight)
-
- class WrappedDropout(torch.nn.Module):
- def __init__(self, dropout_module: torch.nn.Dropout):
- super().__init__()
- self.__dropout_module: torch.nn.Dropout = dropout_module
-
- def forward(self, tenser_or_data) -> torch.Tensor:
- if type(tenser_or_data) == torch.Tensor:
- return self.__dropout_module(tenser_or_data)
- elif (
- hasattr(tenser_or_data, "x")
- and type(getattr(tenser_or_data, "x")) == torch.Tensor
- ):
- return self.__dropout_module(getattr(tenser_or_data, "x"))
- else:
- raise TypeError
-
-
- class GraphSAINTMultiOrderAggregationModel(ClassificationSupportedSequentialModel):
- def __init__(
- self,
- num_features: int,
- num_classes: int,
- _output_dimension_for_each_order: int,
- _layers_order_list: _typing.Sequence[int],
- _pre_dropout: float,
- _layers_dropout: _typing.Union[float, _typing.Sequence[float]],
- activation: _typing.Optional[str] = "ReLU",
- bias: bool = True,
- batch_norm: bool = True,
- normalize: bool = True,
- ):
- super(GraphSAINTMultiOrderAggregationModel, self).__init__()
- if type(_output_dimension_for_each_order) != int:
- raise TypeError
- if not _output_dimension_for_each_order > 0:
- raise ValueError
- self._layers_order_list: _typing.Sequence[int] = _layers_order_list
-
- if isinstance(_layers_dropout, _typing.Sequence):
- if len(_layers_dropout) != len(_layers_order_list):
- raise ValueError
- else:
- self._layers_dropout: _typing.Sequence[float] = _layers_dropout
- elif type(_layers_dropout) == float:
- if _layers_dropout < 0:
- _layers_dropout = 0
- if _layers_dropout > 1:
- _layers_dropout = 1
- self._layers_dropout: _typing.Sequence[float] = [
- _layers_dropout for _ in _layers_order_list
- ]
- else:
- raise TypeError
- if type(_pre_dropout) != float:
- raise TypeError
- else:
- if _pre_dropout < 0:
- _pre_dropout = 0
- if _pre_dropout > 1:
- _pre_dropout = 1
- self.__sequential_encoding_layers: torch.nn.ModuleList = torch.nn.ModuleList(
- (
- _GraphSAINTAggregationLayers.WrappedDropout(
- torch.nn.Dropout(_pre_dropout)
- ),
- _GraphSAINTAggregationLayers.MultiOrderAggregationLayer(
- num_features,
- _output_dimension_for_each_order,
- _layers_order_list[0],
- bias,
- activation,
- batch_norm,
- _layers_dropout[0],
- ),
- )
- )
- for _layer_index in range(1, len(_layers_order_list)):
- self.__sequential_encoding_layers.append(
- _GraphSAINTAggregationLayers.MultiOrderAggregationLayer(
- self.__sequential_encoding_layers[-1].integral_output_dimension,
- _output_dimension_for_each_order,
- _layers_order_list[_layer_index],
- bias,
- activation,
- batch_norm,
- _layers_dropout[_layer_index],
- )
- )
- self.__apply_normalize: bool = normalize
- self.__linear_transform: torch.nn.Linear = torch.nn.Linear(
- self.__sequential_encoding_layers[-1].integral_output_dimension,
- num_classes,
- bias,
- )
- self.__linear_transform.reset_parameters()
-
- def cls_decode(self, x: torch.Tensor) -> torch.Tensor:
- if self.__apply_normalize:
- x: torch.Tensor = torch.nn.functional.normalize(x, p=2, dim=1)
- return torch.nn.functional.log_softmax(self.__linear_transform(x), dim=1)
-
- def cls_encode(self, data) -> torch.Tensor:
- if type(getattr(data, "x")) != torch.Tensor:
- raise TypeError
- if type(getattr(data, "edge_index")) != torch.Tensor:
- raise TypeError
- if (
- getattr(data, "edge_weight", None) is not None
- and type(getattr(data, "edge_weight")) != torch.Tensor
- ):
- raise TypeError
- for encoding_layer in self.__sequential_encoding_layers:
- setattr(data, "x", encoding_layer(data))
- return getattr(data, "x")
-
- @property
- def sequential_encoding_layers(self) -> torch.nn.ModuleList:
- return self.__sequential_encoding_layers
-
-
- @register_model("GraphSAINTAggregationModel")
- class GraphSAINTAggregationModel(ClassificationModel):
- def __init__(
- self,
- num_features: int = ...,
- num_classes: int = ...,
- device: _typing.Union[str, torch.device] = ...,
- init: bool = False,
- **kwargs
- ):
- super(GraphSAINTAggregationModel, self).__init__(
- num_features, num_classes, device=device, init=init, **kwargs
- )
- # todo: Initialize with default hyper parameter space and hyper parameter
-
- def _initialize(self):
- """ Initialize model """
- self.model = GraphSAINTMultiOrderAggregationModel(
- self.num_features,
- self.num_classes,
- self.hyper_parameter.get("output_dimension_for_each_order"),
- self.hyper_parameter.get("layers_order_list"),
- self.hyper_parameter.get("pre_dropout"),
- self.hyper_parameter.get("layers_dropout"),
- self.hyper_parameter.get("activation", "ReLU"),
- bool(self.hyper_parameter.get("bias", True)),
- bool(self.hyper_parameter.get("batch_norm", True)),
- bool(self.hyper_parameter.get("normalize", True)),
- ).to(self.device)
|