|
|
|
@@ -18,9 +18,10 @@ from mindspore.common.tensor import Tensor |
|
|
|
from mindspore._checkparam import Validator |
|
|
|
from ...cell import Cell |
|
|
|
from ...layer.activation import get_activation |
|
|
|
from ..distribution.normal import Normal |
|
|
|
from .layer_distribution import NormalPrior, NormalPosterior |
|
|
|
|
|
|
|
__all__ = ['DenseReparam'] |
|
|
|
__all__ = ['DenseReparam', 'DenseLocalReparam'] |
|
|
|
|
|
|
|
|
|
|
|
class _DenseVariational(Cell): |
|
|
|
@@ -122,17 +123,17 @@ class _DenseVariational(Cell): |
|
|
|
return self.bias_add(inputs, bias_posterior_tensor) |
|
|
|
|
|
|
|
def compute_kl_loss(self): |
|
|
|
"""Compute kl loss.""" |
|
|
|
weight_post_mean = self.weight_posterior("mean") |
|
|
|
weight_post_sd = self.weight_posterior("sd") |
|
|
|
"""Compute kl loss""" |
|
|
|
weight_args_list = self.weight_posterior("get_dist_args") |
|
|
|
weight_type = self.weight_posterior("get_dist_type") |
|
|
|
|
|
|
|
kl = self.weight_prior("kl_loss", "Normal", weight_post_mean, weight_post_sd) |
|
|
|
kl = self.weight_prior("kl_loss", weight_type, *weight_args_list) |
|
|
|
kl_loss = self.sum(kl) |
|
|
|
if self.has_bias: |
|
|
|
bias_post_mean = self.bias_posterior("mean") |
|
|
|
bias_post_sd = self.bias_posterior("sd") |
|
|
|
bias_args_list = self.bias_posterior("get_dist_args") |
|
|
|
bias_type = self.bias_posterior("get_dist_type") |
|
|
|
|
|
|
|
kl = self.bias_prior("kl_loss", "Normal", bias_post_mean, bias_post_sd) |
|
|
|
kl = self.bias_prior("kl_loss", bias_type, *bias_args_list) |
|
|
|
kl = self.sum(kl) |
|
|
|
kl_loss += kl |
|
|
|
return kl_loss |
|
|
|
@@ -187,6 +188,9 @@ class DenseReparam(_DenseVariational): |
|
|
|
Outputs: |
|
|
|
Tensor, the shape of the tensor is :math:`(N, out\_channels)`. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> net = DenseReparam(3, 4) |
|
|
|
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) |
|
|
|
@@ -220,3 +224,95 @@ class DenseReparam(_DenseVariational): |
|
|
|
weight_posterior_tensor = self.weight_posterior("sample") |
|
|
|
outputs = self.matmul(inputs, weight_posterior_tensor) |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
class DenseLocalReparam(_DenseVariational): |
|
|
|
r""" |
|
|
|
Dense variational layers with Local Reparameterization. |
|
|
|
|
|
|
|
For more details, refer to the paper `Variational Dropout and the Local Reparameterization |
|
|
|
Trick <https://arxiv.org/abs/1506.02557>`_. |
|
|
|
|
|
|
|
Applies dense-connected layer to the input. This layer implements the operation as: |
|
|
|
|
|
|
|
.. math:: |
|
|
|
\text{outputs} = \text{activation}(\text{inputs} * \text{weight} + \text{bias}), |
|
|
|
|
|
|
|
where :math:`\text{activation}` is the activation function passed as the activation |
|
|
|
argument (if passed in), :math:`\text{activation}` is a weight matrix with the same |
|
|
|
data type as the inputs created by the layer, :math:`\text{weight}` is a weight |
|
|
|
matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a |
|
|
|
bias vector with the same data type as the inputs created by the layer (only if |
|
|
|
has_bias is True). The bias vector is sampling from posterior distribution of |
|
|
|
:math:`\text{bias}`. |
|
|
|
|
|
|
|
Args: |
|
|
|
in_channels (int): The number of input channel. |
|
|
|
out_channels (int): The number of output channel . |
|
|
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. |
|
|
|
activation (str, Cell): A regularization function applied to the output of the layer. The type of `activation` |
|
|
|
can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must |
|
|
|
be instantiated beforehand. Default: None. |
|
|
|
weight_prior_fn: The prior distribution for weight. |
|
|
|
It must return a mindspore distribution instance. |
|
|
|
Default: NormalPrior. (which creates an instance of standard |
|
|
|
normal distribution). The current version only supports normal distribution. |
|
|
|
weight_posterior_fn: The posterior distribution for sampling weight. |
|
|
|
It must be a function handle which returns a mindspore |
|
|
|
distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). |
|
|
|
The current version only supports normal distribution. |
|
|
|
bias_prior_fn: The prior distribution for bias vector. It must return |
|
|
|
a mindspore distribution. Default: NormalPrior(which creates an |
|
|
|
instance of standard normal distribution). The current version |
|
|
|
only supports normal distribution. |
|
|
|
bias_posterior_fn: The posterior distribution for sampling bias vector. |
|
|
|
It must be a function handle which returns a mindspore |
|
|
|
distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). |
|
|
|
The current version only supports normal distribution. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **input** (Tensor) - The shape of the tensor is :math:`(N, in\_channels)`. |
|
|
|
|
|
|
|
Outputs: |
|
|
|
Tensor, the shape of the tensor is :math:`(N, out\_channels)`. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
|
|
|
|
Examples: |
|
|
|
>>> net = DenseLocalReparam(3, 4) |
|
|
|
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) |
|
|
|
>>> output = net(input).shape |
|
|
|
>>> print(output) |
|
|
|
(2, 4) |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
in_channels, |
|
|
|
out_channels, |
|
|
|
activation=None, |
|
|
|
has_bias=True, |
|
|
|
weight_prior_fn=NormalPrior, |
|
|
|
weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape), |
|
|
|
bias_prior_fn=NormalPrior, |
|
|
|
bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): |
|
|
|
super(DenseLocalReparam, self).__init__( |
|
|
|
in_channels, |
|
|
|
out_channels, |
|
|
|
activation=activation, |
|
|
|
has_bias=has_bias, |
|
|
|
weight_prior_fn=weight_prior_fn, |
|
|
|
weight_posterior_fn=weight_posterior_fn, |
|
|
|
bias_prior_fn=bias_prior_fn, |
|
|
|
bias_posterior_fn=bias_posterior_fn |
|
|
|
) |
|
|
|
self.sqrt = P.Sqrt() |
|
|
|
self.square = P.Square() |
|
|
|
self.normal = Normal() |
|
|
|
|
|
|
|
def _apply_variational_weight(self, inputs): |
|
|
|
mean = self.matmul(inputs, self.weight_posterior("mean")) |
|
|
|
std = self.sqrt(self.matmul(self.square(inputs), self.square(self.weight_posterior("sd")))) |
|
|
|
weight_posterior_affine_tensor = self.normal("sample", mean=mean, sd=std) |
|
|
|
return weight_posterior_affine_tensor |