Merge pull request !13 from ZhidanLiu/mastertags/v0.2.0-alpha
| @@ -0,0 +1,89 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import sys | |||
| import numpy as np | |||
| from mindspore import Model | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.nn import SoftmaxCrossEntropyWithLogits | |||
| from mindarmour.attacks.gradient_method import FastGradientSignMethod | |||
| from mindarmour.utils.logger import LogUtil | |||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
| from mindarmour.fuzzing.fuzzing import Fuzzing | |||
| from lenet5_net import LeNet5 | |||
| sys.path.append("..") | |||
| from data_processing import generate_mnist_dataset | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'Fuzz_test' | |||
| LOGGER.set_level('INFO') | |||
| def test_lenet_mnist_fuzzing(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| # upload trained network | |||
| ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||
| net = LeNet5() | |||
| load_dict = load_checkpoint(ckpt_name) | |||
| load_param_into_net(net, load_dict) | |||
| model = Model(net) | |||
| # get training data | |||
| data_list = "./MNIST_datasets/train" | |||
| batch_size = 32 | |||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||
| train_images = [] | |||
| for data in ds.create_tuple_iterator(): | |||
| images = data[0].astype(np.float32) | |||
| train_images.append(images) | |||
| train_images = np.concatenate(train_images, axis=0) | |||
| # initialize fuzz test with training dataset | |||
| model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| data_list = "./MNIST_datasets/test" | |||
| batch_size = 32 | |||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=True) | |||
| test_images = [] | |||
| test_labels = [] | |||
| for data in ds.create_tuple_iterator(): | |||
| images = data[0].astype(np.float32) | |||
| labels = data[1] | |||
| test_images.append(images) | |||
| test_labels.append(labels) | |||
| test_images = np.concatenate(test_images, axis=0) | |||
| test_labels = np.concatenate(test_labels, axis=0) | |||
| initial_seeds = [] | |||
| # make initial seeds | |||
| for img, label in zip(test_images, test_labels): | |||
| initial_seeds.append([img, label, 0]) | |||
| initial_seeds = initial_seeds[:100] | |||
| model_coverage_test.test_adequacy_coverage_calculate(np.array(test_images[:100]).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||
| model_fuzz_test = Fuzzing(initial_seeds, model, train_images, 20) | |||
| failed_tests = model_fuzz_test.fuzzing() | |||
| model_coverage_test.test_adequacy_coverage_calculate(np.array(failed_tests).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', model_coverage_test.get_kmnc()) | |||
| if __name__ == '__main__': | |||
| test_lenet_mnist_fuzzing() | |||
| @@ -1,3 +1,8 @@ | |||
| """ | |||
| This module includes various metrics to fuzzing the test of DNN. | |||
| """ | |||
| from .fuzzing import Fuzzing | |||
| from .model_coverage_metrics import ModelCoverageMetrics | |||
| __all__ = ['ModelCoverageMetrics'] | |||
| __all__ = ['Fuzzing', | |||
| 'ModelCoverageMetrics'] | |||
| @@ -0,0 +1,169 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Fuzzing. | |||
| """ | |||
| import numpy as np | |||
| from random import choice | |||
| from mindspore import Tensor | |||
| from mindspore import Model | |||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
| from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ | |||
| Translate, Scale, Shear, Rotate | |||
| from mindarmour.utils._check_param import check_model, check_numpy_param, \ | |||
| check_int_positive | |||
| class Fuzzing: | |||
| """ | |||
| Fuzzing test framework for deep neural networks. | |||
| Reference: `DeepHunter: A Coverage-Guided Fuzz Testing Framework for Deep | |||
| Neural Networks <https://dl.acm.org/doi/10.1145/3293882.3330579>`_ | |||
| Args: | |||
| initial_seeds (list): Initial fuzzing seed, format: [[image, label, 0], | |||
| [image, label, 0], ...]. | |||
| target_model (Model): Target fuzz model. | |||
| train_dataset (numpy.ndarray): Training dataset used for determine | |||
| the neurons' output boundaries. | |||
| const_K (int): The number of mutate tests for a seed. | |||
| mode (str): Image mode used in image transform, 'L' means grey graph. | |||
| Default: 'L'. | |||
| """ | |||
| def __init__(self, initial_seeds, target_model, train_dataset, const_K, | |||
| mode='L', max_seed_num=1000): | |||
| self.initial_seeds = initial_seeds | |||
| self.target_model = check_model('model', target_model, Model) | |||
| self.train_dataset = check_numpy_param('train_dataset', train_dataset) | |||
| self.K = check_int_positive('const_k', const_K) | |||
| self.mode = mode | |||
| self.max_seed_num = check_int_positive('max_seed_num', max_seed_num) | |||
| self.coverage_metrics = ModelCoverageMetrics(target_model, 1000, 10, | |||
| train_dataset) | |||
| def _image_value_expand(self, image): | |||
| return image*255 | |||
| def _image_value_compress(self, image): | |||
| return image / 255 | |||
| def _metamorphic_mutate(self, seed, try_num=50): | |||
| if self.mode == 'L': | |||
| seed = seed[0] | |||
| info = [seed, seed] | |||
| mutate_tests = [] | |||
| affine_trans = ['Contrast', 'Brightness', 'Blur', 'Noise'] | |||
| pixel_value_trans = ['Translate', 'Scale', 'Shear', 'Rotate'] | |||
| strages = {'Contrast': Contrast, 'Brightness': Brightness, 'Blur': Blur, | |||
| 'Noise': Noise, | |||
| 'Translate': Translate, 'Scale': Scale, 'Shear': Shear, | |||
| 'Rotate': Rotate} | |||
| for _ in range(self.K): | |||
| for _ in range(try_num): | |||
| if (info[0] == info[1]).all(): | |||
| trans_strage = self._random_pick_mutate(affine_trans, | |||
| pixel_value_trans) | |||
| else: | |||
| trans_strage = self._random_pick_mutate(affine_trans, []) | |||
| transform = strages[trans_strage]( | |||
| self._image_value_expand(seed), self.mode) | |||
| transform.random_param() | |||
| mutate_test = transform.transform() | |||
| mutate_test = np.expand_dims( | |||
| self._image_value_compress(mutate_test), 0) | |||
| if self._is_trans_valid(seed, mutate_test): | |||
| if trans_strage in affine_trans: | |||
| info[1] = mutate_test | |||
| mutate_tests.append(mutate_test) | |||
| if len(mutate_tests) == 0: | |||
| mutate_tests.append(seed) | |||
| return np.array(mutate_tests) | |||
| def fuzzing(self, coverage_metric='KMNC'): | |||
| """ | |||
| Fuzzing tests for deep neural networks. | |||
| Args: | |||
| coverage_metric (str): Model coverage metric of neural networks. | |||
| Default: 'KMNC'. | |||
| Returns: | |||
| list, mutated tests mis-predicted by target dnn model. | |||
| """ | |||
| seed = self._select_next() | |||
| failed_tests = [] | |||
| seed_num = 0 | |||
| while len(seed) > 0 and seed_num < self.max_seed_num: | |||
| mutate_tests = self._metamorphic_mutate(seed[0]) | |||
| coverages, results = self._run(mutate_tests, coverage_metric) | |||
| coverage_gains = self._coverage_gains(coverages) | |||
| for mutate, cov, res in zip(mutate_tests, coverage_gains, results): | |||
| if np.argmax(seed[1]) != np.argmax(res): | |||
| failed_tests.append(mutate) | |||
| continue | |||
| if cov > 0: | |||
| self.initial_seeds.append([mutate, seed[1], 0]) | |||
| seed = self._select_next() | |||
| seed_num += 1 | |||
| return failed_tests | |||
| def _coverage_gains(self, coverages): | |||
| gains = [0] + coverages[:-1] | |||
| gains = np.array(coverages) - np.array(gains) | |||
| return gains | |||
| def _run(self, mutate_tests, coverage_metric="KNMC"): | |||
| coverages = [] | |||
| result = self.target_model.predict( | |||
| Tensor(mutate_tests.astype(np.float32))) | |||
| result = result.asnumpy() | |||
| for index in range(len(mutate_tests)): | |||
| mutate = np.expand_dims(mutate_tests[index], 0) | |||
| self.coverage_metrics.test_adequacy_coverage_calculate( | |||
| mutate.astype(np.float32), batch_size=1) | |||
| if coverage_metric == "KMNC": | |||
| coverages.append(self.coverage_metrics.get_kmnc()) | |||
| return coverages, result | |||
| def _select_next(self): | |||
| seed = choice(self.initial_seeds) | |||
| return seed | |||
| def _random_pick_mutate(self, affine_trans_list, pixel_value_trans_list): | |||
| strage = choice(affine_trans_list + pixel_value_trans_list) | |||
| return strage | |||
| def _is_trans_valid(self, seed, mutate_test): | |||
| is_valid = False | |||
| alpha = 0.02 | |||
| beta = 0.2 | |||
| diff = np.array(seed - mutate_test).flatten() | |||
| size = np.shape(diff)[0] | |||
| L0 = np.linalg.norm(diff, ord=0) | |||
| Linf = np.linalg.norm(diff, ord=np.inf) | |||
| if L0 > alpha*size: | |||
| if Linf < 256: | |||
| is_valid = True | |||
| else: | |||
| if Linf < beta*255: | |||
| is_valid = True | |||
| return is_valid | |||
| @@ -76,6 +76,11 @@ class ModelCoverageMetrics: | |||
| upper_compare_array = np.concatenate( | |||
| [output, np.array([self._upper_bounds])], axis=0) | |||
| self._upper_bounds = np.max(upper_compare_array, axis=0) | |||
| if batches == 0: | |||
| output = self._model.predict(Tensor(train_dataset)).asnumpy() | |||
| self._lower_bounds = np.min(output, axis=0) | |||
| self._upper_bounds = np.max(output, axis=0) | |||
| output_mat.append(output) | |||
| self._var = np.std(np.concatenate(np.array(output_mat), axis=0), | |||
| axis=0) | |||
| @@ -0,0 +1,267 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Image transform | |||
| """ | |||
| import numpy as np | |||
| from PIL import Image, ImageEnhance, ImageFilter | |||
| import random | |||
| from mindarmour.utils._check_param import check_numpy_param | |||
| class ImageTransform: | |||
| """ | |||
| The abstract base class for all image transform classes. | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| def random_param(self): | |||
| pass | |||
| def transform(self): | |||
| pass | |||
| class Contrast(ImageTransform): | |||
| """ | |||
| Contrast of an image. | |||
| Args: | |||
| image (numpy.ndarray): Original image to be transformed. | |||
| mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
| 'L' means grey image. | |||
| """ | |||
| def __init__(self, image, mode): | |||
| super(Contrast, self).__init__() | |||
| self.image = check_numpy_param('image', image) | |||
| self.mode = mode | |||
| def random_param(self): | |||
| """ Random generate parameters. """ | |||
| self.factor = random.uniform(-10, 10) | |||
| def transform(self): | |||
| img = Image.fromarray(self.image, self.mode) | |||
| img_contrast = ImageEnhance.Contrast(img) | |||
| trans_image = img_contrast.enhance(self.factor) | |||
| trans_image = np.array(trans_image) | |||
| return trans_image | |||
| class Brightness(ImageTransform): | |||
| """ | |||
| Brightness of an image. | |||
| Args: | |||
| image (numpy.ndarray): Original image to be transformed. | |||
| mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
| 'L' means grey image. | |||
| """ | |||
| def __init__(self, image, mode): | |||
| super(Brightness, self).__init__() | |||
| self.image = check_numpy_param('image', image) | |||
| self.mode = mode | |||
| def random_param(self): | |||
| """ Random generate parameters. """ | |||
| self.factor = random.uniform(-10, 10) | |||
| def transform(self): | |||
| img = Image.fromarray(self.image, self.mode) | |||
| img_contrast = ImageEnhance.Brightness(img) | |||
| trans_image = img_contrast.enhance(self.factor) | |||
| trans_image = np.array(trans_image) | |||
| return trans_image | |||
| class Blur(ImageTransform): | |||
| """ | |||
| GaussianBlur of an image. | |||
| Args: | |||
| image (numpy.ndarray): Original image to be transformed. | |||
| mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
| 'L' means grey image. | |||
| """ | |||
| def __init__(self, image, mode): | |||
| super(Blur, self).__init__() | |||
| self.image = check_numpy_param('image', image) | |||
| self.mode = mode | |||
| def random_param(self): | |||
| """ Random generate parameters. """ | |||
| self.radius = random.uniform(-10, 10) | |||
| def transform(self): | |||
| """ Transform the image. """ | |||
| img = Image.fromarray(self.image, self.mode) | |||
| trans_image = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||
| trans_image = np.array(trans_image) | |||
| return trans_image | |||
| class Noise(ImageTransform): | |||
| """ | |||
| Add noise of an image. | |||
| Args: | |||
| image (numpy.ndarray): Original image to be transformed. | |||
| mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
| 'L' means grey image. | |||
| """ | |||
| def __init__(self, image, mode): | |||
| super(Noise, self).__init__() | |||
| self.image = check_numpy_param('image', image) | |||
| self.mode = mode | |||
| def random_param(self): | |||
| """ random generate parameters """ | |||
| self.factor = random.uniform(-1, 1) | |||
| def transform(self): | |||
| """ Random generate parameters. """ | |||
| noise = np.random.uniform(low=-1, high=1, size=self.image.shape) | |||
| trans_image = np.copy(self.image) | |||
| trans_image[noise < -self.factor] = 0 | |||
| trans_image[noise > self.factor] = 255 | |||
| trans_image = np.array(trans_image) | |||
| return trans_image | |||
| class Translate(ImageTransform): | |||
| """ | |||
| Translate an image. | |||
| Args: | |||
| image (numpy.ndarray): Original image to be transformed. | |||
| mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
| 'L' means grey image. | |||
| """ | |||
| def __init__(self, image, mode): | |||
| super(Translate, self).__init__() | |||
| self.image = check_numpy_param('image', image) | |||
| self.mode = mode | |||
| def random_param(self): | |||
| """ Random generate parameters. """ | |||
| image_shape = np.shape(self.image) | |||
| self.x_bias = random.uniform(0, image_shape[0]) | |||
| self.y_bias = random.uniform(0, image_shape[1]) | |||
| def transform(self): | |||
| """ Transform the image. """ | |||
| img = Image.fromarray(self.image, self.mode) | |||
| trans_image = img.transform(img.size, Image.AFFINE, | |||
| (1, 0, self.x_bias, 0, 1, self.y_bias)) | |||
| trans_image = np.array(trans_image) | |||
| return trans_image | |||
| class Scale(ImageTransform): | |||
| """ | |||
| Scale an image. | |||
| Args: | |||
| image(numpy.ndarray): Original image to be transformed. | |||
| mode(str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
| 'L' means grey image. | |||
| """ | |||
| def __init__(self, image, mode): | |||
| super(Scale, self).__init__() | |||
| self.image = check_numpy_param('image', image) | |||
| self.mode = mode | |||
| def random_param(self): | |||
| """ Random generate parameters. """ | |||
| self.factor_x = random.uniform(0, 1) | |||
| self.factor_y = random.uniform(0, 1) | |||
| def transform(self): | |||
| """ Transform the image. """ | |||
| img = Image.fromarray(self.image, self.mode) | |||
| trans_image = img.transform(img.size, Image.AFFINE, | |||
| (self.factor_x, 0, 0, 0, self.factor_y, 0)) | |||
| trans_image = np.array(trans_image) | |||
| return trans_image | |||
| class Shear(ImageTransform): | |||
| """ | |||
| Shear an image. | |||
| Args: | |||
| image (numpy.ndarray): Original image to be transformed. | |||
| mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
| 'L' means grey image. | |||
| """ | |||
| def __init__(self, image, mode): | |||
| super(Shear, self).__init__() | |||
| self.image = check_numpy_param('image', image) | |||
| self.mode = mode | |||
| def random_param(self): | |||
| """ Random generate parameters. """ | |||
| self.factor = random.uniform(0, 1) | |||
| def transform(self): | |||
| """ Transform the image. """ | |||
| img = Image.fromarray(self.image, self.mode) | |||
| if np.random.random() > 0.5: | |||
| level = -self.factor | |||
| else: | |||
| level = self.factor | |||
| if np.random.random() > 0.5: | |||
| trans_image = img.transform(img.size, Image.AFFINE, | |||
| (1, level, 0, 0, 1, 0)) | |||
| else: | |||
| trans_image = img.transform(img.size, Image.AFFINE, | |||
| (1, 0, 0, level, 1, 0)) | |||
| trans_image = np.array(trans_image, dtype=np.float) | |||
| return trans_image | |||
| class Rotate(ImageTransform): | |||
| """ | |||
| Rotate an image. | |||
| Args: | |||
| image (numpy.ndarray): Original image to be transformed. | |||
| mode (str): Mode used in PIL, here mode must be in ['L', 'RGB'], | |||
| 'L' means grey image. | |||
| """ | |||
| def __init__(self, image, mode): | |||
| super(Rotate, self).__init__() | |||
| self.image = check_numpy_param('image', image) | |||
| self.mode = mode | |||
| def random_param(self): | |||
| """ Random generate parameters. """ | |||
| self.angle = random.uniform(0, 360) | |||
| def transform(self): | |||
| """ Transform the image. """ | |||
| img = Image.fromarray(self.image, self.mode) | |||
| trans_image = img.rotate(self.angle) | |||
| trans_image = np.array(trans_image) | |||
| return trans_image | |||
| @@ -0,0 +1,161 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Model-fuzz coverage test. | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| import sys | |||
| from mindspore.train import Model | |||
| from mindspore import nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore import context | |||
| from mindspore.common.initializer import TruncatedNormal | |||
| from mindarmour.utils.logger import LogUtil | |||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | |||
| from mindarmour.fuzzing.fuzzing import Fuzzing | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'Fuzzing test' | |||
| LOGGER.set_level('INFO') | |||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||
| weight = weight_variable() | |||
| return nn.Conv2d(in_channels, out_channels, | |||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||
| weight_init=weight, has_bias=False, pad_mode="valid") | |||
| def fc_with_initialize(input_channels, out_channels): | |||
| weight = weight_variable() | |||
| bias = weight_variable() | |||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||
| def weight_variable(): | |||
| return TruncatedNormal(0.02) | |||
| class Net(nn.Cell): | |||
| """ | |||
| Lenet network | |||
| """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.conv1 = conv(1, 6, 5) | |||
| self.conv2 = conv(6, 16, 5) | |||
| self.fc1 = fc_with_initialize(16*5*5, 120) | |||
| self.fc2 = fc_with_initialize(120, 84) | |||
| self.fc3 = fc_with_initialize(84, 10) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.reshape = P.Reshape() | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.conv2(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| x = self.reshape(x, (-1, 16*5*5)) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| x = self.fc2(x) | |||
| x = self.relu(x) | |||
| x = self.fc3(x) | |||
| return x | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_fuzzing_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| # load network | |||
| net = Net() | |||
| model = Model(net) | |||
| batch_size = 8 | |||
| num_classe = 10 | |||
| # initialize fuzz test with training dataset | |||
| training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
| model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
| test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||
| test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||
| initial_seeds = [] | |||
| for img, label in zip(test_data, test_labels): | |||
| initial_seeds.append([img, label, 0]) | |||
| model_coverage_test.test_adequacy_coverage_calculate( | |||
| np.array(test_data).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
| model_coverage_test.get_kmnc()) | |||
| model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, | |||
| max_seed_num=10) | |||
| failed_tests = model_fuzz_test.fuzzing() | |||
| model_coverage_test.test_adequacy_coverage_calculate( | |||
| np.array(failed_tests).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
| model_coverage_test.get_kmnc()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_fuzzing_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| # load network | |||
| net = Net() | |||
| model = Model(net) | |||
| batch_size = 8 | |||
| num_classe = 10 | |||
| # initialize fuzz test with training dataset | |||
| training_data = np.random.rand(32, 1, 32, 32).astype(np.float32) | |||
| model_coverage_test = ModelCoverageMetrics(model, 1000, 10, training_data) | |||
| # fuzz test with original test data | |||
| # get test data | |||
| test_data = np.random.rand(batch_size, 1, 32, 32).astype(np.float32) | |||
| test_labels = np.random.randint(num_classe, size=batch_size).astype(np.int32) | |||
| test_labels = (np.eye(num_classe)[test_labels]).astype(np.float32) | |||
| initial_seeds = [] | |||
| for img, label in zip(test_data, test_labels): | |||
| initial_seeds.append([img, label, 0]) | |||
| model_coverage_test.test_adequacy_coverage_calculate( | |||
| np.array(test_data).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
| model_coverage_test.get_kmnc()) | |||
| model_fuzz_test = Fuzzing(initial_seeds, model, training_data, 5, | |||
| max_seed_num=10) | |||
| failed_tests = model_fuzz_test.fuzzing() | |||
| model_coverage_test.test_adequacy_coverage_calculate( | |||
| np.array(failed_tests).astype(np.float32)) | |||
| LOGGER.info(TAG, 'KMNC of this test is : %s', | |||
| model_coverage_test.get_kmnc()) | |||
| @@ -0,0 +1,136 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| Image transform test. | |||
| """ | |||
| import numpy as np | |||
| import pytest | |||
| from mindarmour.utils.logger import LogUtil | |||
| from mindarmour.utils.image_transform import Contrast, Brightness, Blur, Noise, \ | |||
| Translate, Scale, Shear, Rotate | |||
| LOGGER = LogUtil.get_instance() | |||
| TAG = 'Image transform test' | |||
| LOGGER.set_level('INFO') | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_contrast(): | |||
| image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
| mode = 'L' | |||
| trans = Contrast(image, mode) | |||
| trans.random_param() | |||
| trans_image = trans.transform() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_brightness(): | |||
| image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
| mode = 'L' | |||
| trans = Brightness(image, mode) | |||
| trans.random_param() | |||
| trans_image = trans.transform() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_blur(): | |||
| image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
| mode = 'L' | |||
| trans = Blur(image, mode) | |||
| trans.random_param() | |||
| trans_image = trans.transform() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_noise(): | |||
| image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
| mode = 'L' | |||
| trans = Noise(image, mode) | |||
| trans.random_param() | |||
| trans_image = trans.transform() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_translate(): | |||
| image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
| mode = 'L' | |||
| trans = Translate(image, mode) | |||
| trans.random_param() | |||
| trans_image = trans.transform() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_shear(): | |||
| image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
| mode = 'L' | |||
| trans = Shear(image, mode) | |||
| trans.random_param() | |||
| trans_image = trans.transform() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_scale(): | |||
| image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
| mode = 'L' | |||
| trans = Scale(image, mode) | |||
| trans.random_param() | |||
| trans_image = trans.transform() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.component_mindarmour | |||
| def test_rotate(): | |||
| image = (np.random.rand(32, 32)*255).astype(np.float32) | |||
| mode = 'L' | |||
| trans = Rotate(image, mode) | |||
| trans.random_param() | |||
| trans_image = trans.transform() | |||