Browse Source

!4881 Fix param check

Merge pull request !4881 from byweng/fix_param_check
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
f69b4e03b7
5 changed files with 12 additions and 13 deletions
  1. +2
    -2
      mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py
  2. +2
    -3
      mindspore/nn/probability/bnn_layers/conv_variational.py
  3. +1
    -2
      mindspore/nn/probability/bnn_layers/dense_variational.py
  4. +5
    -4
      mindspore/nn/probability/bnn_layers/layer_distribution.py
  5. +2
    -2
      mindspore/nn/probability/transforms/transform_bnn.py

+ 2
- 2
mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py View File

@@ -65,9 +65,9 @@ class WithBNNLossCell:
""" """


def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)):
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`') raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)):
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`') raise TypeError('The type of `bnn_factor` should be `int` or `float`')


self.backbone = backbone self.backbone = backbone


+ 2
- 3
mindspore/nn/probability/bnn_layers/conv_variational.py View File

@@ -173,13 +173,12 @@ class ConvReparam(_ConvVariational):
r""" r"""
Convolutional variational layers with Reparameterization. Convolutional variational layers with Reparameterization.
See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.
Args: Args:
in_channels (int): The number of input channel :math:`C_{in}`. in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`. out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or
kernel_size (Union[int, tuple[int]]): The data type is int or
tuple with 2 integers. Specifies the height and width of the 2D tuple with 2 integers. Specifies the height and width of the 2D
convolution window. Single int means the value if for both convolution window. Single int means the value if for both
height and width of the kernel. A tuple of 2 ints means the height and width of the kernel. A tuple of 2 ints means the


+ 1
- 2
mindspore/nn/probability/bnn_layers/dense_variational.py View File

@@ -132,8 +132,7 @@ class DenseReparam(_DenseVariational):
r""" r"""
Dense variational layers with Reparameterization. Dense variational layers with Reparameterization.


See more details in paper `Auto-Encoding Variational Bayes
<https://arxiv.org/abs/1312.6114>`
See more details in paper `Auto-Encoding Variational Bayes <https://arxiv.org/abs/1312.6114>`_.


Applies dense-connected layer for the input. This layer implements the operation as: Applies dense-connected layer for the input. This layer implements the operation as:




+ 5
- 4
mindspore/nn/probability/bnn_layers/layer_distribution.py View File

@@ -78,16 +78,17 @@ class NormalPosterior(Cell):
if not isinstance(shape, (tuple, list)): if not isinstance(shape, (tuple, list)):
raise TypeError('The type of `shape` should be `tuple` or `list`') raise TypeError('The type of `shape` should be `tuple` or `list`')
if not isinstance(loc_mean, (int, float)):
if isinstance(loc_mean, bool) or not isinstance(loc_mean, (int, float)):
raise TypeError('The type of `loc_mean` should be `int` or `float`') raise TypeError('The type of `loc_mean` should be `int` or `float`')
if not isinstance(untransformed_scale_mean, (int, float)):
if isinstance(untransformed_scale_mean, bool) or not isinstance(untransformed_scale_mean, (int, float)):
raise TypeError('The type of `untransformed_scale_mean` should be `int` or `float`') raise TypeError('The type of `untransformed_scale_mean` should be `int` or `float`')
if not (isinstance(loc_std, (int, float)) and loc_std >= 0):
if isinstance(loc_std, bool) or not (isinstance(loc_std, (int, float)) and loc_std >= 0):
raise TypeError('The type of `loc_std` should be `int` or `float` and its value should > 0') raise TypeError('The type of `loc_std` should be `int` or `float` and its value should > 0')
if not (isinstance(untransformed_scale_std, (int, float)) and untransformed_scale_std >= 0):
if isinstance(loc_std, bool) or not (isinstance(untransformed_scale_std, (int, float)) and
untransformed_scale_std >= 0):
raise TypeError('The type of `untransformed_scale_std` should be `int` or `float` and ' raise TypeError('The type of `untransformed_scale_std` should be `int` or `float` and '
'its value should > 0') 'its value should > 0')


+ 2
- 2
mindspore/nn/probability/transforms/transform_bnn.py View File

@@ -61,9 +61,9 @@ class TransformToBNN:
""" """


def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1):
if not isinstance(dnn_factor, (int, float)):
if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)):
raise TypeError('The type of `dnn_factor` should be `int` or `float`') raise TypeError('The type of `dnn_factor` should be `int` or `float`')
if not isinstance(bnn_factor, (int, float)):
if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)):
raise TypeError('The type of `bnn_factor` should be `int` or `float`') raise TypeError('The type of `bnn_factor` should be `int` or `float`')


net_with_loss = trainable_dnn.network net_with_loss = trainable_dnn.network


Loading…
Cancel
Save