Merge pull request !5132 from byweng/mastertags/v1.0.0
| @@ -21,13 +21,13 @@ The objective of MDP is to integrate deep learning with Bayesian learning. On th | |||||
| **Layer 1-2: Probabilistic inference algorithms** | **Layer 1-2: Probabilistic inference algorithms** | ||||
| - SVI([mindspore.nn.probability.dpn](https://gitee.com/mindspore/mindspore/tree/master/mindspore/nn/probability/dpn)): A unified interface for stochastic variational inference. | |||||
| - SVI([mindspore.nn.probability.infer.variational](https://gitee.com/mindspore/mindspore/tree/master/mindspore/nn/probability/infer/variational)): A unified interface for stochastic variational inference. | |||||
| - MC: Algorithms for approximating integrals via sampling. | - MC: Algorithms for approximating integrals via sampling. | ||||
| **Layer 2: Deep Probabilistic Programming (DPP) aims to provide composable BNN modules** | **Layer 2: Deep Probabilistic Programming (DPP) aims to provide composable BNN modules** | ||||
| - Layers([mindspore.nn.probability.bnn_layers](https://gitee.com/mindspore/mindspore/tree/master/mindspore/nn/probability/bnn_layers)): BNN layers, which are used to construct BNN. | - Layers([mindspore.nn.probability.bnn_layers](https://gitee.com/mindspore/mindspore/tree/master/mindspore/nn/probability/bnn_layers)): BNN layers, which are used to construct BNN. | ||||
| - Bnn: A bunch of BNN models that allow to be integrated into DNN; | |||||
| - Dpn([mindspore.nn.probability.dpn](https://gitee.com/mindspore/mindspore/tree/master/mindspore/nn/probability/dpn)): A bunch of BNN models that allow to be integrated into DNN; | |||||
| - Transform([mindspore.nn.probability.transforms](https://gitee.com/mindspore/mindspore/tree/master/mindspore/nn/probability/transforms)): Interfaces for the transformation between BNN and DNN; | - Transform([mindspore.nn.probability.transforms](https://gitee.com/mindspore/mindspore/tree/master/mindspore/nn/probability/transforms)): Interfaces for the transformation between BNN and DNN; | ||||
| - Context: context managers for models and layers. | - Context: context managers for models and layers. | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """Convolutional variational layers.""" | """Convolutional variational layers.""" | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore._checkparam import twice | from mindspore._checkparam import twice | ||||
| from ...layer.conv import _Conv | from ...layer.conv import _Conv | ||||
| from ...cell import Cell | from ...cell import Cell | ||||
| @@ -79,35 +80,45 @@ class _ConvVariational(_Conv): | |||||
| self.weight.requires_grad = False | self.weight.requires_grad = False | ||||
| if isinstance(weight_prior_fn, Cell): | if isinstance(weight_prior_fn, Cell): | ||||
| if weight_prior_fn.__class__.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') | |||||
| self.weight_prior = weight_prior_fn | self.weight_prior = weight_prior_fn | ||||
| else: | else: | ||||
| if weight_prior_fn.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') | |||||
| self.weight_prior = weight_prior_fn() | self.weight_prior = weight_prior_fn() | ||||
| for prior_name, prior_dist in self.weight_prior.name_cells().items(): | |||||
| if prior_name != 'normal': | |||||
| raise TypeError("The type of distribution of `weight_prior_fn` should be `normal`") | |||||
| if not (isinstance(getattr(prior_dist, '_mean_value'), Tensor) and | |||||
| isinstance(getattr(prior_dist, '_sd_value'), Tensor)): | |||||
| raise TypeError("The input form of `weight_prior_fn` is incorrect") | |||||
| try: | try: | ||||
| self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight') | self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight') | ||||
| except TypeError: | except TypeError: | ||||
| raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') | |||||
| raise TypeError('The input form of `weight_posterior_fn` is incorrect') | |||||
| for posterior_name, _ in self.weight_posterior.name_cells().items(): | |||||
| if posterior_name != 'normal': | |||||
| raise TypeError("The type of distribution of `weight_posterior_fn` should be `normal`") | |||||
| if self.has_bias: | if self.has_bias: | ||||
| self.bias.requires_grad = False | self.bias.requires_grad = False | ||||
| if isinstance(bias_prior_fn, Cell): | if isinstance(bias_prior_fn, Cell): | ||||
| if bias_prior_fn.__class__.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') | |||||
| self.bias_prior = bias_prior_fn | self.bias_prior = bias_prior_fn | ||||
| else: | else: | ||||
| if bias_prior_fn.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') | |||||
| self.bias_prior = bias_prior_fn() | self.bias_prior = bias_prior_fn() | ||||
| for prior_name, prior_dist in self.bias_prior.name_cells().items(): | |||||
| if prior_name != 'normal': | |||||
| raise TypeError("The type of distribution of `bias_prior_fn` should be `normal`") | |||||
| if not (isinstance(getattr(prior_dist, '_mean_value'), Tensor) and | |||||
| isinstance(getattr(prior_dist, '_sd_value'), Tensor)): | |||||
| raise TypeError("The input form of `bias_prior_fn` is incorrect") | |||||
| try: | try: | ||||
| self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') | self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') | ||||
| except TypeError: | except TypeError: | ||||
| raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') | raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') | ||||
| for posterior_name, _ in self.bias_posterior.name_cells().items(): | |||||
| if posterior_name != 'normal': | |||||
| raise TypeError("The type of distribution of `bias_posterior_fn` should be `normal`") | |||||
| # mindspore operations | # mindspore operations | ||||
| self.bias_add = P.BiasAdd() | self.bias_add = P.BiasAdd() | ||||
| @@ -221,16 +232,16 @@ class ConvReparam(_ConvVariational): | |||||
| normal distribution). The current version only supports NormalPrior. | normal distribution). The current version only supports NormalPrior. | ||||
| weight_posterior_fn: posterior distribution for sampling weight. | weight_posterior_fn: posterior distribution for sampling weight. | ||||
| It should be a function handle which returns a mindspore | It should be a function handle which returns a mindspore | ||||
| distribution instance. Default: NormalPosterior. The current | |||||
| version only supports NormalPosterior. | |||||
| distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). | |||||
| The current version only supports normal distribution. | |||||
| bias_prior_fn: prior distribution for bias vector. It should return | bias_prior_fn: prior distribution for bias vector. It should return | ||||
| a mindspore distribution. Default: NormalPrior(which creates an | a mindspore distribution. Default: NormalPrior(which creates an | ||||
| instance of standard normal distribution). The current version | instance of standard normal distribution). The current version | ||||
| only supports NormalPrior. | |||||
| only supports normal distribution. | |||||
| bias_posterior_fn: posterior distribution for sampling bias vector. | bias_posterior_fn: posterior distribution for sampling bias vector. | ||||
| It should be a function handle which returns a mindspore | It should be a function handle which returns a mindspore | ||||
| distribution instance. Default: NormalPosterior. The current | |||||
| version only supports NormalPosterior. | |||||
| distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). | |||||
| The current version only supports normal distribution. | |||||
| Inputs: | Inputs: | ||||
| - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. | ||||
| @@ -239,7 +250,6 @@ class ConvReparam(_ConvVariational): | |||||
| Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. | ||||
| Examples: | Examples: | ||||
| Examples: | |||||
| >>> net = ConvReparam(120, 240, 4, has_bias=False) | >>> net = ConvReparam(120, 240, 4, has_bias=False) | ||||
| >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) | ||||
| >>> net(input).shape | >>> net(input).shape | ||||
| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """dense_variational""" | """dense_variational""" | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore._checkparam import check_int_positive, check_bool | from mindspore._checkparam import check_int_positive, check_bool | ||||
| from ...cell import Cell | from ...cell import Cell | ||||
| from ...layer.activation import get_activation | from ...layer.activation import get_activation | ||||
| @@ -43,33 +44,43 @@ class _DenseVariational(Cell): | |||||
| self.has_bias = check_bool(has_bias) | self.has_bias = check_bool(has_bias) | ||||
| if isinstance(weight_prior_fn, Cell): | if isinstance(weight_prior_fn, Cell): | ||||
| if weight_prior_fn.__class__.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') | |||||
| self.weight_prior = weight_prior_fn | self.weight_prior = weight_prior_fn | ||||
| else: | else: | ||||
| if weight_prior_fn.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') | |||||
| self.weight_prior = weight_prior_fn() | self.weight_prior = weight_prior_fn() | ||||
| for prior_name, prior_dist in self.weight_prior.name_cells().items(): | |||||
| if prior_name != 'normal': | |||||
| raise TypeError("The type of distribution of `weight_prior_fn` should be `normal`") | |||||
| if not (isinstance(getattr(prior_dist, '_mean_value'), Tensor) and | |||||
| isinstance(getattr(prior_dist, '_sd_value'), Tensor)): | |||||
| raise TypeError("The input form of `weight_prior_fn` is incorrect") | |||||
| try: | try: | ||||
| self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight') | self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight') | ||||
| except TypeError: | except TypeError: | ||||
| raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') | raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') | ||||
| for posterior_name, _ in self.weight_posterior.name_cells().items(): | |||||
| if posterior_name != 'normal': | |||||
| raise TypeError("The type of distribution of `weight_posterior_fn` should be `normal`") | |||||
| if self.has_bias: | if self.has_bias: | ||||
| if isinstance(bias_prior_fn, Cell): | if isinstance(bias_prior_fn, Cell): | ||||
| if bias_prior_fn.__class__.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') | |||||
| self.bias_prior = bias_prior_fn | self.bias_prior = bias_prior_fn | ||||
| else: | else: | ||||
| if bias_prior_fn.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') | |||||
| self.bias_prior = bias_prior_fn() | self.bias_prior = bias_prior_fn() | ||||
| for prior_name, prior_dist in self.bias_prior.name_cells().items(): | |||||
| if prior_name != 'normal': | |||||
| raise TypeError("The type of distribution of `bias_prior_fn` should be `normal`") | |||||
| if not (isinstance(getattr(prior_dist, '_mean_value'), Tensor) and | |||||
| isinstance(getattr(prior_dist, '_sd_value'), Tensor)): | |||||
| raise TypeError("The input form of `bias_prior_fn` is incorrect") | |||||
| try: | try: | ||||
| self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') | self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') | ||||
| except TypeError: | except TypeError: | ||||
| raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') | raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') | ||||
| for posterior_name, _ in self.bias_posterior.name_cells().items(): | |||||
| if posterior_name != 'normal': | |||||
| raise TypeError("The type of distribution of `bias_posterior_fn` should be `normal`") | |||||
| self.activation = activation | self.activation = activation | ||||
| if not self.activation: | if not self.activation: | ||||
| @@ -160,16 +171,16 @@ class DenseReparam(_DenseVariational): | |||||
| normal distribution). The current version only supports NormalPrior. | normal distribution). The current version only supports NormalPrior. | ||||
| weight_posterior_fn: posterior distribution for sampling weight. | weight_posterior_fn: posterior distribution for sampling weight. | ||||
| It should be a function handle which returns a mindspore | It should be a function handle which returns a mindspore | ||||
| distribution instance. Default: NormalPosterior. The current | |||||
| version only supports NormalPosterior. | |||||
| distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). | |||||
| The current version only supports normal distribution. | |||||
| bias_prior_fn: prior distribution for bias vector. It should return | bias_prior_fn: prior distribution for bias vector. It should return | ||||
| a mindspore distribution. Default: NormalPrior(which creates an | a mindspore distribution. Default: NormalPrior(which creates an | ||||
| instance of standard normal distribution). The current version | instance of standard normal distribution). The current version | ||||
| only supports NormalPrior. | only supports NormalPrior. | ||||
| bias_posterior_fn: posterior distribution for sampling bias vector. | bias_posterior_fn: posterior distribution for sampling bias vector. | ||||
| It should be a function handle which returns a mindspore | It should be a function handle which returns a mindspore | ||||
| distribution instance. Default: NormalPosterior. The current | |||||
| version only supports NormalPosterior. | |||||
| distribution instance. Default: lambda name, shape: NormalPosterior(name=name, shape=shape). | |||||
| The current version only supports normal distribution. | |||||
| Inputs: | Inputs: | ||||
| - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. | - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. | ||||
| @@ -180,7 +191,8 @@ class DenseReparam(_DenseVariational): | |||||
| Examples: | Examples: | ||||
| >>> net = DenseReparam(3, 4) | >>> net = DenseReparam(3, 4) | ||||
| >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) | ||||
| >>> net(input) | |||||
| >>> net(input).shape | |||||
| (2, 4) | |||||
| """ | """ | ||||
| def __init__( | def __init__( | ||||
| @@ -31,8 +31,8 @@ class NormalPrior(Cell): | |||||
| Args: | Args: | ||||
| dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor. | dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor. | ||||
| Default: mindspore.float32. | Default: mindspore.float32. | ||||
| mean (int, float): Mean of normal distribution. | |||||
| std (int, float): Standard deviation of normal distribution. | |||||
| mean (int, float): Mean of normal distribution. Default: 0. | |||||
| std (int, float): Standard deviation of normal distribution. Default: 0.1. | |||||
| Returns: | Returns: | ||||
| Cell, a normal distribution. | Cell, a normal distribution. | ||||
| @@ -99,7 +99,7 @@ class ConditionalVAE(Cell): | |||||
| Randomly sample from latent space to generate sample. | Randomly sample from latent space to generate sample. | ||||
| Args: | Args: | ||||
| sample_y (Tensor): Define the label of sample, int tensor, the shape is (generate_nums, ). | |||||
| sample_y (Tensor): Define the label of sample. Tensor of shape (generate_nums, ) and type mindspore.int32. | |||||
| generate_nums (int): The number of samples to generate. | generate_nums (int): The number of samples to generate. | ||||
| shape(tuple): The shape of sample, it should be (generate_nums, C, H, W) or (-1, C, H, W). | shape(tuple): The shape of sample, it should be (generate_nums, C, H, W) or (-1, C, H, W). | ||||
| @@ -68,6 +68,10 @@ class UncertaintyEvaluation: | |||||
| >>> save_model=False) | >>> save_model=False) | ||||
| >>> epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data) | >>> epistemic_uncertainty = evaluation.eval_epistemic_uncertainty(eval_data) | ||||
| >>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data) | >>> aleatoric_uncertainty = evaluation.eval_aleatoric_uncertainty(eval_data) | ||||
| >>> epistemic_uncertainty.shape | |||||
| (32, 10) | |||||
| >>> aleatoric_uncertainty.shape | |||||
| (32,) | |||||
| """ | """ | ||||
| def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1, | def __init__(self, model, train_dataset, task_type, num_classes=None, epochs=1, | ||||
| @@ -31,8 +31,8 @@ class TransformToBNN: | |||||
| Args: | Args: | ||||
| trainable_dnn (Cell): A trainable DNN model (backbone) wrapped by TrainOneStepCell. | trainable_dnn (Cell): A trainable DNN model (backbone) wrapped by TrainOneStepCell. | ||||
| dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. | |||||
| bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. | |||||
| dnn_factor ((int, float): The coefficient of backbone's loss, which is computed by loss function. Default: 1. | |||||
| bnn_factor (int, float): The coefficient of kl loss, which is kl divergence of Bayesian layer. Default: 1. | |||||
| Examples: | Examples: | ||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| @@ -93,15 +93,15 @@ class TransformToBNN: | |||||
| Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell. | Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell. | ||||
| Args: | Args: | ||||
| get_dense_args (:class:`function`): The arguments gotten from the DNN full connection layer. Default: lambda dp: | |||||
| get_dense_args: The arguments gotten from the DNN full connection layer. Default: lambda dp: | |||||
| {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}. | {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}. | ||||
| get_conv_args (:class:`function`): The arguments gotten from the DNN convolutional layer. Default: lambda dp: | |||||
| get_conv_args: The arguments gotten from the DNN convolutional layer. Default: lambda dp: | |||||
| {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode, | {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode, | ||||
| "kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}. | "kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}. | ||||
| add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in | add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in | ||||
| `add_dense_args` should not duplicate arguments in `get_dense_args`. Default: {}. | |||||
| `add_dense_args` should not duplicate arguments in `get_dense_args`. Default: None. | |||||
| add_conv_args (dict): The new arguments added to BNN convolutional layer. Note that the arguments in | add_conv_args (dict): The new arguments added to BNN convolutional layer. Note that the arguments in | ||||
| `add_conv_args` should not duplicate arguments in `get_conv_args`. Default: {}. | |||||
| `add_conv_args` should not duplicate arguments in `get_conv_args`. Default: None. | |||||
| Returns: | Returns: | ||||
| Cell, a trainable BNN model wrapped by TrainOneStepCell. | Cell, a trainable BNN model wrapped by TrainOneStepCell. | ||||
| @@ -142,7 +142,7 @@ class TransformToBNN: | |||||
| nn.Dense, nn.Conv2d. | nn.Dense, nn.Conv2d. | ||||
| bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are | bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are | ||||
| DenseReparam, ConvReparam. | DenseReparam, ConvReparam. | ||||
| get_args (:class:`function`): The arguments gotten from the DNN layer. Default: None. | |||||
| get_args: The arguments gotten from the DNN layer. Default: None. | |||||
| add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not | add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not | ||||
| duplicate arguments in `get_args`. Default: None. | duplicate arguments in `get_args`. Default: None. | ||||