Merge pull request !14 from zheng-huanhuan/2_mastertags/v0.2.0-alpha
| @@ -0,0 +1,118 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import sys | |||||
| import time | |||||
| import numpy as np | |||||
| import pytest | |||||
| from scipy.special import softmax | |||||
| from mindspore import Model | |||||
| from mindspore import Tensor | |||||
| from mindspore import context | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindarmour.attacks.iterative_gradient_method import MomentumDiverseInputIterativeMethod | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| from mindarmour.evaluations.attack_evaluation import AttackEvaluate | |||||
| from lenet5_net import LeNet5 | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| sys.path.append("..") | |||||
| from data_processing import generate_mnist_dataset | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = 'M_DI2_FGSM_Test' | |||||
| LOGGER.set_level('INFO') | |||||
| @pytest.mark.level1 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_card | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_momentum_diverse_input_iterative_method(): | |||||
| """ | |||||
| M-DI2-FGSM Attack Test | |||||
| """ | |||||
| # upload trained network | |||||
| ckpt_name = './trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
| net = LeNet5() | |||||
| load_dict = load_checkpoint(ckpt_name) | |||||
| load_param_into_net(net, load_dict) | |||||
| # get test data | |||||
| data_list = "./MNIST_unzip/test" | |||||
| batch_size = 32 | |||||
| ds = generate_mnist_dataset(data_list, batch_size, sparse=False) | |||||
| # prediction accuracy before attack | |||||
| model = Model(net) | |||||
| batch_num = 32 # the number of batches of attacking samples | |||||
| test_images = [] | |||||
| test_labels = [] | |||||
| predict_labels = [] | |||||
| i = 0 | |||||
| for data in ds.create_tuple_iterator(): | |||||
| i += 1 | |||||
| images = data[0].astype(np.float32) | |||||
| labels = data[1] | |||||
| test_images.append(images) | |||||
| test_labels.append(labels) | |||||
| pred_labels = np.argmax(model.predict(Tensor(images)).asnumpy(), | |||||
| axis=1) | |||||
| predict_labels.append(pred_labels) | |||||
| if i >= batch_num: | |||||
| break | |||||
| predict_labels = np.concatenate(predict_labels) | |||||
| true_labels = np.argmax(np.concatenate(test_labels), axis=1) | |||||
| accuracy = np.mean(np.equal(predict_labels, true_labels)) | |||||
| LOGGER.info(TAG, "prediction accuracy before attacking is : %s", accuracy) | |||||
| # attacking | |||||
| attack = MomentumDiverseInputIterativeMethod(net) | |||||
| start_time = time.clock() | |||||
| adv_data = attack.batch_generate(np.concatenate(test_images), | |||||
| np.concatenate(test_labels), batch_size=32) | |||||
| stop_time = time.clock() | |||||
| pred_logits_adv = model.predict(Tensor(adv_data)).asnumpy() | |||||
| # rescale predict confidences into (0, 1). | |||||
| pred_logits_adv = softmax(pred_logits_adv, axis=1) | |||||
| pred_labels_adv = np.argmax(pred_logits_adv, axis=1) | |||||
| accuracy_adv = np.mean(np.equal(pred_labels_adv, true_labels)) | |||||
| LOGGER.info(TAG, "prediction accuracy after attacking is : %s", accuracy_adv) | |||||
| attack_evaluate = AttackEvaluate(np.concatenate(test_images).transpose(0, 2, 3, 1), | |||||
| np.concatenate(test_labels), | |||||
| adv_data.transpose(0, 2, 3, 1), | |||||
| pred_logits_adv) | |||||
| LOGGER.info(TAG, 'mis-classification rate of adversaries is : %s', | |||||
| attack_evaluate.mis_classification_rate()) | |||||
| LOGGER.info(TAG, 'The average confidence of adversarial class is : %s', | |||||
| attack_evaluate.avg_conf_adv_class()) | |||||
| LOGGER.info(TAG, 'The average confidence of true class is : %s', | |||||
| attack_evaluate.avg_conf_true_class()) | |||||
| LOGGER.info(TAG, 'The average distance (l0, l2, linf) between original ' | |||||
| 'samples and adversarial samples are: %s', | |||||
| attack_evaluate.avg_lp_distance()) | |||||
| LOGGER.info(TAG, 'The average structural similarity between original ' | |||||
| 'samples and adversarial samples are: %s', | |||||
| attack_evaluate.avg_ssim()) | |||||
| LOGGER.info(TAG, 'The average costing time is %s', | |||||
| (stop_time - start_time)/(batch_num*batch_size)) | |||||
| if __name__ == '__main__': | |||||
| test_momentum_diverse_input_iterative_method() | |||||
| @@ -26,6 +26,8 @@ __all__ = ['FastGradientMethod', | |||||
| 'BasicIterativeMethod', | 'BasicIterativeMethod', | ||||
| 'MomentumIterativeMethod', | 'MomentumIterativeMethod', | ||||
| 'ProjectedGradientDescent', | 'ProjectedGradientDescent', | ||||
| 'DiverseInputIterativeMethod', | |||||
| 'MomentumDiverseInputIterativeMethod', | |||||
| 'DeepFool', | 'DeepFool', | ||||
| 'CarliniWagnerL2Attack', | 'CarliniWagnerL2Attack', | ||||
| 'JSMAAttack', | 'JSMAAttack', | ||||
| @@ -46,7 +46,7 @@ class GradientMethod(Attack): | |||||
| Default: None. | Default: None. | ||||
| bounds (tuple): Upper and lower bounds of data, indicating the data range. | bounds (tuple): Upper and lower bounds of data, indicating the data range. | ||||
| In form of (clip_min, clip_max). Default: None. | In form of (clip_min, clip_max). Default: None. | ||||
| loss_fn (Loss): Loss function for optimization. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| """ | """ | ||||
| def __init__(self, network, eps=0.07, alpha=None, bounds=None, | def __init__(self, network, eps=0.07, alpha=None, bounds=None, | ||||
| @@ -151,7 +151,7 @@ class FastGradientMethod(GradientMethod): | |||||
| Possible values: np.inf, 1 or 2. Default: 2. | Possible values: np.inf, 1 or 2. Default: 2. | ||||
| is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
| attack. Default: False. | attack. Default: False. | ||||
| loss_fn (Loss): Loss function for optimization. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| Examples: | Examples: | ||||
| >>> attack = FastGradientMethod(network) | >>> attack = FastGradientMethod(network) | ||||
| @@ -214,7 +214,7 @@ class RandomFastGradientMethod(FastGradientMethod): | |||||
| Possible values: np.inf, 1 or 2. Default: 2. | Possible values: np.inf, 1 or 2. Default: 2. | ||||
| is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
| attack. Default: False. | attack. Default: False. | ||||
| loss_fn (Loss): Loss function for optimization. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| Raises: | Raises: | ||||
| ValueError: eps is smaller than alpha! | ValueError: eps is smaller than alpha! | ||||
| @@ -255,7 +255,7 @@ class FastGradientSignMethod(GradientMethod): | |||||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
| is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
| attack. Default: False. | attack. Default: False. | ||||
| loss_fn (Loss): Loss function for optimization. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| Examples: | Examples: | ||||
| >>> attack = FastGradientSignMethod(network) | >>> attack = FastGradientSignMethod(network) | ||||
| @@ -314,7 +314,7 @@ class RandomFastGradientSignMethod(FastGradientSignMethod): | |||||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
| is_targeted (bool): True: targeted attack. False: untargeted attack. | is_targeted (bool): True: targeted attack. False: untargeted attack. | ||||
| Default: False. | Default: False. | ||||
| loss_fn (Loss): Loss function for optimization. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| Raises: | Raises: | ||||
| ValueError: eps is smaller than alpha! | ValueError: eps is smaller than alpha! | ||||
| @@ -350,7 +350,7 @@ class LeastLikelyClassMethod(FastGradientSignMethod): | |||||
| Default: None. | Default: None. | ||||
| bounds (tuple): Upper and lower bounds of data, indicating the data range. | bounds (tuple): Upper and lower bounds of data, indicating the data range. | ||||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
| loss_fn (Loss): Loss function for optimization. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| Examples: | Examples: | ||||
| >>> attack = LeastLikelyClassMethod(network) | >>> attack = LeastLikelyClassMethod(network) | ||||
| @@ -15,6 +15,7 @@ | |||||
| from abc import abstractmethod | from abc import abstractmethod | ||||
| import numpy as np | import numpy as np | ||||
| from PIL import Image, ImageOps | |||||
| from mindspore.nn import SoftmaxCrossEntropyWithLogits | from mindspore.nn import SoftmaxCrossEntropyWithLogits | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| @@ -115,7 +116,7 @@ class IterativeGradientMethod(Attack): | |||||
| bounds (tuple): Upper and lower bounds of data, indicating the data range. | bounds (tuple): Upper and lower bounds of data, indicating the data range. | ||||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | In form of (clip_min, clip_max). Default: (0.0, 1.0). | ||||
| nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
| loss_fn (Loss): Loss function for optimization. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| """ | """ | ||||
| def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5, | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), nb_iter=5, | ||||
| loss_fn=None): | loss_fn=None): | ||||
| @@ -178,14 +179,13 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
| is_targeted (bool): If True, targeted attack. If False, untargeted | is_targeted (bool): If True, targeted attack. If False, untargeted | ||||
| attack. Default: False. | attack. Default: False. | ||||
| nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
| loss_fn (Loss): Loss function for optimization. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| attack (class): The single step gradient method of each iteration. In | attack (class): The single step gradient method of each iteration. In | ||||
| this class, FGSM is used. | this class, FGSM is used. | ||||
| Examples: | Examples: | ||||
| >>> attack = BasicIterativeMethod(network) | >>> attack = BasicIterativeMethod(network) | ||||
| """ | """ | ||||
| def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | ||||
| is_targeted=False, nb_iter=5, loss_fn=None): | is_targeted=False, nb_iter=5, loss_fn=None): | ||||
| super(BasicIterativeMethod, self).__init__(network, | super(BasicIterativeMethod, self).__init__(network, | ||||
| @@ -227,14 +227,22 @@ class BasicIterativeMethod(IterativeGradientMethod): | |||||
| clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
| clip_diff = clip_max - clip_min | clip_diff = clip_max - clip_min | ||||
| for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
| adv_x = self._attack.generate(inputs, labels) | |||||
| if 'self.prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self.prob) | |||||
| else: | |||||
| d_inputs = inputs | |||||
| adv_x = self._attack.generate(d_inputs, labels) | |||||
| perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | ||||
| self._eps*clip_diff) | self._eps*clip_diff) | ||||
| adv_x = arr_x + perturs | adv_x = arr_x + perturs | ||||
| inputs = adv_x | inputs = adv_x | ||||
| else: | else: | ||||
| for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
| adv_x = self._attack.generate(inputs, labels) | |||||
| if 'self.prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self.prob) | |||||
| else: | |||||
| d_inputs = inputs | |||||
| adv_x = self._attack.generate(d_inputs, labels) | |||||
| adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | ||||
| inputs = adv_x | inputs = adv_x | ||||
| return adv_x | return adv_x | ||||
| @@ -261,7 +269,7 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
| decay_factor (float): Decay factor in iterations. Default: 1.0. | decay_factor (float): Decay factor in iterations. Default: 1.0. | ||||
| norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | ||||
| np.inf, 1 or 2. Default: 'inf'. | np.inf, 1 or 2. Default: 'inf'. | ||||
| loss_fn (Loss): Loss function for optimization. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| """ | """ | ||||
| def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | ||||
| @@ -303,9 +311,13 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
| clip_min, clip_max = self._bounds | clip_min, clip_max = self._bounds | ||||
| clip_diff = clip_max - clip_min | clip_diff = clip_max - clip_min | ||||
| for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
| gradient = self._gradient(inputs, labels) | |||||
| if 'self.prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self.prob) | |||||
| else: | |||||
| d_inputs = inputs | |||||
| gradient = self._gradient(d_inputs, labels) | |||||
| momentum = self._decay_factor*momentum + gradient | momentum = self._decay_factor*momentum + gradient | ||||
| adv_x = inputs + self._eps_iter*np.sign(momentum) | |||||
| adv_x = d_inputs + self._eps_iter*np.sign(momentum) | |||||
| perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | perturs = np.clip(adv_x - arr_x, (0 - self._eps)*clip_diff, | ||||
| self._eps*clip_diff) | self._eps*clip_diff) | ||||
| adv_x = arr_x + perturs | adv_x = arr_x + perturs | ||||
| @@ -313,12 +325,15 @@ class MomentumIterativeMethod(IterativeGradientMethod): | |||||
| inputs = adv_x | inputs = adv_x | ||||
| else: | else: | ||||
| for _ in range(self._nb_iter): | for _ in range(self._nb_iter): | ||||
| gradient = self._gradient(inputs, labels) | |||||
| if 'self.prob' in globals(): | |||||
| d_inputs = _transform_inputs(inputs, self.prob) | |||||
| else: | |||||
| d_inputs = inputs | |||||
| gradient = self._gradient(d_inputs, labels) | |||||
| momentum = self._decay_factor*momentum + gradient | momentum = self._decay_factor*momentum + gradient | ||||
| adv_x = inputs + self._eps_iter*np.sign(momentum) | |||||
| adv_x = d_inputs + self._eps_iter*np.sign(momentum) | |||||
| adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | ||||
| inputs = adv_x | inputs = adv_x | ||||
| return adv_x | return adv_x | ||||
| def _gradient(self, inputs, labels): | def _gradient(self, inputs, labels): | ||||
| @@ -372,7 +387,7 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
| nb_iter (int): Number of iteration. Default: 5. | nb_iter (int): Number of iteration. Default: 5. | ||||
| norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | ||||
| np.inf, 1 or 2. Default: 'inf'. | np.inf, 1 or 2. Default: 'inf'. | ||||
| loss_fn (Loss): Loss function for optimization. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| """ | """ | ||||
| def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | def __init__(self, network, eps=0.3, eps_iter=0.1, bounds=(0.0, 1.0), | ||||
| @@ -430,3 +445,114 @@ class ProjectedGradientDescent(BasicIterativeMethod): | |||||
| adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | adv_x = np.clip(adv_x, arr_x - self._eps, arr_x + self._eps) | ||||
| inputs = adv_x | inputs = adv_x | ||||
| return adv_x | return adv_x | ||||
| class DiverseInputIterativeMethod(BasicIterativeMethod): | |||||
| """ | |||||
| The Diverse Input Iterative Method attack. | |||||
| References: `Xie, Cihang and Zhang, et al., "Improving Transferability of | |||||
| Adversarial Examples With Input Diversity," in CVPR, 2019 <https://arxiv.org/abs/1803.06978>`_ | |||||
| Args: | |||||
| network (Cell): Target model. | |||||
| eps (float): Proportion of adversarial perturbation generated by the | |||||
| attack to data range. Default: 0.3. | |||||
| bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||||
| is_targeted (bool): If True, targeted attack. If False, untargeted | |||||
| attack. Default: False. | |||||
| prob (float): Transformation probability. Default: 0.5. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| """ | |||||
| def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | |||||
| is_targeted=False, prob=0.5, loss_fn=None): | |||||
| # reference to paper hyper parameters setting. | |||||
| eps_iter = 16*2/255 | |||||
| nb_iter = int(min(eps*255 + 4, 1.25*255*eps)) | |||||
| super(DiverseInputIterativeMethod, self).__init__(network, | |||||
| eps=eps, | |||||
| eps_iter=eps_iter, | |||||
| bounds=bounds, | |||||
| is_targeted=is_targeted, | |||||
| nb_iter=nb_iter, | |||||
| loss_fn=loss_fn) | |||||
| # FGSM default alpha is None equal alpha=1 | |||||
| self.prob = check_param_type('prob', prob, float) | |||||
| class MomentumDiverseInputIterativeMethod(MomentumIterativeMethod): | |||||
| """ | |||||
| The Momentum Diverse Input Iterative Method attack. | |||||
| References: `Xie, Cihang and Zhang, et al., "Improving Transferability of | |||||
| Adversarial Examples With Input Diversity," in CVPR, 2019 <https://arxiv.org/abs/1803.06978>`_ | |||||
| Args: | |||||
| network (Cell): Target model. | |||||
| eps (float): Proportion of adversarial perturbation generated by the | |||||
| attack to data range. Default: 0.3. | |||||
| bounds (tuple): Upper and lower bounds of data, indicating the data range. | |||||
| In form of (clip_min, clip_max). Default: (0.0, 1.0). | |||||
| is_targeted (bool): If True, targeted attack. If False, untargeted | |||||
| attack. Default: False. | |||||
| norm_level (Union[int, numpy.inf]): Order of the norm. Possible values: | |||||
| np.inf, 1 or 2. Default: 'l1'. | |||||
| prob (float): Transformation probability. Default: 0.5. | |||||
| loss_fn (Loss): Loss function for optimization. Default: None. | |||||
| """ | |||||
| def __init__(self, network, eps=0.3, bounds=(0.0, 1.0), | |||||
| is_targeted=False, norm_level='l1', prob=0.5, loss_fn=None): | |||||
| eps_iter = 16*2 / 255 | |||||
| nb_iter = int(min(eps*255 + 4, 1.25*255*eps)) | |||||
| super(MomentumDiverseInputIterativeMethod, self).__init__(network=network, | |||||
| eps=eps, | |||||
| eps_iter=eps_iter, | |||||
| bounds=bounds, | |||||
| nb_iter=nb_iter, | |||||
| is_targeted=is_targeted, | |||||
| norm_level=norm_level, | |||||
| loss_fn=loss_fn) | |||||
| self.prob = check_param_type('prob', prob, float) | |||||
| def _transform_inputs(inputs, prob, low=29, high=33, full_aug=False): | |||||
| """ | |||||
| Inputs data augmentation. | |||||
| Args: | |||||
| inputs (Union[np.int8, np.float]): Inputs. | |||||
| prob (float): The probability of augmentation. | |||||
| low (int): Lower bound of resize image width. Default: 29. | |||||
| high (int): Upper bound of resize image height. Default: 33. | |||||
| full_aug (bool): type of augmentation method, use interpolation and padding | |||||
| as default. Default: False. | |||||
| Returns: | |||||
| numpy.ndarray, the augmentation data. | |||||
| """ | |||||
| raw_shape = inputs[0].shape | |||||
| tran_mask = np.random.uniform(0, 1, size=inputs.shape[0]) < prob | |||||
| tran_inputs = inputs[tran_mask] | |||||
| raw_inputs = inputs[tran_mask == 0] | |||||
| tran_outputs = [] | |||||
| for sample in tran_inputs: | |||||
| width = np.random.choice(np.arange(low, high)) | |||||
| # resize | |||||
| sample = (sample*255).astype(np.uint8) | |||||
| d_image = Image.fromarray(sample, mode='L').resize((width, width), Image.NEAREST) | |||||
| # pad | |||||
| left_pad = (raw_shape[0] - width) // 2 | |||||
| right_pad = raw_shape[0] - width - left_pad | |||||
| top_pad = (raw_shape[1] - width) // 2 | |||||
| bottom_pad = raw_shape[1] - width - top_pad | |||||
| p_sample = ImageOps.expand(d_image, | |||||
| border=(left_pad, top_pad, right_pad, bottom_pad)) | |||||
| tran_outputs.append(np.array(p_sample).astype(np.float) / 255) | |||||
| if full_aug: | |||||
| # gaussian noise | |||||
| tran_outputs = np.random.normal(tran_outputs.shape) + tran_outputs | |||||
| tran_outputs.extend(raw_inputs) | |||||
| if not np.any(tran_outputs-raw_inputs): | |||||
| LOGGER.error(TAG, 'the transform function does not take effect.') | |||||
| return tran_outputs | |||||
| @@ -242,7 +242,7 @@ def normalize_value(value, norm_level): | |||||
| Raises: | Raises: | ||||
| NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', | NotImplementedError: If norm_level is not in [1, 2 , np.inf, '1', '2', | ||||
| 'inf] | |||||
| 'inf', 'l1', 'l2'] | |||||
| """ | """ | ||||
| norm_level = check_norm_level(norm_level) | norm_level = check_norm_level(norm_level) | ||||
| ori_shape = value.shape | ori_shape = value.shape | ||||
| @@ -1,6 +1,7 @@ | |||||
| numpy >= 1.17.0 | numpy >= 1.17.0 | ||||
| scipy >= 1.3.3 | scipy >= 1.3.3 | ||||
| matplotlib >= 3.1.3 | matplotlib >= 3.1.3 | ||||
| Pillow >= 2.0.0 | |||||
| pytest >= 4.3.1 | pytest >= 4.3.1 | ||||
| wheel >= 0.32.0 | wheel >= 0.32.0 | ||||
| setuptools >= 40.8.0 | setuptools >= 40.8.0 | ||||
| @@ -95,7 +95,8 @@ setup( | |||||
| install_requires=[ | install_requires=[ | ||||
| 'scipy >= 1.3.3', | 'scipy >= 1.3.3', | ||||
| 'numpy >= 1.17.0', | 'numpy >= 1.17.0', | ||||
| 'matplotlib >= 3.1.3' | |||||
| 'matplotlib >= 3.1.3', | |||||
| 'Pillow >= 2.0.0' | |||||
| ], | ], | ||||
| ) | ) | ||||
| print(find_packages()) | print(find_packages()) | ||||
| @@ -25,6 +25,8 @@ from mindarmour.attacks import BasicIterativeMethod | |||||
| from mindarmour.attacks import MomentumIterativeMethod | from mindarmour.attacks import MomentumIterativeMethod | ||||
| from mindarmour.attacks import ProjectedGradientDescent | from mindarmour.attacks import ProjectedGradientDescent | ||||
| from mindarmour.attacks import IterativeGradientMethod | from mindarmour.attacks import IterativeGradientMethod | ||||
| from mindarmour.attacks import DiverseInputIterativeMethod | |||||
| from mindarmour.attacks import MomentumDiverseInputIterativeMethod | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | ||||
| @@ -91,7 +93,7 @@ def test_momentum_iterative_method(): | |||||
| for i in range(5): | for i in range(5): | ||||
| attack = MomentumIterativeMethod(Net(), nb_iter=i + 1) | attack = MomentumIterativeMethod(Net(), nb_iter=i + 1) | ||||
| ms_adv_x = attack.generate(input_np, label) | ms_adv_x = attack.generate(input_np, label) | ||||
| assert np.any(ms_adv_x != input_np), 'Basic iterative method: generate' \ | |||||
| assert np.any(ms_adv_x != input_np), 'Momentum iterative method: generate' \ | |||||
| ' value must not be equal to' \ | ' value must not be equal to' \ | ||||
| ' original value.' | ' original value.' | ||||
| @@ -119,6 +121,48 @@ def test_projected_gradient_descent_method(): | |||||
| ' original value.' | ' original value.' | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_card | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_diverse_input_iterative_method(): | |||||
| """ | |||||
| Diverse input iterative method unit test. | |||||
| """ | |||||
| input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||||
| label = np.asarray([2], np.int32) | |||||
| label = np.eye(3)[label].astype(np.float32) | |||||
| for i in range(5): | |||||
| attack = DiverseInputIterativeMethod(Net()) | |||||
| ms_adv_x = attack.generate(input_np, label) | |||||
| assert np.any(ms_adv_x != input_np), 'Diverse input iterative method: generate' \ | |||||
| ' value must not be equal to' \ | |||||
| ' original value.' | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_card | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_momentum_diverse_input_iterative_method(): | |||||
| """ | |||||
| Momentum diverse input iterative method unit test. | |||||
| """ | |||||
| input_np = np.asarray([[0.1, 0.2, 0.7]], np.float32) | |||||
| label = np.asarray([2], np.int32) | |||||
| label = np.eye(3)[label].astype(np.float32) | |||||
| for i in range(5): | |||||
| attack = MomentumDiverseInputIterativeMethod(Net()) | |||||
| ms_adv_x = attack.generate(input_np, label) | |||||
| assert np.any(ms_adv_x != input_np), 'Momentum diverse input iterative method: ' \ | |||||
| 'generate value must not be equal to' \ | |||||
| ' original value.' | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||