diff --git a/mindspore/nn/probability/toolbox/__init__.py b/mindspore/nn/probability/toolbox/__init__.py index 8391cd9185..867a00a724 100644 --- a/mindspore/nn/probability/toolbox/__init__.py +++ b/mindspore/nn/probability/toolbox/__init__.py @@ -17,5 +17,6 @@ Uncertainty toolbox. """ from .uncertainty_evaluation import UncertaintyEvaluation +from .anomaly_detection import VAEAnomalyDetection -__all__ = ['UncertaintyEvaluation'] +__all__ = ['UncertaintyEvaluation', 'VAEAnomalyDetection'] diff --git a/mindspore/nn/probability/toolbox/anomaly_detection.py b/mindspore/nn/probability/toolbox/anomaly_detection.py new file mode 100644 index 0000000000..4673bace70 --- /dev/null +++ b/mindspore/nn/probability/toolbox/anomaly_detection.py @@ -0,0 +1,91 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Toolbox for anomaly detection by using VAE.""" +import numpy as np + +from ..dpn import VAE +from ..infer import ELBO, SVI +from ...optim import Adam +from ...wrap.cell_wrapper import WithLossCell + + +class VAEAnomalyDetection: + r""" + Toolbox for anomaly detection by using VAE. + + Variational Auto-Encoder(VAE) can be used for Unsupervised Anomaly Detection. The anomaly score is the error + between the X and the reconstruction. If the score is high, the X is mostly outlier. + + Args: + encoder(Cell): The Deep Neural Network (DNN) model defined as encoder. + decoder(Cell): The DNN model defined as decoder. + hidden_size(int): The size of encoder's output tensor. + latent_size(int): The size of the latent space. + + """ + + def __init__(self, encoder, decoder, hidden_size=400, latent_size=20): + self.vae = VAE(encoder, decoder, hidden_size, latent_size) + + def train(self, train_dataset, epochs=5): + """ + Train the VAE model. + + Args: + train_dataset (Dataset): A dataset iterator to train model. + epochs (int): Total number of iterations on the data. Default: 5. + + Returns: + Cell, the trained model. + """ + net_loss = ELBO() + optimizer = Adam(params=self.vae.trainable_params(), learning_rate=0.001) + net_with_loss = WithLossCell(self.vae, net_loss) + vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) + self.vae = vi.run(train_dataset, epochs) + return self.vae + + def predict_outlier_score(self, sample_x): + """ + Predict the outlier score. + + Args: + sample_x (Tensor): The sample to be predicted, the shape is (N, C, H, W). + + Returns: + numpy.dtype, the predicted outlier score of the sample. + """ + reconstructed_sample = self.vae.reconstruct_sample(sample_x) + return self._calculate_euclidean_distance(sample_x.asnumpy(), reconstructed_sample.asnumpy()) + + def predict_outlier(self, sample_x, threshold=100.0): + """ + Predict whether the sample is an outlier. + + Args: + sample_x (Tensor): The sample to be predicted, the shape is (N, C, H, W). + threshold (float): the threshold of the outlier. Default: 100.0. + + Returns: + Bool, whether the sample is an outlier. + """ + score = self.predict_outlier_score(sample_x) + return score >= threshold + + def _calculate_euclidean_distance(self, sample_x, reconstructed_sample): + """ + Calculate the euclidean distance of the sample_x and reconstructed_sample. + """ + return np.sqrt(np.sum(np.square(sample_x - reconstructed_sample))) diff --git a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py index 20a75b38eb..113c081eee 100644 --- a/mindspore/nn/probability/toolbox/uncertainty_evaluation.py +++ b/mindspore/nn/probability/toolbox/uncertainty_evaluation.py @@ -47,7 +47,6 @@ class UncertaintyEvaluation: Default: None. epochs (int): Total number of iterations on the data. Default: 1. epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None. - If the epi_uncer_model_path is 'Untrain', the epistemic model need not to be trained. ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None. save_model (bool): Whether to save the uncertainty model or not, if true, the epi_uncer_model_path and ale_uncer_model_path must not be None. If false, the model to evaluate will be loaded from @@ -82,7 +81,7 @@ class UncertaintyEvaluation: self.epi_model = model self.ale_model = deepcopy(model) self.epi_train_dataset = train_dataset - self.ale_train_dataset = deepcopy(train_dataset) + self.ale_train_dataset = train_dataset self.task_type = task_type self.epochs = Validator.check_positive_int(epochs) self.epi_uncer_model_path = epi_uncer_model_path @@ -112,7 +111,7 @@ class UncertaintyEvaluation: """ if self.epi_uncer_model is None: self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model) - if self.epi_uncer_model.drop_count == 0 and self.epi_uncer_model_path != 'Untrain': + if self.epi_uncer_model.drop_count == 0 and self.epi_train_dataset is not None: if self.task_type == 'classification': net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_opt = Adam(self.epi_uncer_model.trainable_params()) @@ -156,6 +155,8 @@ class UncertaintyEvaluation: """ Get the model which can obtain the aleatoric uncertainty. """ + if self.ale_train_dataset is None: + raise ValueError('The train dataset should not be None when evaluating aleatoric uncertainty.') if self.ale_uncer_model is None: self.ale_uncer_model = AleatoricUncertaintyModel(self.ale_model, self.num_classes, self.task_type) net_loss = AleatoricLoss(self.task_type) @@ -239,17 +240,17 @@ class EpistemicUncertaintyModel(Cell): The dropout rate is set to 0.5 by default. """ for (name, layer) in epi_model.name_cells().items(): - if isinstance(layer, Dropout): - self.drop_count += 1 - return epi_model - for (name, layer) in epi_model.name_cells().items(): - if isinstance(layer, (Conv2d, Dense)): + if isinstance(layer, (Conv2d, Dense, Dropout)): + if isinstance(layer, Dropout): + self.drop_count += 1 + return epi_model uncertainty_layer = layer uncertainty_name = name drop = Dropout(keep_prob=dropout_rate) bnn_drop = SequentialCell([uncertainty_layer, drop]) setattr(epi_model, uncertainty_name, bnn_drop) return epi_model + self._make_epistemic(layer) raise ValueError("The model has not Dense Layer or Convolution Layer, " "it can not evaluate epistemic uncertainty so far.") diff --git a/mindspore/nn/probability/transforms/transform_bnn.py b/mindspore/nn/probability/transforms/transform_bnn.py index 67c7659174..7bb15088c6 100644 --- a/mindspore/nn/probability/transforms/transform_bnn.py +++ b/mindspore/nn/probability/transforms/transform_bnn.py @@ -36,21 +36,21 @@ class TransformToBNN: Examples: >>> class Net(nn.Cell): - >>> def __init__(self): - >>> super(Net, self).__init__() - >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') - >>> self.bn = nn.BatchNorm2d(64) - >>> self.relu = nn.ReLU() - >>> self.flatten = nn.Flatten() - >>> self.fc = nn.Dense(64*224*224, 12) # padding=0 - >>> - >>> def construct(self, x): - >>> x = self.conv(x) - >>> x = self.bn(x) - >>> x = self.relu(x) - >>> x = self.flatten(x) - >>> out = self.fc(x) - >>> return out + ... def __init__(self): + ... super(Net, self).__init__() + ... self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') + ... self.bn = nn.BatchNorm2d(64) + ... self.relu = nn.ReLU() + ... self.flatten = nn.Flatten() + ... self.fc = nn.Dense(64*224*224, 12) # padding=0 + ... + ... def construct(self, x): + ... x = self.conv(x) + ... x = self.bn(x) + ... x = self.relu(x) + ... x = self.flatten(x) + ... out = self.fc(x) + ... return out >>> >>> net = Net() >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True)