Browse Source

!4881 Fix param check

Merge pull request !4881 from byweng/fix_param_check
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
f69b4e03b7
5 changed files with 12 additions and 13 deletions
  1. +2
    -2
      mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py
  2. +2
    -3
      mindspore/nn/probability/bnn_layers/conv_variational.py
  3. +1
    -2
      mindspore/nn/probability/bnn_layers/dense_variational.py
  4. +5
    -4
      mindspore/nn/probability/bnn_layers/layer_distribution.py
  5. +2
    -2
      mindspore/nn/probability/transforms/transform_bnn.py

+ 2
- 2
mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py View File

@@ -65,9 +65,9 @@ class WithBNNLossCell:
"""

def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)):
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)):
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`')

self.backbone = backbone


+ 2
- 3
mindspore/nn/probability/bnn_layers/conv_variational.py View File

@@ -173,13 +173,12 @@ class ConvReparam(_ConvVariational):
r"""
Convolutional variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or
kernel_size (Union[int, tuple[int]]): The data type is int or
tuple with 2 integers. Specifies the height and width of the 2D
convolution window. Single int means the value if for both
height and width of the kernel. A tuple of 2 ints means the


+ 1
- 2
mindspore/nn/probability/bnn_layers/dense_variational.py View File

@@ -132,8 +132,7 @@ class DenseReparam(_DenseVariational):
r"""
Dense variational layers with Reparameterization.

See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.

Applies dense-connected layer for the input. This layer implements the operation as:



+ 5
- 4
mindspore/nn/probability/bnn_layers/layer_distribution.py View File

@@ -78,16 +78,17 @@ class NormalPosterior(Cell):
if not isinstance(shape, (tuple, list)):
raise TypeError('The type of `shape` should be `tuple` or `list`')
if not isinstance(loc_mean, (int, float)):
if isinstance(loc_mean, bool) or not isinstance(loc_mean, (int, float)):
raise TypeError('The type of `loc_mean` should be `int` or `float`')
if not isinstance(untransformed_scale_mean, (int, float)):
if isinstance(untransformed_scale_mean, bool) or not isinstance(untransformed_scale_mean, (int, float)):
raise TypeError('The type of `untransformed_scale_mean` should be `int` or `float`')
if not (isinstance(loc_std, (int, float)) and loc_std >= 0):
if isinstance(loc_std, bool) or not (isinstance(loc_std, (int, float)) and loc_std >= 0):
raise TypeError('The type of `loc_std` should be `int` or `float` and its value should > 0')
if not (isinstance(untransformed_scale_std, (int, float)) and untransformed_scale_std >= 0):
if isinstance(loc_std, bool) or not (isinstance(untransformed_scale_std, (int, float)) and
untransformed_scale_std >= 0):
raise TypeError('The type of `untransformed_scale_std` should be `int` or `float` and '
'its value should > 0')


+ 2
- 2
mindspore/nn/probability/transforms/transform_bnn.py View File

@@ -61,9 +61,9 @@ class TransformToBNN:
"""

def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)):
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)):
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`')

net_with_loss = trainable_dnn.network


Loading…
Cancel
Save