Browse Source

fix error of conv_variational

tags/v0.7.0-beta
bingyaweng 5 years ago
parent
commit
b126272b39
4 changed files with 7 additions and 7 deletions
  1. +4
    -4
      mindspore/nn/probability/bnn_layers/conv_variational.py
  2. +1
    -1
      tests/st/probability/test_bnn_layer.py
  3. +1
    -1
      tests/st/probability/test_transform_bnn_layer.py
  4. +1
    -1
      tests/st/probability/test_transform_bnn_model.py

+ 4
- 4
mindspore/nn/probability/bnn_layers/conv_variational.py View File

@@ -61,10 +61,10 @@ class _ConvVariational(_Conv):
raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed '
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
if not isinstance(stride, (int, tuple)):
if isinstance(stride, bool) or not isinstance(stride, (int, tuple)):
raise TypeError('The type of `stride` should be `int` of `tuple`')
if not isinstance(dilation, (int, tuple)):
if isinstance(dilation, bool) or not isinstance(dilation, (int, tuple)):
raise TypeError('The type of `dilation` should be `int` of `tuple`')
# convolution args
@@ -136,8 +136,8 @@ class _ConvVariational(_Conv):
return outputs
def extend_repr(self):
str_info = 'in_channels={}, out_channels={}, kernel_size={}, weight_mean={}, stride={}, pad_mode={}, ' \
'padding={}, dilation={}, group={}, weight_std={}, has_bias={}'\
str_info = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, pad_mode={}, ' \
'padding={}, dilation={}, group={}, weight_mean={}, weight_std={}, has_bias={}'\
.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.pad_mode, self.padding,
self.dilation, self.group, self.weight_posterior.mean, self.weight_posterior.untransformed_std,
self.has_bias)


+ 1
- 1
tests/st/probability/test_bnn_layer.py View File

@@ -137,7 +137,7 @@ if __name__ == "__main__":
epoch = 100

for i in range(epoch):
train_loss, train_acc = train_model(train_bnn_network, network, test_set)
train_loss, train_acc = train_model(train_bnn_network, network, train_set)

valid_acc = validate_model(network, test_set)



+ 1
- 1
tests/st/probability/test_transform_bnn_layer.py View File

@@ -142,7 +142,7 @@ if __name__ == "__main__":
epoch = 100

for i in range(epoch):
train_loss, train_acc = train_model(train_bnn_network, network, test_set)
train_loss, train_acc = train_model(train_bnn_network, network, train_set)

valid_acc = validate_model(network, test_set)



+ 1
- 1
tests/st/probability/test_transform_bnn_model.py View File

@@ -141,7 +141,7 @@ if __name__ == "__main__":
epoch = 500

for i in range(epoch):
train_loss, train_acc = train_model(train_bnn_network, network, test_set)
train_loss, train_acc = train_model(train_bnn_network, network, train_set)

valid_acc = validate_model(network, test_set)



Loading…
Cancel
Save