From: @bingyaweng Reviewed-by: @sunnybeike,@zichun_ye Signed-off-by: @sunnybeiketags/v1.1.0
| @@ -18,7 +18,7 @@ | |||||
| """ | """ | ||||
| from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper | from . import conv_variational, dense_variational, layer_distribution, bnn_cell_wrapper | ||||
| from .conv_variational import ConvReparam | from .conv_variational import ConvReparam | ||||
| from .dense_variational import DenseReparam | |||||
| from .dense_variational import DenseReparam, DenseLocalReparam | |||||
| from .layer_distribution import NormalPrior, NormalPosterior | from .layer_distribution import NormalPrior, NormalPosterior | ||||
| from .bnn_cell_wrapper import WithBNNLossCell | from .bnn_cell_wrapper import WithBNNLossCell | ||||
| @@ -37,6 +37,9 @@ class WithBNNLossCell(Cell): | |||||
| Outputs: | Outputs: | ||||
| Tensor, a scalar tensor with shape :math:`()`. | Tensor, a scalar tensor with shape :math:`()`. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | Examples: | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) | >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) | ||||
| @@ -157,18 +157,16 @@ class _ConvVariational(_Conv): | |||||
| def compute_kl_loss(self): | def compute_kl_loss(self): | ||||
| """Compute kl loss""" | """Compute kl loss""" | ||||
| weight_post_mean = self.weight_posterior("mean") | |||||
| weight_post_sd = self.weight_posterior("sd") | |||||
| weight_args_list = self.weight_posterior("get_dist_args") | |||||
| weight_type = self.weight_posterior("get_dist_type") | |||||
| kl = self.weight_prior("kl_loss", "Normal", | |||||
| weight_post_mean, weight_post_sd) | |||||
| kl = self.weight_prior("kl_loss", weight_type, *weight_args_list) | |||||
| kl_loss = self.sum(kl) | kl_loss = self.sum(kl) | ||||
| if self.has_bias: | if self.has_bias: | ||||
| bias_post_mean = self.bias_posterior("mean") | |||||
| bias_post_sd = self.bias_posterior("sd") | |||||
| bias_args_list = self.bias_posterior("get_dist_args") | |||||
| bias_type = self.bias_posterior("get_dist_type") | |||||
| kl = self.bias_prior("kl_loss", "Normal", | |||||
| bias_post_mean, bias_post_sd) | |||||
| kl = self.bias_prior("kl_loss", bias_type, *bias_args_list) | |||||
| kl = self.sum(kl) | kl = self.sum(kl) | ||||
| kl_loss += kl | kl_loss += kl | ||||
| return kl_loss | return kl_loss | ||||
| @@ -249,6 +247,9 @@ class ConvReparam(_ConvVariational): | |||||
| Outputs: | Outputs: | ||||
| Tensor, with the shape being :math:`(N, C_{out}, H_{out}, W_{out})`. | Tensor, with the shape being :math:`(N, C_{out}, H_{out}, W_{out})`. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | Examples: | ||||
| >>> net = ConvReparam(120, 240, 4, has_bias=False) | >>> net = ConvReparam(120, 240, 4, has_bias=False) | ||||
| >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | ||||
| @@ -18,9 +18,10 @@ from mindspore.common.tensor import Tensor | |||||
| from mindspore._checkparam import Validator | from mindspore._checkparam import Validator | ||||
| from ...cell import Cell | from ...cell import Cell | ||||
| from ...layer.activation import get_activation | from ...layer.activation import get_activation | ||||
| from ..distribution.normal import Normal | |||||
| from .layer_distribution import NormalPrior, NormalPosterior | from .layer_distribution import NormalPrior, NormalPosterior | ||||
| __all__ = ['DenseReparam'] | |||||
| __all__ = ['DenseReparam', 'DenseLocalReparam'] | |||||
| class _DenseVariational(Cell): | class _DenseVariational(Cell): | ||||
| @@ -122,17 +123,17 @@ class _DenseVariational(Cell): | |||||
| return self.bias_add(inputs, bias_posterior_tensor) | return self.bias_add(inputs, bias_posterior_tensor) | ||||
| def compute_kl_loss(self): | def compute_kl_loss(self): | ||||
| """Compute kl loss.""" | |||||
| weight_post_mean = self.weight_posterior("mean") | |||||
| weight_post_sd = self.weight_posterior("sd") | |||||
| """Compute kl loss""" | |||||
| weight_args_list = self.weight_posterior("get_dist_args") | |||||
| weight_type = self.weight_posterior("get_dist_type") | |||||
| kl = self.weight_prior("kl_loss", "Normal", weight_post_mean, weight_post_sd) | |||||
| kl = self.weight_prior("kl_loss", weight_type, *weight_args_list) | |||||
| kl_loss = self.sum(kl) | kl_loss = self.sum(kl) | ||||
| if self.has_bias: | if self.has_bias: | ||||
| bias_post_mean = self.bias_posterior("mean") | |||||
| bias_post_sd = self.bias_posterior("sd") | |||||
| bias_args_list = self.bias_posterior("get_dist_args") | |||||
| bias_type = self.bias_posterior("get_dist_type") | |||||
| kl = self.bias_prior("kl_loss", "Normal", bias_post_mean, bias_post_sd) | |||||
| kl = self.bias_prior("kl_loss", bias_type, *bias_args_list) | |||||
| kl = self.sum(kl) | kl = self.sum(kl) | ||||
| kl_loss += kl | kl_loss += kl | ||||
| return kl_loss | return kl_loss | ||||
| @@ -187,6 +188,9 @@ class DenseReparam(_DenseVariational): | |||||
| Outputs: | Outputs: | ||||
| Tensor, the shape of the tensor is :math:`(N, out\_channels)`. | Tensor, the shape of the tensor is :math:`(N, out\_channels)`. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | Examples: | ||||
| >>> net = DenseReparam(3, 4) | >>> net = DenseReparam(3, 4) | ||||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | ||||
| @@ -220,3 +224,95 @@ class DenseReparam(_DenseVariational): | |||||
| weight_posterior_tensor = self.weight_posterior("sample") | weight_posterior_tensor = self.weight_posterior("sample") | ||||
| outputs = self.matmul(inputs, weight_posterior_tensor) | outputs = self.matmul(inputs, weight_posterior_tensor) | ||||
| return outputs | return outputs | ||||
| class DenseLocalReparam(_DenseVariational): | |||||
| r""" | |||||
| Dense variational layers with Local Reparameterization. | |||||
| For more details, refer to the paper `Variational Dropout and the Local Reparameterization | |||||
| Trick <https://arxiv.org/abs/1506.02557>`_. | |||||
| Applies dense-connected layer to the input. This layer implements the operation as: | |||||
| .. math:: | |||||
| \text{outputs} = \text{activation}(\text{inputs} * \text{weight} + \text{bias}), | |||||
| where :math:`\text{activation}` is the activation function passed as the activation | |||||
| argument (if passed in), :math:`\text{activation}` is a weight matrix with the same | |||||
| data type as the inputs created by the layer, :math:`\text{weight}` is a weight | |||||
| matrix sampling from posterior distribution of weight, and :math:`\text{bias}` is a | |||||
| bias vector with the same data type as the inputs created by the layer (only if | |||||
| has_bias is True). The bias vector is sampling from posterior distribution of | |||||
| :math:`\text{bias}`. | |||||
| Args: | |||||
| in_channels (int): The number of input channel. | |||||
| out_channels (int): The number of output channel . | |||||
| has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. | |||||
| activation (str, Cell): A regularization function applied to the output of the layer. The type of `activation` | |||||
| can be a string (eg. 'relu') or a Cell (eg. nn.ReLU()). Note that if the type of activation is Cell, it must | |||||
| be instantiated beforehand. Default: None. | |||||
| weight_prior_fn: The prior distribution for weight. | |||||
| It must return a mindspore distribution instance. | |||||
| Default: NormalPrior. (which creates an instance of standard | |||||
| normal distribution). The current version only supports normal distribution. | |||||
| weight_posterior_fn: The posterior distribution for sampling weight. | |||||
| It must be a function handle which returns a mindspore | |||||
| distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). | |||||
| The current version only supports normal distribution. | |||||
| bias_prior_fn: The prior distribution for bias vector. It must return | |||||
| a mindspore distribution. Default: NormalPrior(which creates an | |||||
| instance of standard normal distribution). The current version | |||||
| only supports normal distribution. | |||||
| bias_posterior_fn: The posterior distribution for sampling bias vector. | |||||
| It must be a function handle which returns a mindspore | |||||
| distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). | |||||
| The current version only supports normal distribution. | |||||
| Inputs: | |||||
| - **input** (Tensor) - The shape of the tensor is :math:`(N, in\_channels)`. | |||||
| Outputs: | |||||
| Tensor, the shape of the tensor is :math:`(N, out\_channels)`. | |||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | |||||
| >>> net = DenseLocalReparam(3, 4) | |||||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | |||||
| >>> output = net(input).shape | |||||
| >>> print(output) | |||||
| (2, 4) | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| activation=None, | |||||
| has_bias=True, | |||||
| weight_prior_fn=NormalPrior, | |||||
| weight_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape), | |||||
| bias_prior_fn=NormalPrior, | |||||
| bias_posterior_fn=lambda name, shape: NormalPosterior(name=name, shape=shape)): | |||||
| super(DenseLocalReparam, self).__init__( | |||||
| in_channels, | |||||
| out_channels, | |||||
| activation=activation, | |||||
| has_bias=has_bias, | |||||
| weight_prior_fn=weight_prior_fn, | |||||
| weight_posterior_fn=weight_posterior_fn, | |||||
| bias_prior_fn=bias_prior_fn, | |||||
| bias_posterior_fn=bias_posterior_fn | |||||
| ) | |||||
| self.sqrt = P.Sqrt() | |||||
| self.square = P.Square() | |||||
| self.normal = Normal() | |||||
| def _apply_variational_weight(self, inputs): | |||||
| mean = self.matmul(inputs, self.weight_posterior("mean")) | |||||
| std = self.sqrt(self.matmul(self.square(inputs), self.square(self.weight_posterior("sd")))) | |||||
| weight_posterior_affine_tensor = self.normal("sample", mean=mean, sd=std) | |||||
| return weight_posterior_affine_tensor | |||||
| @@ -36,6 +36,9 @@ class NormalPrior(Cell): | |||||
| Returns: | Returns: | ||||
| Cell, a normal distribution. | Cell, a normal distribution. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| """ | """ | ||||
| def __init__(self, dtype=mstype.float32, mean=0, std=0.1): | def __init__(self, dtype=mstype.float32, mean=0, std=0.1): | ||||
| super(NormalPrior, self).__init__() | super(NormalPrior, self).__init__() | ||||
| @@ -62,6 +65,9 @@ class NormalPosterior(Cell): | |||||
| Returns: | Returns: | ||||
| Cell, a normal distribution. | Cell, a normal distribution. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| name, | name, | ||||
| @@ -49,6 +49,9 @@ class ConditionalVAE(Cell): | |||||
| Outputs: | Outputs: | ||||
| - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). | - **output** (tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| """ | """ | ||||
| def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes): | def __init__(self, encoder, decoder, hidden_size, latent_size, num_classes): | ||||
| @@ -44,6 +44,9 @@ class VAE(Cell): | |||||
| Outputs: | Outputs: | ||||
| - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). | - **output** (Tuple) - (recon_x(Tensor), x(Tensor), mu(Tensor), std(Tensor)). | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| """ | """ | ||||
| def __init__(self, encoder, decoder, hidden_size, latent_size): | def __init__(self, encoder, decoder, hidden_size, latent_size): | ||||
| @@ -41,6 +41,9 @@ class ELBO(Cell): | |||||
| Outputs: | Outputs: | ||||
| Tensor, loss float tensor. | Tensor, loss float tensor. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| """ | """ | ||||
| def __init__(self, latent_prior='Normal', output_prior='Normal'): | def __init__(self, latent_prior='Normal', output_prior='Normal'): | ||||
| @@ -34,6 +34,9 @@ class SVI: | |||||
| net_with_loss(Cell): Cell with loss function. | net_with_loss(Cell): Cell with loss function. | ||||
| optimizer (Cell): Optimizer for updating the weights. | optimizer (Cell): Optimizer for updating the weights. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| """ | """ | ||||
| def __init__(self, net_with_loss, optimizer): | def __init__(self, net_with_loss, optimizer): | ||||
| @@ -34,6 +34,9 @@ class VAEAnomalyDetection: | |||||
| hidden_size(int): The size of encoder's output tensor. | hidden_size(int): The size of encoder's output tensor. | ||||
| latent_size(int): The size of the latent space. | latent_size(int): The size of the latent space. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| """ | """ | ||||
| def __init__(self, encoder, decoder, hidden_size=400, latent_size=20): | def __init__(self, encoder, decoder, hidden_size=400, latent_size=20): | ||||
| @@ -53,6 +53,9 @@ class UncertaintyEvaluation: | |||||
| the the path of the uncertainty model; if the path is not given , it will not save or load the | the the path of the uncertainty model; if the path is not given , it will not save or load the | ||||
| uncertainty model. Default: False. | uncertainty model. Default: False. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | Examples: | ||||
| >>> network = LeNet() | >>> network = LeNet() | ||||
| >>> param_dict = load_checkpoint('checkpoint_lenet.ckpt') | >>> param_dict = load_checkpoint('checkpoint_lenet.ckpt') | ||||
| @@ -34,6 +34,9 @@ class TransformToBNN: | |||||
| dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. Default: 1. | dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. Default: 1. | ||||
| bnn_factor (int, float): The coefficient of KL loss, which is KL divergence of Bayesian layer. Default: 1. | bnn_factor (int, float): The coefficient of KL loss, which is KL divergence of Bayesian layer. Default: 1. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | Examples: | ||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| ... def __init__(self): | ... def __init__(self): | ||||
| @@ -57,7 +60,7 @@ class TransformToBNN: | |||||
| >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
| >>> net_with_loss = WithLossCell(network, criterion) | >>> net_with_loss = WithLossCell(network, criterion) | ||||
| >>> train_network = TrainOneStepCell(net_with_loss, optim) | >>> train_network = TrainOneStepCell(net_with_loss, optim) | ||||
| >>> bnn_transformer = TransformToBNN(train_network, 60000, 0.1) | |||||
| >>> bnn_transformer = TransformToBNN(train_network, 60000, 0.0001) | |||||
| """ | """ | ||||
| def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): | def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): | ||||
| @@ -105,6 +108,9 @@ class TransformToBNN: | |||||
| Returns: | Returns: | ||||
| Cell, a trainable BNN model wrapped by TrainOneStepCell. | Cell, a trainable BNN model wrapped by TrainOneStepCell. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | Examples: | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | ||||
| @@ -147,6 +153,9 @@ class TransformToBNN: | |||||
| Cell, a trainable model wrapped by TrainOneStepCell, whose specific type of layer is transformed to the | Cell, a trainable model wrapped by TrainOneStepCell, whose specific type of layer is transformed to the | ||||
| corresponding bayesian layer. | corresponding bayesian layer. | ||||
| Supported Platforms: | |||||
| ``Ascend`` ``GPU`` | |||||
| Examples: | Examples: | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | ||||