From: @zhangxinfeng3 Reviewed-by: @sunnybeike,@wang_zi_dong Signed-off-by: @sunnybeiketags/v1.1.0
| @@ -17,5 +17,6 @@ Uncertainty toolbox. | |||
| """ | |||
| from .uncertainty_evaluation import UncertaintyEvaluation | |||
| from .anomaly_detection import VAEAnomalyDetection | |||
| __all__ = ['UncertaintyEvaluation'] | |||
| __all__ = ['UncertaintyEvaluation', 'VAEAnomalyDetection'] | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Toolbox for anomaly detection by using VAE.""" | |||
| import numpy as np | |||
| from ..dpn import VAE | |||
| from ..infer import ELBO, SVI | |||
| from ...optim import Adam | |||
| from ...wrap.cell_wrapper import WithLossCell | |||
| class VAEAnomalyDetection: | |||
| r""" | |||
| Toolbox for anomaly detection by using VAE. | |||
| Variational Auto-Encoder(VAE) can be used for Unsupervised Anomaly Detection. The anomaly score is the error | |||
| between the X and the reconstruction. If the score is high, the X is mostly outlier. | |||
| Args: | |||
| encoder(Cell): The Deep Neural Network (DNN) model defined as encoder. | |||
| decoder(Cell): The DNN model defined as decoder. | |||
| hidden_size(int): The size of encoder's output tensor. | |||
| latent_size(int): The size of the latent space. | |||
| """ | |||
| def __init__(self, encoder, decoder, hidden_size=400, latent_size=20): | |||
| self.vae = VAE(encoder, decoder, hidden_size, latent_size) | |||
| def train(self, train_dataset, epochs=5): | |||
| """ | |||
| Train the VAE model. | |||
| Args: | |||
| train_dataset (Dataset): A dataset iterator to train model. | |||
| epochs (int): Total number of iterations on the data. Default: 5. | |||
| Returns: | |||
| Cell, the trained model. | |||
| """ | |||
| net_loss = ELBO() | |||
| optimizer = Adam(params=self.vae.trainable_params(), learning_rate=0.001) | |||
| net_with_loss = WithLossCell(self.vae, net_loss) | |||
| vi = SVI(net_with_loss=net_with_loss, optimizer=optimizer) | |||
| self.vae = vi.run(train_dataset, epochs) | |||
| return self.vae | |||
| def predict_outlier_score(self, sample_x): | |||
| """ | |||
| Predict the outlier score. | |||
| Args: | |||
| sample_x (Tensor): The sample to be predicted, the shape is (N, C, H, W). | |||
| Returns: | |||
| numpy.dtype, the predicted outlier score of the sample. | |||
| """ | |||
| reconstructed_sample = self.vae.reconstruct_sample(sample_x) | |||
| return self._calculate_euclidean_distance(sample_x.asnumpy(), reconstructed_sample.asnumpy()) | |||
| def predict_outlier(self, sample_x, threshold=100.0): | |||
| """ | |||
| Predict whether the sample is an outlier. | |||
| Args: | |||
| sample_x (Tensor): The sample to be predicted, the shape is (N, C, H, W). | |||
| threshold (float): the threshold of the outlier. Default: 100.0. | |||
| Returns: | |||
| Bool, whether the sample is an outlier. | |||
| """ | |||
| score = self.predict_outlier_score(sample_x) | |||
| return score >= threshold | |||
| def _calculate_euclidean_distance(self, sample_x, reconstructed_sample): | |||
| """ | |||
| Calculate the euclidean distance of the sample_x and reconstructed_sample. | |||
| """ | |||
| return np.sqrt(np.sum(np.square(sample_x - reconstructed_sample))) | |||
| @@ -47,7 +47,6 @@ class UncertaintyEvaluation: | |||
| Default: None. | |||
| epochs (int): Total number of iterations on the data. Default: 1. | |||
| epi_uncer_model_path (str): The save or read path of the epistemic uncertainty model. Default: None. | |||
| If the epi_uncer_model_path is 'Untrain', the epistemic model need not to be trained. | |||
| ale_uncer_model_path (str): The save or read path of the aleatoric uncertainty model. Default: None. | |||
| save_model (bool): Whether to save the uncertainty model or not, if true, the epi_uncer_model_path | |||
| and ale_uncer_model_path must not be None. If false, the model to evaluate will be loaded from | |||
| @@ -82,7 +81,7 @@ class UncertaintyEvaluation: | |||
| self.epi_model = model | |||
| self.ale_model = deepcopy(model) | |||
| self.epi_train_dataset = train_dataset | |||
| self.ale_train_dataset = deepcopy(train_dataset) | |||
| self.ale_train_dataset = train_dataset | |||
| self.task_type = task_type | |||
| self.epochs = Validator.check_positive_int(epochs) | |||
| self.epi_uncer_model_path = epi_uncer_model_path | |||
| @@ -112,7 +111,7 @@ class UncertaintyEvaluation: | |||
| """ | |||
| if self.epi_uncer_model is None: | |||
| self.epi_uncer_model = EpistemicUncertaintyModel(self.epi_model) | |||
| if self.epi_uncer_model.drop_count == 0 and self.epi_uncer_model_path != 'Untrain': | |||
| if self.epi_uncer_model.drop_count == 0 and self.epi_train_dataset is not None: | |||
| if self.task_type == 'classification': | |||
| net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||
| net_opt = Adam(self.epi_uncer_model.trainable_params()) | |||
| @@ -156,6 +155,8 @@ class UncertaintyEvaluation: | |||
| """ | |||
| Get the model which can obtain the aleatoric uncertainty. | |||
| """ | |||
| if self.ale_train_dataset is None: | |||
| raise ValueError('The train dataset should not be None when evaluating aleatoric uncertainty.') | |||
| if self.ale_uncer_model is None: | |||
| self.ale_uncer_model = AleatoricUncertaintyModel(self.ale_model, self.num_classes, self.task_type) | |||
| net_loss = AleatoricLoss(self.task_type) | |||
| @@ -239,17 +240,17 @@ class EpistemicUncertaintyModel(Cell): | |||
| The dropout rate is set to 0.5 by default. | |||
| """ | |||
| for (name, layer) in epi_model.name_cells().items(): | |||
| if isinstance(layer, Dropout): | |||
| self.drop_count += 1 | |||
| return epi_model | |||
| for (name, layer) in epi_model.name_cells().items(): | |||
| if isinstance(layer, (Conv2d, Dense)): | |||
| if isinstance(layer, (Conv2d, Dense, Dropout)): | |||
| if isinstance(layer, Dropout): | |||
| self.drop_count += 1 | |||
| return epi_model | |||
| uncertainty_layer = layer | |||
| uncertainty_name = name | |||
| drop = Dropout(keep_prob=dropout_rate) | |||
| bnn_drop = SequentialCell([uncertainty_layer, drop]) | |||
| setattr(epi_model, uncertainty_name, bnn_drop) | |||
| return epi_model | |||
| self._make_epistemic(layer) | |||
| raise ValueError("The model has not Dense Layer or Convolution Layer, " | |||
| "it can not evaluate epistemic uncertainty so far.") | |||
| @@ -36,21 +36,21 @@ class TransformToBNN: | |||
| Examples: | |||
| >>> class Net(nn.Cell): | |||
| >>> def __init__(self): | |||
| >>> super(Net, self).__init__() | |||
| >>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') | |||
| >>> self.bn = nn.BatchNorm2d(64) | |||
| >>> self.relu = nn.ReLU() | |||
| >>> self.flatten = nn.Flatten() | |||
| >>> self.fc = nn.Dense(64*224*224, 12) # padding=0 | |||
| >>> | |||
| >>> def construct(self, x): | |||
| >>> x = self.conv(x) | |||
| >>> x = self.bn(x) | |||
| >>> x = self.relu(x) | |||
| >>> x = self.flatten(x) | |||
| >>> out = self.fc(x) | |||
| >>> return out | |||
| ... def __init__(self): | |||
| ... super(Net, self).__init__() | |||
| ... self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') | |||
| ... self.bn = nn.BatchNorm2d(64) | |||
| ... self.relu = nn.ReLU() | |||
| ... self.flatten = nn.Flatten() | |||
| ... self.fc = nn.Dense(64*224*224, 12) # padding=0 | |||
| ... | |||
| ... def construct(self, x): | |||
| ... x = self.conv(x) | |||
| ... x = self.bn(x) | |||
| ... x = self.relu(x) | |||
| ... x = self.flatten(x) | |||
| ... out = self.fc(x) | |||
| ... return out | |||
| >>> | |||
| >>> net = Net() | |||
| >>> criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | |||