|
- import torch
- import typing as _typing
- from typing import Optional, Tuple
-
-
- def degree(index, num_nodes: _typing.Optional[int] = None,
- dtype: _typing.Optional[torch.dtype] = None):
- r"""Computes the (unweighted) degree of a given one-dimensional index
- tensor.
-
- Args:
- index (LongTensor): Index tensor.
- num_nodes (int, optional): The number of nodes, *i.e.*
- :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
- dtype (:obj:`torch.dtype`, optional): The desired data type of the
- returned tensor.
-
- :rtype: :class:`Tensor`
- """
-
- def maybe_num_nodes(edge_index, __num_nodes=None):
- if __num_nodes is not None:
- return __num_nodes
- elif isinstance(edge_index, torch.Tensor):
- return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0
- else:
- return max(edge_index.size(0), edge_index.size(1))
-
- N = maybe_num_nodes(index, num_nodes)
- out = torch.zeros((N,), dtype=dtype, device=index.device)
- one = torch.ones((index.size(0),), dtype=out.dtype, device=out.device)
- return out.scatter_add_(0, index, one)
-
-
- def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
- if dim < 0:
- dim = other.dim() + dim
- if src.dim() == 1:
- for _ in range(0, dim):
- src = src.unsqueeze(0)
- for _ in range(src.dim(), other.dim()):
- src = src.unsqueeze(-1)
- src = src.expand_as(other)
- return src
-
-
- def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
- out: Optional[torch.Tensor] = None,
- dim_size: Optional[int] = None) -> torch.Tensor:
- index = broadcast(index, src, dim)
- if out is None:
- size = list(src.size())
- if dim_size is not None:
- size[dim] = dim_size
- elif index.numel() == 0:
- size[dim] = 0
- else:
- size[dim] = int(index.max()) + 1
- out = torch.zeros(size, dtype=src.dtype, device=src.device)
- return out.scatter_add_(dim, index, src)
- else:
- return out.scatter_add_(dim, index, src)
-
-
- def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
- out: Optional[torch.Tensor] = None,
- dim_size: Optional[int] = None) -> torch.Tensor:
- return scatter_sum(src, index, dim, out, dim_size)
-
-
- def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
- out: Optional[torch.Tensor] = None,
- dim_size: Optional[int] = None) -> torch.Tensor:
- return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
-
-
- def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
- out: Optional[torch.Tensor] = None,
- dim_size: Optional[int] = None) -> torch.Tensor:
- out = scatter_sum(src, index, dim, out, dim_size)
- dim_size = out.size(dim)
-
- index_dim = dim
- if index_dim < 0:
- index_dim = index_dim + src.dim()
- if index.dim() <= index_dim:
- index_dim = index.dim() - 1
-
- ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
- count = scatter_sum(ones, index, index_dim, None, dim_size)
- count[count < 1] = 1
- count = broadcast(count, out, dim)
- if out.is_floating_point():
- out.true_divide_(count)
- else:
- out.floor_divide_(count)
- return out
-
-
- def scatter_min(
- src: torch.Tensor, index: torch.Tensor, dim: int = -1,
- out: Optional[torch.Tensor] = None,
- dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
- return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
-
-
- def scatter_max(
- src: torch.Tensor, index: torch.Tensor, dim: int = -1,
- out: Optional[torch.Tensor] = None,
- dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
- return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
-
-
- def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
- out: Optional[torch.Tensor] = None,
- dim_size: Optional[int] = None,
- unbiased: bool = True) -> torch.Tensor:
- if out is not None:
- dim_size = out.size(dim)
-
- if dim < 0:
- dim = src.dim() + dim
-
- count_dim = dim
- if index.dim() <= dim:
- count_dim = index.dim() - 1
-
- ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
- count = scatter_sum(ones, index, count_dim, dim_size=dim_size)
-
- index = broadcast(index, src, dim)
- tmp = scatter_sum(src, index, dim, dim_size=dim_size)
- count = broadcast(count, tmp, dim).clamp(1)
- mean = tmp.div(count)
-
- var = (src - mean.gather(dim, index))
- var = var * var
- out = scatter_sum(var, index, dim, out, dim_size)
-
- if unbiased:
- count = count.sub(1).clamp_(1)
- out = out.div(count + 1e-6).sqrt()
-
- return out
-
-
- def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
- out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
- reduce: str = "sum") -> torch.Tensor:
- r"""
- |
-
- .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
- master/docs/source/_figures/add.svg?sanitize=true
- :align: center
- :width: 400px
-
- |
-
- Reduces all values from the :attr:`src` tensor into :attr:`out` at the
- indices specified in the :attr:`index` tensor along a given axis
- :attr:`dim`.
- For each value in :attr:`src`, its output index is specified by its index
- in :attr:`src` for dimensions outside of :attr:`dim` and by the
- corresponding value in :attr:`index` for dimension :attr:`dim`.
- The applied reduction is defined via the :attr:`reduce` argument.
-
- Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
- tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
- and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
- tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
- Moreover, the values of :attr:`index` must be between :math:`0` and
- :math:`y - 1`, although no specific ordering of indices is required.
- The :attr:`index` tensor supports broadcasting in case its dimensions do
- not match with :attr:`src`.
-
- For one-dimensional tensors with :obj:`reduce="sum"`, the operation
- computes
-
- .. math::
- \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
-
- where :math:`\sum_j` is over :math:`j` such that
- :math:`\mathrm{index}_j = i`.
-
- .. note::
-
- This operation is implemented via atomic operations on the GPU and is
- therefore **non-deterministic** since the order of parallel operations
- to the same value is undetermined.
- For floating-point variables, this results in a source of variance in
- the result.
-
- :param src: The source tensor.
- :param index: The indices of elements to scatter.
- :param dim: The axis along which to index. (default: :obj:`-1`)
- :param out: The destination tensor.
- :param dim_size: If :attr:`out` is not given, automatically create output
- with size :attr:`dim_size` at dimension :attr:`dim`.
- If :attr:`dim_size` is not given, a minimal sized output tensor
- according to :obj:`index.max() + 1` is returned.
- :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
- :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
-
- :rtype: :class:`Tensor`
-
- .. code-block:: python
-
- from torch_scatter import scatter
-
- src = torch.randn(10, 6, 64)
- index = torch.tensor([0, 1, 0, 1, 2, 1])
-
- # Broadcasting in the first and last dim.
- out = scatter(src, index, dim=1, reduce="sum")
-
- print(out.size())
-
- .. code-block::
-
- torch.Size([10, 3, 64])
- """
- if reduce == 'sum' or reduce == 'add':
- return scatter_sum(src, index, dim, out, dim_size)
- if reduce == 'mul':
- return scatter_mul(src, index, dim, out, dim_size)
- elif reduce == 'mean':
- return scatter_mean(src, index, dim, out, dim_size)
- elif reduce == 'min':
- return scatter_min(src, index, dim, out, dim_size)[0]
- elif reduce == 'max':
- return scatter_max(src, index, dim, out, dim_size)[0]
- else:
- raise ValueError
|