| @@ -67,8 +67,13 @@ class WithBNNLossCell: | |||
| def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): | |||
| if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)): | |||
| raise TypeError('The type of `dnn_factor` should be `int` or `float`') | |||
| if dnn_factor < 0: | |||
| raise ValueError('The value of `dnn_factor` should >= 0') | |||
| if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)): | |||
| raise TypeError('The type of `bnn_factor` should be `int` or `float`') | |||
| if bnn_factor < 0: | |||
| raise ValueError('The value of `bnn_factor` should >= 0') | |||
| self.backbone = backbone | |||
| self.loss_fn = loss_fn | |||
| @@ -61,12 +61,6 @@ class _ConvVariational(_Conv): | |||
| raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' | |||
| + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') | |||
| if isinstance(stride, bool) or not isinstance(stride, (int, tuple)): | |||
| raise TypeError('The type of `stride` should be `int` of `tuple`') | |||
| if isinstance(dilation, bool) or not isinstance(dilation, (int, tuple)): | |||
| raise TypeError('The type of `dilation` should be `int` of `tuple`') | |||
| # convolution args | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| @@ -29,7 +29,7 @@ class NormalPrior(Cell): | |||
| To initialize a normal distribution of mean 0 and standard deviation 0.1. | |||
| Args: | |||
| dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor. | |||
| dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor. | |||
| Default: mindspore.float32. | |||
| mean (int, float): Mean of normal distribution. | |||
| std (int, float): Standard deviation of normal distribution. | |||
| @@ -52,7 +52,7 @@ class NormalPosterior(Cell): | |||
| Args: | |||
| name (str): Name prepended to trainable parameter. | |||
| shape (list, tuple): Shape of the mean and standard deviation. | |||
| dtype (class `mindspore.dtype`): The argument is used to define the data type of the output tensor. | |||
| dtype (:class:`mindspore.dtype`): The argument is used to define the data type of the output tensor. | |||
| Default: mindspore.float32. | |||
| loc_mean (int, float): Mean of distribution to initialize trainable parameters. Default: 0. | |||
| loc_std (int, float): Standard deviation of distribution to initialize trainable parameters. Default: 0.1. | |||
| @@ -63,8 +63,13 @@ class TransformToBNN: | |||
| def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): | |||
| if isinstance(dnn_factor, bool) or not isinstance(dnn_factor, (int, float)): | |||
| raise TypeError('The type of `dnn_factor` should be `int` or `float`') | |||
| if dnn_factor < 0: | |||
| raise ValueError('The value of `dnn_factor` should >= 0') | |||
| if isinstance(bnn_factor, bool) or not isinstance(bnn_factor, (int, float)): | |||
| raise TypeError('The type of `bnn_factor` should be `int` or `float`') | |||
| if bnn_factor < 0: | |||
| raise ValueError('The value of `bnn_factor` should >= 0') | |||
| net_with_loss = trainable_dnn.network | |||
| self.optimizer = trainable_dnn.optimizer | |||
| @@ -88,9 +93,9 @@ class TransformToBNN: | |||
| Transform the whole DNN model to BNN model, and wrap BNN model by TrainOneStepCell. | |||
| Args: | |||
| get_dense_args (function): The arguments gotten from the DNN full connection layer. Default: lambda dp: | |||
| get_dense_args (:class:`function`): The arguments gotten from the DNN full connection layer. Default: lambda dp: | |||
| {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "has_bias": dp.has_bias}. | |||
| get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp: | |||
| get_conv_args (:class:`function`): The arguments gotten from the DNN convolutional layer. Default: lambda dp: | |||
| {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode, | |||
| "kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}. | |||
| add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in | |||
| @@ -134,10 +139,10 @@ class TransformToBNN: | |||
| Args: | |||
| dnn_layer_type (Cell): The type of DNN layer to be transformed to BNN layer. The optional values are | |||
| nn.Dense, nn.Conv2d. | |||
| nn.Dense, nn.Conv2d. | |||
| bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are | |||
| DenseReparameterization, ConvReparameterization. | |||
| get_args (dict): The arguments gotten from the DNN layer. Default: None. | |||
| DenseReparam, ConvReparam. | |||
| get_args (:class:`function`): The arguments gotten from the DNN layer. Default: None. | |||
| add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not | |||
| duplicate arguments in `get_args`. Default: None. | |||
| @@ -108,22 +108,22 @@ class VaeGan(nn.Cell): | |||
| return ld_real, ld_fake, ld_p, recon_x, x, mu, std | |||
| class VaeGanLoss(nn.Cell): | |||
| class VaeGanLoss(ELBO): | |||
| def __init__(self): | |||
| super(VaeGanLoss, self).__init__() | |||
| self.zeros = P.ZerosLike() | |||
| self.mse = nn.MSELoss(reduction='sum') | |||
| self.elbo = ELBO(latent_prior='Normal', output_prior='Normal') | |||
| def construct(self, data, label): | |||
| ld_real, ld_fake, ld_p, recon_x, x, mean, std = data | |||
| ld_real, ld_fake, ld_p, recon_x, x, mu, std = data | |||
| y_real = self.zeros(ld_real) + 1 | |||
| y_fake = self.zeros(ld_fake) | |||
| elbo_data = (recon_x, x, mean, std) | |||
| loss_D = self.mse(ld_real, y_real) | |||
| loss_GD = self.mse(ld_p, y_fake) | |||
| loss_G = self.mse(ld_fake, y_real) | |||
| elbo_loss = self.elbo(elbo_data, label) | |||
| reconstruct_loss = self.recon_loss(x, recon_x) | |||
| kl_loss = self.posterior('kl_loss', 'Normal', self.zeros(mu), self.zeros(mu) + 1, mu, std) | |||
| elbo_loss = reconstruct_loss + self.sum(kl_loss) | |||
| return loss_D + loss_G + loss_GD + elbo_loss | |||