Merge pull request !4717 from byweng/fix_errortags/v0.7.0-beta
| @@ -26,8 +26,8 @@ class ClassWrap: | |||||
| self._cls = cls | self._cls = cls | ||||
| self.bnn_loss_file = None | self.bnn_loss_file = None | ||||
| def __call__(self, backbone, loss_fn, backbone_factor, kl_factor): | |||||
| obj = self._cls(backbone, loss_fn, backbone_factor, kl_factor) | |||||
| def __call__(self, backbone, loss_fn, dnn_factor, bnn_factor): | |||||
| obj = self._cls(backbone, loss_fn, dnn_factor, bnn_factor) | |||||
| bnn_with_loss = obj() | bnn_with_loss = obj() | ||||
| self.bnn_loss_file = obj.bnn_loss_file | self.bnn_loss_file = obj.bnn_loss_file | ||||
| return bnn_with_loss | return bnn_with_loss | ||||
| @@ -65,6 +65,11 @@ class WithBNNLossCell: | |||||
| """ | """ | ||||
| def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): | def __init__(self, backbone, loss_fn, dnn_factor=1, bnn_factor=1): | ||||
| if not isinstance(dnn_factor, (int, float)): | |||||
| raise TypeError('The type of `dnn_factor` should be `int` or `float`') | |||||
| if not isinstance(bnn_factor, (int, float)): | |||||
| raise TypeError('The type of `bnn_factor` should be `int` or `float`') | |||||
| self.backbone = backbone | self.backbone = backbone | ||||
| self.loss_fn = loss_fn | self.loss_fn = loss_fn | ||||
| self.dnn_factor = dnn_factor | self.dnn_factor = dnn_factor | ||||
| @@ -79,20 +79,40 @@ class _ConvVariational(_Conv): | |||||
| self.weight.requires_grad = False | self.weight.requires_grad = False | ||||
| if isinstance(weight_prior_fn, Cell): | if isinstance(weight_prior_fn, Cell): | ||||
| if weight_prior_fn.__class__.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') | |||||
| self.weight_prior = weight_prior_fn | self.weight_prior = weight_prior_fn | ||||
| else: | else: | ||||
| if weight_prior_fn.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') | |||||
| self.weight_prior = weight_prior_fn() | self.weight_prior = weight_prior_fn() | ||||
| if isinstance(weight_posterior_fn, Cell): | |||||
| if weight_posterior_fn.__class__.__name__ != 'NormalPosterior': | |||||
| raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') | |||||
| else: | |||||
| if weight_posterior_fn.__name__ != 'NormalPosterior': | |||||
| raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') | |||||
| self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight') | self.weight_posterior = weight_posterior_fn(shape=self.shape, name='bnn_weight') | ||||
| if self.has_bias: | if self.has_bias: | ||||
| self.bias.requires_grad = False | self.bias.requires_grad = False | ||||
| if isinstance(bias_prior_fn, Cell): | if isinstance(bias_prior_fn, Cell): | ||||
| if bias_prior_fn.__class__.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') | |||||
| self.bias_prior = bias_prior_fn | self.bias_prior = bias_prior_fn | ||||
| else: | else: | ||||
| if bias_prior_fn.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') | |||||
| self.bias_prior = bias_prior_fn() | self.bias_prior = bias_prior_fn() | ||||
| if isinstance(bias_posterior_fn, Cell): | |||||
| if bias_posterior_fn.__class__.__name__ != 'NormalPosterior': | |||||
| raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') | |||||
| else: | |||||
| if bias_posterior_fn.__name__ != 'NormalPosterior': | |||||
| raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') | |||||
| self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') | self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') | ||||
| # mindspore operations | # mindspore operations | ||||
| @@ -43,18 +43,38 @@ class _DenseVariational(Cell): | |||||
| self.has_bias = check_bool(has_bias) | self.has_bias = check_bool(has_bias) | ||||
| if isinstance(weight_prior_fn, Cell): | if isinstance(weight_prior_fn, Cell): | ||||
| if weight_prior_fn.__class__.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') | |||||
| self.weight_prior = weight_prior_fn | self.weight_prior = weight_prior_fn | ||||
| else: | else: | ||||
| if weight_prior_fn.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `weight_prior_fn` should be `NormalPrior`') | |||||
| self.weight_prior = weight_prior_fn() | self.weight_prior = weight_prior_fn() | ||||
| if isinstance(weight_posterior_fn, Cell): | |||||
| if weight_posterior_fn.__class__.__name__ != 'NormalPosterior': | |||||
| raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') | |||||
| else: | |||||
| if weight_posterior_fn.__name__ != 'NormalPosterior': | |||||
| raise TypeError('The type of `weight_posterior_fn` should be `NormalPosterior`') | |||||
| self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight') | self.weight_posterior = weight_posterior_fn(shape=[self.out_channels, self.in_channels], name='bnn_weight') | ||||
| if self.has_bias: | if self.has_bias: | ||||
| if isinstance(bias_prior_fn, Cell): | if isinstance(bias_prior_fn, Cell): | ||||
| if bias_prior_fn.__class__.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') | |||||
| self.bias_prior = bias_prior_fn | self.bias_prior = bias_prior_fn | ||||
| else: | else: | ||||
| if bias_prior_fn.__name__ != 'NormalPrior': | |||||
| raise TypeError('The type of `bias_prior_fn` should be `NormalPrior`') | |||||
| self.bias_prior = bias_prior_fn() | self.bias_prior = bias_prior_fn() | ||||
| if isinstance(bias_posterior_fn, Cell): | |||||
| if bias_posterior_fn.__class__.__name__ != 'NormalPosterior': | |||||
| raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') | |||||
| else: | |||||
| if bias_posterior_fn.__name__ != 'NormalPosterior': | |||||
| raise TypeError('The type of `bias_posterior_fn` should be `NormalPosterior`') | |||||
| self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') | self.bias_posterior = bias_posterior_fn(shape=[self.out_channels], name='bnn_bias') | ||||
| self.activation = activation | self.activation = activation | ||||
| @@ -75,7 +75,18 @@ class NormalPosterior(Cell): | |||||
| untransformed_scale_std=0.1): | untransformed_scale_std=0.1): | ||||
| super(NormalPosterior, self).__init__() | super(NormalPosterior, self).__init__() | ||||
| if not isinstance(name, str): | if not isinstance(name, str): | ||||
| raise ValueError('The type of `name` should be `str`') | |||||
| raise TypeError('The type of `name` should be `str`') | |||||
| if not isinstance(shape, (tuple, list)): | |||||
| raise TypeError('The type of `shape` should be `tuple` or `list`') | |||||
| if not (np.array(shape) > 0).all(): | |||||
| raise ValueError('Negative dimensions are not allowed') | |||||
| if not (np.array(loc_std) >= 0).all(): | |||||
| raise ValueError('The value of `loc_std` < 0') | |||||
| if not (np.array(untransformed_scale_std) >= 0).all(): | |||||
| raise ValueError('The value of `untransformed_scale_std` < 0') | |||||
| self.mean = Parameter( | self.mean = Parameter( | ||||
| Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean') | Tensor(np.random.normal(loc_mean, loc_std, shape), dtype=dtype), name=name + '_mean') | ||||
| @@ -15,7 +15,7 @@ | |||||
| """ | """ | ||||
| Transforms. | Transforms. | ||||
| The high-level components used to transform model between DNN and DNN. | |||||
| The high-level components used to transform model between DNN and BNN. | |||||
| """ | """ | ||||
| from . import transform_bnn | from . import transform_bnn | ||||
| from .transform_bnn import TransformToBNN | from .transform_bnn import TransformToBNN | ||||
| @@ -54,3 +54,13 @@ class WithBNNLossCell(nn.Cell): | |||||
| self.kl_loss.append(layer.compute_kl_loss) | self.kl_loss.append(layer.compute_kl_loss) | ||||
| else: | else: | ||||
| self._add_kl_loss(layer) | self._add_kl_loss(layer) | ||||
| @property | |||||
| def backbone_network(self): | |||||
| """ | |||||
| Returns the backbone network. | |||||
| Returns: | |||||
| Cell, the backbone network. | |||||
| """ | |||||
| return self._backbone | |||||
| @@ -61,6 +61,11 @@ class TransformToBNN: | |||||
| """ | """ | ||||
| def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): | def __init__(self, trainable_dnn, dnn_factor=1, bnn_factor=1): | ||||
| if not isinstance(dnn_factor, (int, float)): | |||||
| raise TypeError('The type of `dnn_factor` should be `int` or `float`') | |||||
| if not isinstance(bnn_factor, (int, float)): | |||||
| raise TypeError('The type of `bnn_factor` should be `int` or `float`') | |||||
| net_with_loss = trainable_dnn.network | net_with_loss = trainable_dnn.network | ||||
| self.optimizer = trainable_dnn.optimizer | self.optimizer = trainable_dnn.optimizer | ||||
| self.backbone = net_with_loss.backbone_network | self.backbone = net_with_loss.backbone_network | ||||
| @@ -88,8 +93,10 @@ class TransformToBNN: | |||||
| get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp: | get_conv_args (function): The arguments gotten from the DNN convolutional layer. Default: lambda dp: | ||||
| {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode, | {"in_channels": dp.in_channels, "out_channels": dp.out_channels, "pad_mode": dp.pad_mode, | ||||
| "kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}. | "kernel_size": dp.kernel_size, "stride": dp.stride, "has_bias": dp.has_bias}. | ||||
| add_dense_args (dict): The new arguments added to BNN full connection layer. Default: {}. | |||||
| add_conv_args (dict): The new arguments added to BNN convolutional layer. Default: {}. | |||||
| add_dense_args (dict): The new arguments added to BNN full connection layer. Note that the arguments in | |||||
| `add_dense_args` should not duplicate arguments in `get_dense_args`. Default: {}. | |||||
| add_conv_args (dict): The new arguments added to BNN convolutional layer. Note that the arguments in | |||||
| `add_conv_args` should not duplicate arguments in `get_conv_args`. Default: {}. | |||||
| Returns: | Returns: | ||||
| Cell, a trainable BNN model wrapped by TrainOneStepCell. | Cell, a trainable BNN model wrapped by TrainOneStepCell. | ||||
| @@ -131,7 +138,8 @@ class TransformToBNN: | |||||
| bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are | bnn_layer_type (Cell): The type of BNN layer to be transformed to. The optional values are | ||||
| DenseReparameterization, ConvReparameterization. | DenseReparameterization, ConvReparameterization. | ||||
| get_args (dict): The arguments gotten from the DNN layer. Default: None. | get_args (dict): The arguments gotten from the DNN layer. Default: None. | ||||
| add_args (dict): The new arguments added to BNN layer. Default: None. | |||||
| add_args (dict): The new arguments added to BNN layer. Note that the arguments in `add_args` should not | |||||
| duplicate arguments in `get_args`. Default: None. | |||||
| Returns: | Returns: | ||||
| Cell, a trainable model wrapped by TrainOneStepCell, whose sprcific type of layer is transformed to the | Cell, a trainable model wrapped by TrainOneStepCell, whose sprcific type of layer is transformed to the | ||||