From e03697a0322ab95628a02f8093da3bfe91e08069 Mon Sep 17 00:00:00 2001 From: Generall Date: Mon, 19 Dec 2022 21:26:36 +0800 Subject: [PATCH] update gnnguard --- autogl/module/model/pyg/robust/gnnguard.py | 3 +- autogl/module/model/pyg/robust/nn/__init__.py | 7 + autogl/module/model/pyg/robust/nn/gcn_conv.py | 196 ++++++++++++++++++ autogl/module/model/pyg/robust/nn/inits.py | 56 +++++ 4 files changed, 261 insertions(+), 1 deletion(-) create mode 100644 autogl/module/model/pyg/robust/nn/__init__.py create mode 100644 autogl/module/model/pyg/robust/nn/gcn_conv.py create mode 100644 autogl/module/model/pyg/robust/nn/inits.py diff --git a/autogl/module/model/pyg/robust/gnnguard.py b/autogl/module/model/pyg/robust/gnnguard.py index 4eb3201..1318256 100644 --- a/autogl/module/model/pyg/robust/gnnguard.py +++ b/autogl/module/model/pyg/robust/gnnguard.py @@ -19,7 +19,8 @@ import autogl.data from .. import register_model from . import utils from ..gcn import GCN -from torch_geometric.nn.conv import GCNConv +#from torch_geometric.nn.conv import GCNConv +from .nn import GCNConv from ..base import BaseAutoModel from .....utils import get_logger from .utils import accuracy diff --git a/autogl/module/model/pyg/robust/nn/__init__.py b/autogl/module/model/pyg/robust/nn/__init__.py new file mode 100644 index 0000000..a1a7bd5 --- /dev/null +++ b/autogl/module/model/pyg/robust/nn/__init__.py @@ -0,0 +1,7 @@ +from .gcn_conv import GCNConv + +__all__ = [ + 'GCNConv' +] + +classes = __all__ diff --git a/autogl/module/model/pyg/robust/nn/gcn_conv.py b/autogl/module/model/pyg/robust/nn/gcn_conv.py new file mode 100644 index 0000000..d40e9ef --- /dev/null +++ b/autogl/module/model/pyg/robust/nn/gcn_conv.py @@ -0,0 +1,196 @@ +import torch +from torch.nn import Parameter +from torch_scatter import scatter_add +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.utils import add_remaining_self_loops, to_undirected + +from .inits import glorot, zeros + +@torch.jit._overload +def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, + add_self_loops=True, dtype=None): + # type: (Tensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> PairTensor # noqa + pass + + +@torch.jit._overload +def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, + add_self_loops=True, dtype=None): + # type: (SparseTensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> SparseTensor # noqa + pass + + +def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, + add_self_loops=True, dtype=None): + + fill_value = 2. if improved else 1. + + if isinstance(edge_index, SparseTensor): + adj_t = edge_index + if not adj_t.has_value(): + adj_t = adj_t.fill_value(1., dtype=dtype) + if add_self_loops: + adj_t = fill_diag(adj_t, fill_value) + deg = sum(adj_t, dim=1) + deg_inv_sqrt = deg.pow_(-0.5) + deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) + adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) + adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) + return adj_t + + else: + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + if edge_weight is None: + edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, + device=edge_index.device) + + if add_self_loops: + edge_index, tmp_edge_weight = add_remaining_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + assert tmp_edge_weight is not None + edge_weight = tmp_edge_weight + + row, col = edge_index[0], edge_index[1] + deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) + deg_inv_sqrt = deg.pow_(-0.5) + deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) + return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + + +class GCNConv(MessagePassing): + r"""The graph convolutional operator from the `"Semi-supervised + Classification with Graph Convolutional Networks" + `_ paper + + .. math:: + \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, + + where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the + adjacency matrix with inserted self-loops and + :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + improved (bool, optional): If set to :obj:`True`, the layer computes + :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. + (default: :obj:`False`) + cached (bool, optional): If set to :obj:`True`, the layer will cache + the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the + cached version for further executions. + This parameter should only be set to :obj:`True` in transductive + learning scenarios. (default: :obj:`False`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + normalize (bool, optional): Whether to add self-loops and apply + symmetric normalization. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`torch_geometric.nn.conv.MessagePassing`. + """ + + def __init__(self, in_channels, out_channels, improved=False, cached=False, + bias=True, add_self_loops: bool = True, normalize=True, **kwargs): + super(GCNConv, self).__init__(aggr='add', **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.improved = improved + self.cached = cached + self.normalize = normalize + + self.weight = Parameter(torch.Tensor(in_channels, out_channels)) + + if bias: + self.bias = Parameter(torch.tensor(out_channels, dtype=torch.float32)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + glorot(self.weight) + zeros(self.bias) + self.cached_result = None + self.cached_num_edges = None + + # 原来的版本 + # @staticmethod + # def norm(edge_index, num_nodes, edge_weight=None, improved=False, + # dtype=None): + # if edge_weight is None: + # edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, + # device=edge_index.device) + + # fill_value = 1 if not improved else 2 + # # """Here I removed the self-loop because the self-loop already added in the att_coef function""" + # # edge_index, edge_weight = add_remaining_self_loops( + # # edge_index, edge_weight, fill_value, num_nodes) + + # row, col = edge_index + # deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + # deg_inv_sqrt = deg.pow(-0.5) + # deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + # return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + # for GNNGuard + @staticmethod + def norm(edge_index, num_nodes, edge_weight=None, improved=False, + dtype=None): + if edge_weight is None: + edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, + device=edge_index.device) + edge_weight = edge_weight.to(edge_index.device) + fill_value = 1 if not improved else 2 + # """Here I removed the self-loop because the self-loop already added in the att_coef function""" + # edge_index, edge_weight = add_remaining_self_loops( + # edge_index, edge_weight, fill_value, num_nodes) + + row, col = edge_index # for GNNGuard + # row, col = edge_index[0], edge_index[1] + deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + + def forward(self, x, edge_index, edge_weight=None): + """""" + x = torch.matmul(x, self.weight) + + if self.cached and self.cached_result is not None: + if edge_index.size(1) != self.cached_num_edges: + raise RuntimeError( + 'Cached {} number of edges, but found {}. Please ' + 'disable the caching behavior of this layer by removing ' + 'the `cached=True` argument in its constructor.'.format( + self.cached_num_edges, edge_index.size(1))) + # edge_index = to_undirected(edge_index, x.size(0)) # add non-direct edges + if not self.cached or self.cached_result is None: + self.cached_num_edges = edge_index.size(1) + if self.normalize: + edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, self.improved, x.dtype) + else: + norm = edge_weight + self.cached_result = edge_index, norm + + edge_index, norm = self.cached_result + + return self.propagate(edge_index, x=x, norm=norm) + + def message(self, x_j, norm): + return norm.view(-1, 1) * x_j + + def update(self, aggr_out): + if self.bias is not None: + aggr_out = aggr_out + self.bias + return aggr_out + + def __repr__(self): + return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, + self.out_channels) \ No newline at end of file diff --git a/autogl/module/model/pyg/robust/nn/inits.py b/autogl/module/model/pyg/robust/nn/inits.py new file mode 100644 index 0000000..9a22541 --- /dev/null +++ b/autogl/module/model/pyg/robust/nn/inits.py @@ -0,0 +1,56 @@ +import math + +import torch + + +def uniform(size, tensor): + if tensor is not None: + bound = 1.0 / math.sqrt(size) + tensor.data.uniform_(-bound, bound) + + +def kaiming_uniform(tensor, fan, a): + if tensor is not None: + bound = math.sqrt(6 / ((1 + a**2) * fan)) + tensor.data.uniform_(-bound, bound) + + +def glorot(tensor): + if tensor is not None: + stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) + tensor.data.uniform_(-stdv, stdv) + + +def glorot_orthogonal(tensor, scale): + if tensor is not None: + torch.nn.init.orthogonal_(tensor.data) + scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var()) + tensor.data *= scale.sqrt() + + +def zeros(tensor): + if tensor is not None: + tensor.data.fill_(0) + + +def ones(tensor): + if tensor is not None: + tensor.data.fill_(1) + + +def normal(tensor, mean, std): + if tensor is not None: + tensor.data.normal_(mean, std) + + +def reset(nn): + def _reset(item): + if hasattr(item, 'reset_parameters'): + item.reset_parameters() + + if nn is not None: + if hasattr(nn, 'children') and len(list(nn.children())) > 0: + for item in nn.children(): + _reset(item) + else: + _reset(nn)