| @@ -27,7 +27,7 @@ import mindspore.nn as nn | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.nn.optim.momentum import Momentum | from mindspore.nn.optim.momentum import Momentum | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | |||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.serialization import load_param_into_net, load_checkpoint | from mindspore.train.serialization import load_param_into_net, load_checkpoint | ||||
| from mindarmour.utils import LogUtil | from mindarmour.utils import LogUtil | ||||
| @@ -187,12 +187,13 @@ if __name__ == '__main__': | |||||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | ||||
| # checkpoint save | # checkpoint save | ||||
| callbacks = [LossMonitor()] | |||||
| if args.rank_save_ckpt_flag: | if args.rank_save_ckpt_flag: | ||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch, | ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval*args.steps_per_epoch, | ||||
| keep_checkpoint_max=args.ckpt_save_max) | keep_checkpoint_max=args.ckpt_save_max) | ||||
| ckpt_cb = ModelCheckpoint(config=ckpt_config, | ckpt_cb = ModelCheckpoint(config=ckpt_config, | ||||
| directory=args.outputs_dir, | directory=args.outputs_dir, | ||||
| prefix='{}'.format(args.rank)) | prefix='{}'.format(args.rank)) | ||||
| callbacks = ckpt_cb | |||||
| callbacks.append(ckpt_cb) | |||||
| model.train(args.max_epoch, dataset, callbacks=callbacks) | model.train(args.max_epoch, dataset, callbacks=callbacks) | ||||
| @@ -51,7 +51,7 @@ def test_lenet_mnist_coverage(): | |||||
| train_images = np.concatenate(train_images, axis=0) | train_images = np.concatenate(train_images, axis=0) | ||||
| # initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
| model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | |||||
| model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
| # fuzz test with original test data | # fuzz test with original test data | ||||
| # get test data | # get test data | ||||
| @@ -41,12 +41,20 @@ def test_lenet_mnist_fuzzing(): | |||||
| mutate_config = [{'method': 'Blur', | mutate_config = [{'method': 'Blur', | ||||
| 'params': {'auto_param': True}}, | 'params': {'auto_param': True}}, | ||||
| {'method': 'Contrast', | {'method': 'Contrast', | ||||
| 'params': {'factor': 2}}, | |||||
| 'params': {'auto_param': True}}, | |||||
| {'method': 'Translate', | {'method': 'Translate', | ||||
| 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | |||||
| 'params': {'auto_param': True}}, | |||||
| {'method': 'Brightness', | |||||
| 'params': {'auto_param': True}}, | |||||
| {'method': 'Noise', | |||||
| 'params': {'auto_param': True}}, | |||||
| {'method': 'Scale', | |||||
| 'params': {'auto_param': True}}, | |||||
| {'method': 'Shear', | |||||
| 'params': {'auto_param': True}}, | |||||
| {'method': 'FGSM', | {'method': 'FGSM', | ||||
| 'params': {'eps': 0.1, 'alpha': 0.1}} | |||||
| ] | |||||
| 'params': {'eps': 0.3, 'alpha': 0.1}} | |||||
| ] | |||||
| # get training data | # get training data | ||||
| data_list = "./MNIST_unzip/train" | data_list = "./MNIST_unzip/train" | ||||
| @@ -59,7 +67,7 @@ def test_lenet_mnist_fuzzing(): | |||||
| train_images = np.concatenate(train_images, axis=0) | train_images = np.concatenate(train_images, axis=0) | ||||
| # initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
| model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||||
| model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
| # fuzz test with original test data | # fuzz test with original test data | ||||
| # get test data | # get test data | ||||
| @@ -79,7 +87,7 @@ def test_lenet_mnist_fuzzing(): | |||||
| # make initial seeds | # make initial seeds | ||||
| for img, label in zip(test_images, test_labels): | for img, label in zip(test_images, test_labels): | ||||
| initial_seeds.append([img, label, 0]) | |||||
| initial_seeds.append([img, label]) | |||||
| initial_seeds = initial_seeds[:100] | initial_seeds = initial_seeds[:100] | ||||
| model_coverage_test.calculate_coverage( | model_coverage_test.calculate_coverage( | ||||
| @@ -11,6 +11,7 @@ from .monitor.monitor import RDPMonitor | |||||
| from .monitor.monitor import ZCDPMonitor | from .monitor.monitor import ZCDPMonitor | ||||
| from .optimizer.optimizer import DPOptimizerClassFactory | from .optimizer.optimizer import DPOptimizerClassFactory | ||||
| from .train.model import DPModel | from .train.model import DPModel | ||||
| from .evaluation.membership_inference import MembershipInference | |||||
| __all__ = ['NoiseGaussianRandom', | __all__ = ['NoiseGaussianRandom', | ||||
| 'NoiseAdaGaussianRandom', | 'NoiseAdaGaussianRandom', | ||||
| @@ -21,4 +22,5 @@ __all__ = ['NoiseGaussianRandom', | |||||
| 'RDPMonitor', | 'RDPMonitor', | ||||
| 'ZCDPMonitor', | 'ZCDPMonitor', | ||||
| 'DPOptimizerClassFactory', | 'DPOptimizerClassFactory', | ||||
| 'DPModel'] | |||||
| 'DPModel', | |||||
| 'MembershipInference'] | |||||
| @@ -21,6 +21,11 @@ from sklearn.ensemble import RandomForestClassifier | |||||
| from sklearn.model_selection import GridSearchCV | from sklearn.model_selection import GridSearchCV | ||||
| from sklearn.model_selection import RandomizedSearchCV | from sklearn.model_selection import RandomizedSearchCV | ||||
| from mindarmour.utils.logger import LogUtil | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = "Attacker" | |||||
| def _attack_knn(features, labels, param_grid): | def _attack_knn(features, labels, param_grid): | ||||
| """ | """ | ||||
| @@ -114,17 +119,31 @@ def get_attack_model(features, labels, config): | |||||
| features (numpy.ndarray): Loss and logits characteristics of each sample. | features (numpy.ndarray): Loss and logits characteristics of each sample. | ||||
| labels (numpy.ndarray): Labels of each sample whether belongs to training set. | labels (numpy.ndarray): Labels of each sample whether belongs to training set. | ||||
| config (dict): Config of attacker, with key in ["method", "params"]. | config (dict): Config of attacker, with key in ["method", "params"]. | ||||
| The format is {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}}, | |||||
| params of each method must within the range of changeable parameters. | |||||
| Tips of params implement can be found in | |||||
| "https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". | |||||
| Returns: | Returns: | ||||
| sklearn.BaseEstimator, trained model specify by config["method"]. | sklearn.BaseEstimator, trained model specify by config["method"]. | ||||
| Examples: | |||||
| >>> features = np.random.randn(10, 10) | |||||
| >>> labels = np.random.randint(0, 2, 10) | |||||
| >>> config = {"method": "knn", "params": {"n_neighbors": [3, 5, 7]}} | |||||
| >>> attack_model = get_attack_model(features, labels, config) | |||||
| """ | """ | ||||
| method = str.lower(config["method"]) | method = str.lower(config["method"]) | ||||
| if method == "knn": | if method == "knn": | ||||
| return _attack_knn(features, labels, config["params"]) | return _attack_knn(features, labels, config["params"]) | ||||
| if method in ["lr", "logitic regression"]: | |||||
| if method == "lr": | |||||
| return _attack_lr(features, labels, config["params"]) | return _attack_lr(features, labels, config["params"]) | ||||
| if method == "mlp": | if method == "mlp": | ||||
| return _attack_mlpc(features, labels, config["params"]) | return _attack_mlpc(features, labels, config["params"]) | ||||
| if method in ["rf", "random forest"]: | |||||
| if method == "rf": | |||||
| return _attack_rf(features, labels, config["params"]) | return _attack_rf(features, labels, config["params"]) | ||||
| raise ValueError("Method {} is not support.".format(config["method"])) | |||||
| msg = "Method {} is not supported.".format(config["method"]) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| @@ -19,10 +19,14 @@ import numpy as np | |||||
| import mindspore as ms | import mindspore as ms | ||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| import mindspore.nn as nn | |||||
| import mindspore.context as context | |||||
| from mindspore.dataset.engine import Dataset | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindarmour.diff_privacy.evaluation.attacker import get_attack_model | from mindarmour.diff_privacy.evaluation.attacker import get_attack_model | ||||
| from mindarmour.utils.logger import LogUtil | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = "MembershipInference" | |||||
| def _eval_info(pred, truth, option): | def _eval_info(pred, truth, option): | ||||
| """ | """ | ||||
| @@ -42,7 +46,9 @@ def _eval_info(pred, truth, option): | |||||
| ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. | ValueError, value of parameter option must be in ["precision", "accuracy", "recall"]. | ||||
| """ | """ | ||||
| if pred.size == 0 or truth.size == 0: | if pred.size == 0 or truth.size == 0: | ||||
| raise ValueError("Size of pred or truth is 0.") | |||||
| msg = "Size of pred or truth is 0." | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| if option == "accuracy": | if option == "accuracy": | ||||
| count = np.sum(pred == truth) | count = np.sum(pred == truth) | ||||
| @@ -58,7 +64,25 @@ def _eval_info(pred, truth, option): | |||||
| return -1 | return -1 | ||||
| return count / np.sum(truth) | return count / np.sum(truth) | ||||
| raise ValueError("The metric value {} is undefined.".format(option)) | |||||
| msg = "The metric value {} is undefined.".format(option) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| def _softmax_cross_entropy(logits, labels): | |||||
| """ | |||||
| Calculate the SoftmaxCrossEntropy result between logits and labels. | |||||
| Args: | |||||
| logits (numpy.ndarray): Numpy array of shape(N, C). | |||||
| labels (numpy.ndarray): Numpy array of shape(N, ) | |||||
| Returns: | |||||
| numpy.ndarray: Numpy array of shape(N, ), containing loss value for each vector in logits. | |||||
| """ | |||||
| labels = np.eye(logits.shape[1])[labels].astype(np.int32) | |||||
| logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) | |||||
| return -1*np.sum(labels*np.log(logits), axis=1) | |||||
| class MembershipInference: | class MembershipInference: | ||||
| @@ -66,22 +90,23 @@ class MembershipInference: | |||||
| Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. | Evaluation proposed by Shokri, Stronati, Song and Shmatikov is a grey-box attack. | ||||
| The attack requires obtain loss or logits results of training samples. | The attack requires obtain loss or logits results of training samples. | ||||
| References: Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov. | |||||
| References: `Reza Shokri, Marco Stronati, Congzheng Song, Vitaly Shmatikov. | |||||
| Membership Inference Attacks against Machine Learning Models. 2017. | Membership Inference Attacks against Machine Learning Models. 2017. | ||||
| arXiv:1610.05820v2 <https://arxiv.org/abs/1610.05820v2>`_ | |||||
| <https://arxiv.org/abs/1610.05820v2>`_ | |||||
| Args: | Args: | ||||
| model (Model): Target model. | model (Model): Target model. | ||||
| Examples: | Examples: | ||||
| >>> # ds_train, eval_train are non-overlapping datasets from training dataset. | |||||
| >>> # eval_train, eval_test are non-overlapping datasets from test dataset. | |||||
| >>> train_1, train_2 are non-overlapping datasets from training dataset of target model. | |||||
| >>> test_1, test_2 are non-overlapping datasets from test dataset of target model. | |||||
| >>> We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. | |||||
| >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | >>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | ||||
| >>> inference_model = MembershipInference(model) | >>> inference_model = MembershipInference(model) | ||||
| >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | >>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | ||||
| >>> inference_model.train(ds_train, ds_test, config) | |||||
| >>> inference_model.train(train_1, test_1, config) | |||||
| >>> metrics = ["precision", "recall", "accuracy"] | >>> metrics = ["precision", "recall", "accuracy"] | ||||
| >>> result = inference_model.eval(eval_train, eval_test, metrics) | |||||
| >>> result = inference_model.eval(train_2, test_2, metrics) | |||||
| Raises: | Raises: | ||||
| TypeError: If type of model is not mindspore.train.Model. | TypeError: If type of model is not mindspore.train.Model. | ||||
| @@ -89,8 +114,12 @@ class MembershipInference: | |||||
| def __init__(self, model): | def __init__(self, model): | ||||
| if not isinstance(model, Model): | if not isinstance(model, Model): | ||||
| raise TypeError("Type of model must be {}, but got {}.".format(type(Model), type(model))) | |||||
| msg = "Type of parameter 'model' must be Model, but got {}.".format(type(model)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| self.model = model | self.model = model | ||||
| self.method_list = ["knn", "lr", "mlp", "rf"] | |||||
| self.attack_list = [] | self.attack_list = [] | ||||
| def train(self, dataset_train, dataset_test, attack_config): | def train(self, dataset_train, dataset_test, attack_config): | ||||
| @@ -101,11 +130,48 @@ class MembershipInference: | |||||
| Args: | Args: | ||||
| dataset_train (mindspore.dataset): The training dataset for the target model. | dataset_train (mindspore.dataset): The training dataset for the target model. | ||||
| dataset_test (mindspore.dataset): The test set for the target model. | dataset_test (mindspore.dataset): The test set for the target model. | ||||
| attack_config (list): Parameter setting for the attack model. | |||||
| attack_config (list): Parameter setting for the attack model. The format is | |||||
| [{"method": "knn", "params": {"n_neighbors": [3, 5, 7]}}, | |||||
| {"method": "lr", "params": {"C": np.logspace(-4, 2, 10)}}]. | |||||
| The support methods list is in self.method_list, and the params of each method | |||||
| must within the range of changeable parameters. Tips of params implement | |||||
| can be found in | |||||
| "https://scikit-learn.org/0.16/modules/generated/sklearn.grid_search.GridSearchCV.html". | |||||
| Raises: | Raises: | ||||
| ValueError: If the method in attack_config is not in ["LR", "KNN", "RF", "MLPC"]. | |||||
| KeyError: If each config in attack_config doesn't have keys {"method", "params"} | |||||
| ValueError: If the method(case insensitive) in attack_config is not in ["lr", "knn", "rf", "mlp"]. | |||||
| """ | """ | ||||
| if not isinstance(dataset_train, Dataset): | |||||
| msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if not isinstance(dataset_test, Dataset): | |||||
| msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if not isinstance(attack_config, list): | |||||
| msg = "Type of parameter 'attack_config' must be list, but got {}.".format(type(attack_config)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| for config in attack_config: | |||||
| if not isinstance(config, dict): | |||||
| msg = "Type of each config in 'attack_config' must be dict, but got {}.".format(type(config)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if {"params", "method"} != set(config.keys()): | |||||
| msg = "Each config in attack_config must have keys 'method' and 'params'," \ | |||||
| "but your key value is {}.".format(set(config.keys())) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise KeyError(msg) | |||||
| if str.lower(config["method"]) not in self.method_list: | |||||
| msg = "Method {} is not support.".format(config["method"]) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| features, labels = self._transform(dataset_train, dataset_test) | features, labels = self._transform(dataset_train, dataset_test) | ||||
| for config in attack_config: | for config in attack_config: | ||||
| self.attack_list.append(get_attack_model(features, labels, config)) | self.attack_list.append(get_attack_model(features, labels, config)) | ||||
| @@ -124,6 +190,28 @@ class MembershipInference: | |||||
| Returns: | Returns: | ||||
| list, Each element contains an evaluation indicator for the attack model. | list, Each element contains an evaluation indicator for the attack model. | ||||
| """ | """ | ||||
| if not isinstance(dataset_train, Dataset): | |||||
| msg = "Type of parameter 'dataset_train' must be Dataset, but got {}".format(type(dataset_train)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if not isinstance(dataset_test, Dataset): | |||||
| msg = "Type of parameter 'test_train' must be Dataset, but got {}".format(type(dataset_train)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if not isinstance(metrics, (list, tuple)): | |||||
| msg = "Type of parameter 'config' must be Union[list, tuple], but got {}.".format(type(metrics)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| metrics = set(metrics) | |||||
| metrics_list = {"precision", "accuracy", "recall"} | |||||
| if not metrics <= metrics_list: | |||||
| msg = "Element in 'metrics' must be in {}, but got {}.".format(metrics_list, metrics) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| result = [] | result = [] | ||||
| features, labels = self._transform(dataset_train, dataset_test) | features, labels = self._transform(dataset_train, dataset_test) | ||||
| for attacker in self.attack_list: | for attacker in self.attack_list: | ||||
| @@ -170,17 +258,12 @@ class MembershipInference: | |||||
| N is the number of sample. C = 1 + dim(logits). | N is the number of sample. C = 1 + dim(logits). | ||||
| - numpy.ndarray, Labels for each sample, Shape is (N,). | - numpy.ndarray, Labels for each sample, Shape is (N,). | ||||
| """ | """ | ||||
| if context.get_context("device_target") != "Ascend": | |||||
| raise RuntimeError("The target device must be Ascend, " | |||||
| "but current is {}.".format(context.get_context("device_target"))) | |||||
| loss_logits = np.array([]) | loss_logits = np.array([]) | ||||
| for batch in dataset_x.create_dict_iterator(): | for batch in dataset_x.create_dict_iterator(): | ||||
| batch_data = Tensor(batch['image'], ms.float32) | batch_data = Tensor(batch['image'], ms.float32) | ||||
| batch_labels = Tensor(batch['label'], ms.int32) | |||||
| batch_logits = self.model.predict(batch_data) | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction=None) | |||||
| batch_loss = loss(batch_logits, batch_labels).asnumpy() | |||||
| batch_logits = batch_logits.asnumpy() | |||||
| batch_labels = batch['label'].astype(np.int32) | |||||
| batch_logits = self.model.predict(batch_data).asnumpy() | |||||
| batch_loss = _softmax_cross_entropy(batch_logits, batch_labels) | |||||
| batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) | batch_feature = np.hstack((batch_loss.reshape(-1, 1), batch_logits)) | ||||
| if loss_logits.size == 0: | if loss_logits.size == 0: | ||||
| @@ -193,5 +276,7 @@ class MembershipInference: | |||||
| elif label == 0: | elif label == 0: | ||||
| labels = np.zeros(len(loss_logits), np.int32) | labels = np.zeros(len(loss_logits), np.int32) | ||||
| else: | else: | ||||
| raise ValueError("The value of label must be 0 or 1, but got {}.".format(label)) | |||||
| msg = "The value of label must be 0 or 1, but got {}.".format(label) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| return loss_logits, labels | return loss_logits, labels | ||||
| @@ -22,7 +22,8 @@ from mindspore import Tensor | |||||
| from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | from mindarmour.fuzzing.model_coverage_metrics import ModelCoverageMetrics | ||||
| from mindarmour.utils._check_param import check_model, check_numpy_param, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, \ | ||||
| check_param_multi_types, check_norm_level, check_param_in_range | |||||
| check_param_multi_types, check_norm_level, check_param_in_range, \ | |||||
| check_param_type, check_int_positive | |||||
| from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \ | from mindarmour.fuzzing.image_transform import Contrast, Brightness, Blur, \ | ||||
| Noise, Translate, Scale, Shear, Rotate | Noise, Translate, Scale, Shear, Rotate | ||||
| from mindarmour.attacks import FastGradientSignMethod, \ | from mindarmour.attacks import FastGradientSignMethod, \ | ||||
| @@ -93,7 +94,7 @@ class Fuzzer: | |||||
| >>> {'method': 'Translate', 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | >>> {'method': 'Translate', 'params': {'x_bias': 0.1, 'y_bias': 0.2}}, | ||||
| >>> {'method': 'FGSM', 'params': {'eps': 0.1, 'alpha': 0.1}}] | >>> {'method': 'FGSM', 'params': {'eps': 0.1, 'alpha': 0.1}}] | ||||
| >>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | >>> train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | ||||
| >>> model_fuzz_test = Fuzzer(model, train_images, 1000, 10) | |||||
| >>> model_fuzz_test = Fuzzer(model, train_images, 10, 1000) | |||||
| >>> samples, labels, preds, strategies, report = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | >>> samples, labels, preds, strategies, report = model_fuzz_test.fuzzing(mutate_config, initial_seeds) | ||||
| """ | """ | ||||
| @@ -101,8 +102,9 @@ class Fuzzer: | |||||
| self._target_model = check_model('model', target_model, Model) | self._target_model = check_model('model', target_model, Model) | ||||
| train_dataset = check_numpy_param('train_dataset', train_dataset) | train_dataset = check_numpy_param('train_dataset', train_dataset) | ||||
| self._coverage_metrics = ModelCoverageMetrics(target_model, | self._coverage_metrics = ModelCoverageMetrics(target_model, | ||||
| neuron_num, | |||||
| segmented_num, | segmented_num, | ||||
| neuron_num, train_dataset) | |||||
| train_dataset) | |||||
| # Allowed mutate strategies so far. | # Allowed mutate strategies so far. | ||||
| self._strategies = {'Contrast': Contrast, 'Brightness': Brightness, | self._strategies = {'Contrast': Contrast, 'Brightness': Brightness, | ||||
| 'Blur': Blur, 'Noise': Noise, 'Translate': Translate, | 'Blur': Blur, 'Noise': Noise, 'Translate': Translate, | ||||
| @@ -115,23 +117,21 @@ class Fuzzer: | |||||
| 'Noise'] | 'Noise'] | ||||
| self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] | self._attacks_list = ['FGSM', 'PGD', 'MDIIM'] | ||||
| self._attack_param_checklists = { | self._attack_param_checklists = { | ||||
| 'FGSM': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||||
| 'alpha': {'dtype': [float, int], | |||||
| 'FGSM': {'params': {'eps': {'dtype': [float], 'range': [0, 1]}, | |||||
| 'alpha': {'dtype': [float], | |||||
| 'range': [0, 1]}, | 'range': [0, 1]}, | ||||
| 'bounds': {'dtype': [list, tuple], | |||||
| 'range': None}}}, | |||||
| 'PGD': {'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||||
| 'eps_iter': {'dtype': [float, int], | |||||
| 'range': [0, 1e5]}, | |||||
| 'nb_iter': {'dtype': [float, int], | |||||
| 'bounds': {'dtype': [tuple]}}}, | |||||
| 'PGD': {'params': {'eps': {'dtype': [float], 'range': [0, 1]}, | |||||
| 'eps_iter': {'dtype': [float], | |||||
| 'range': [0, 1]}, | |||||
| 'nb_iter': {'dtype': [int], | |||||
| 'range': [0, 1e5]}, | 'range': [0, 1e5]}, | ||||
| 'bounds': {'dtype': [list, tuple], | |||||
| 'range': None}}}, | |||||
| 'bounds': {'dtype': [tuple]}}}, | |||||
| 'MDIIM': { | 'MDIIM': { | ||||
| 'params': {'eps': {'dtype': [float, int], 'range': [0, 1]}, | |||||
| 'norm_level': {'dtype': [str], 'range': None}, | |||||
| 'prob': {'dtype': [float, int], 'range': [0, 1]}, | |||||
| 'bounds': {'dtype': [list, tuple], 'range': None}}}} | |||||
| 'params': {'eps': {'dtype': [float], 'range': [0, 1]}, | |||||
| 'norm_level': {'dtype': [str]}, | |||||
| 'prob': {'dtype': [float], 'range': [0, 1]}, | |||||
| 'bounds': {'dtype': [tuple]}}}} | |||||
| def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', | def fuzzing(self, mutate_config, initial_seeds, coverage_metric='KMNC', | ||||
| eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20): | eval_metrics='auto', max_iters=10000, mutate_num_per_seed=20): | ||||
| @@ -140,16 +140,29 @@ class Fuzzer: | |||||
| Args: | Args: | ||||
| mutate_config (list): Mutate configs. The format is | mutate_config (list): Mutate configs. The format is | ||||
| [{'method': 'Blur', 'params': {'auto_param': True}}, {'method': 'Contrast', 'params': {'factor': 2}}]. | |||||
| The support methods list is in `self._strategies`, and the params of each | |||||
| method must within the range of changeable parameters. | |||||
| initial_seeds (numpy.ndarray): Initial seeds used to generate | |||||
| mutated samples. | |||||
| coverage_metric (str): Model coverage metric of neural networks. | |||||
| Default: 'KMNC'. | |||||
| eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the type is 'auto', | |||||
| it will calculate all the metrics, else if the type is list or tuple, it will | |||||
| calculate the metrics specified by user. Default: 'auto'. | |||||
| [{'method': 'Blur', 'params': {'auto_param': True}}, | |||||
| {'method': 'Contrast', 'params': {'factor': 2}}]. The | |||||
| supported methods list is in `self._strategies`, and the | |||||
| params of each method must within the range of changeable parameters. | |||||
| Supported methods are grouped in three types: | |||||
| Firstly, pixel value based transform methods include: | |||||
| 'Contrast', 'Brightness', 'Blur' and 'Noise'. Secondly, affine | |||||
| transform methods include: 'Translate', 'Scale', 'Shear' and | |||||
| 'Rotate'. Thirdly, attack methods include: 'FGSM', 'PGD' and 'MDIIM'. | |||||
| `mutate_config` must have method in the type of pixel value based | |||||
| transform methods. The way of setting parameters for first and | |||||
| second type methods can be seen in 'mindarmour/fuzzing/image_transform.py'. | |||||
| For third type methods, you can refer to the corresponding class. | |||||
| initial_seeds (list[list]): Initial seeds used to generate mutated | |||||
| samples. The format of initial seeds is [[image_data, label], | |||||
| [...], ...]. | |||||
| coverage_metric (str): Model coverage metric of neural networks. All | |||||
| supported metrics are: 'KMNC', 'NBC', 'SNAC'. Default: 'KMNC'. | |||||
| eval_metrics (Union[list, tuple, str]): Evaluation metrics. If the | |||||
| type is 'auto', it will calculate all the metrics, else if the | |||||
| type is list or tuple, it will calculate the metrics specified | |||||
| by user. All supported evaluate methods are 'accuracy', | |||||
| 'attack_success_rate', 'kmnc', 'nbc', 'snac'. Default: 'auto'. | |||||
| max_iters (int): Max number of select a seed to mutate. | max_iters (int): Max number of select a seed to mutate. | ||||
| Default: 10000. | Default: 10000. | ||||
| mutate_num_per_seed (int): The number of mutate times for a seed. | mutate_num_per_seed (int): The number of mutate times for a seed. | ||||
| @@ -173,16 +186,10 @@ class Fuzzer: | |||||
| ValueError: If metric in list `eval_metrics` is not in ['accuracy', 'attack_success_rate', | ValueError: If metric in list `eval_metrics` is not in ['accuracy', 'attack_success_rate', | ||||
| 'kmnc', 'nbc', 'snac']. | 'kmnc', 'nbc', 'snac']. | ||||
| """ | """ | ||||
| eval_metrics_ = None | |||||
| if isinstance(eval_metrics, (list, tuple)): | if isinstance(eval_metrics, (list, tuple)): | ||||
| eval_metrics_ = [] | eval_metrics_ = [] | ||||
| avaliable_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] | avaliable_metrics = ['accuracy', 'attack_success_rate', 'kmnc', 'nbc', 'snac'] | ||||
| for elem in eval_metrics: | for elem in eval_metrics: | ||||
| if not isinstance(elem, str): | |||||
| msg = 'the type of metric in list `eval_metrics` must be str, but got {}.' \ | |||||
| .format(type(elem)) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if elem not in avaliable_metrics: | if elem not in avaliable_metrics: | ||||
| msg = 'metric in list `eval_metrics` must be in {}, but got {}.' \ | msg = 'metric in list `eval_metrics` must be in {}, but got {}.' \ | ||||
| .format(avaliable_metrics, elem) | .format(avaliable_metrics, elem) | ||||
| @@ -203,7 +210,33 @@ class Fuzzer: | |||||
| raise TypeError(msg) | raise TypeError(msg) | ||||
| # Check whether the mutate_config meet the specification. | # Check whether the mutate_config meet the specification. | ||||
| mutate_config = check_param_type('mutate_config', mutate_config, list) | |||||
| for config in mutate_config: | |||||
| check_param_type("config['params']", config['params'], dict) | |||||
| if set(config.keys()) != {'method', 'params'}: | |||||
| msg = "Config must contain 'method' and 'params', but got {}." \ | |||||
| .format(set(config.keys())) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if config['method'] not in self._strategies.keys(): | |||||
| msg = "Config methods must be in {}, but got {}." \ | |||||
| .format(self._strategies.keys(), config['method']) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise TypeError(msg) | |||||
| if coverage_metric not in ['KMNC', 'NBC', 'SNAC']: | |||||
| msg = "coverage_metric must be in ['KMNC', 'NBC', 'SNAC'], but got {}." \ | |||||
| .format(coverage_metric) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| max_iters = check_int_positive('max_iters', max_iters) | |||||
| mutate_num_per_seed = check_int_positive('mutate_num_per_seed', mutate_num_per_seed) | |||||
| mutates = self._init_mutates(mutate_config) | mutates = self._init_mutates(mutate_config) | ||||
| initial_seeds = check_param_type('initial_seeds', initial_seeds, list) | |||||
| for seed in initial_seeds: | |||||
| check_param_type('seed', seed, list) | |||||
| check_numpy_param('seed[0]', seed[0]) | |||||
| check_numpy_param('seed[1]', seed[1]) | |||||
| seed.append(0) | |||||
| seed, initial_seeds = _select_next(initial_seeds) | seed, initial_seeds = _select_next(initial_seeds) | ||||
| fuzz_samples = [] | fuzz_samples = [] | ||||
| gt_labels = [] | gt_labels = [] | ||||
| @@ -248,7 +281,7 @@ class Fuzzer: | |||||
| for index in range(len(samples)): | for index in range(len(samples)): | ||||
| mutate = samples[:index + 1] | mutate = samples[:index + 1] | ||||
| self._coverage_metrics.calculate_coverage(mutate.astype(np.float32)) | self._coverage_metrics.calculate_coverage(mutate.astype(np.float32)) | ||||
| if coverage_metric == "KMNC": | |||||
| if coverage_metric == 'KMNC': | |||||
| coverages.append(self._coverage_metrics.get_kmnc()) | coverages.append(self._coverage_metrics.get_kmnc()) | ||||
| if coverage_metric == 'NBC': | if coverage_metric == 'NBC': | ||||
| coverages.append(self._coverage_metrics.get_nbc()) | coverages.append(self._coverage_metrics.get_nbc()) | ||||
| @@ -357,18 +390,24 @@ class Fuzzer: | |||||
| dict, evaluate metrics include accuarcy, attack success rate | dict, evaluate metrics include accuarcy, attack success rate | ||||
| and neural coverage. | and neural coverage. | ||||
| """ | """ | ||||
| gt_labels = np.asarray(gt_labels) | |||||
| fuzz_preds = np.asarray(fuzz_preds) | |||||
| temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1) | temp = np.argmax(gt_labels, axis=1) == np.argmax(fuzz_preds, axis=1) | ||||
| metrics_report = {} | metrics_report = {} | ||||
| if metrics == 'auto' or 'accuracy' in metrics: | if metrics == 'auto' or 'accuracy' in metrics: | ||||
| gt_labels = np.asarray(gt_labels) | |||||
| fuzz_preds = np.asarray(fuzz_preds) | |||||
| acc = np.sum(temp) / np.size(temp) | |||||
| if temp.any(): | |||||
| acc = np.sum(temp) / np.size(temp) | |||||
| else: | |||||
| acc = 0 | |||||
| metrics_report['Accuracy'] = acc | metrics_report['Accuracy'] = acc | ||||
| if metrics == 'auto' or 'attack_success_rate' in metrics: | if metrics == 'auto' or 'attack_success_rate' in metrics: | ||||
| cond = [elem in self._attacks_list for elem in fuzz_strategies] | cond = [elem in self._attacks_list for elem in fuzz_strategies] | ||||
| temp = temp[cond] | temp = temp[cond] | ||||
| attack_success_rate = 1 - np.sum(temp) / np.size(temp) | |||||
| if temp.any(): | |||||
| attack_success_rate = 1 - np.sum(temp) / np.size(temp) | |||||
| else: | |||||
| attack_success_rate = None | |||||
| metrics_report['Attack_success_rate'] = attack_success_rate | metrics_report['Attack_success_rate'] = attack_success_rate | ||||
| if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics: | if metrics == 'auto' or 'kmnc' in metrics or 'nbc' in metrics or 'snac' in metrics: | ||||
| @@ -350,8 +350,10 @@ class Translate(ImageTransform): | |||||
| Translate an image. | Translate an image. | ||||
| Args: | Args: | ||||
| x_bias ([int, float): X-direction translation, x=x+x_bias. Default: 0. | |||||
| y_bias ([int, float): Y-direction translation, y=y+y_bias. Default: 0. | |||||
| x_bias ([int, float): X-direction translation, x=x+x_bias*image_length. | |||||
| Default: 0. | |||||
| y_bias ([int, float): Y-direction translation, y=y+y_bias*image_wide. | |||||
| Default: 0. | |||||
| """ | """ | ||||
| def __init__(self, x_bias=0, y_bias=0): | def __init__(self, x_bias=0, y_bias=0): | ||||
| @@ -363,8 +365,10 @@ class Translate(ImageTransform): | |||||
| Set translate parameters. | Set translate parameters. | ||||
| Args: | Args: | ||||
| x_bias ([float, int]): X-direction translation, x=x+x_bias. Default: 0. | |||||
| y_bias ([float, int]): Y-direction translation, y=y+y_bias. Default: 0. | |||||
| x_bias ([float, int]): X-direction translation, x=x+x_bias*image_length. | |||||
| Default: 0. | |||||
| y_bias ([float, int]): Y-direction translation, y=y+y_bias*image_wide. | |||||
| Default: 0. | |||||
| auto_param (bool): True if auto generate parameters. Default: False. | auto_param (bool): True if auto generate parameters. Default: False. | ||||
| """ | """ | ||||
| self.auto_param = auto_param | self.auto_param = auto_param | ||||
| @@ -579,7 +583,7 @@ class Rotate(ImageTransform): | |||||
| """ | """ | ||||
| _, chw, normalized, gray3dim, image = self._check(image) | _, chw, normalized, gray3dim, image = self._check(image) | ||||
| img = to_pil(image) | img = to_pil(image) | ||||
| trans_image = img.rotate(self.angle, expand=True) | |||||
| trans_image = img.rotate(self.angle, expand=False) | |||||
| trans_image = self._original_format(trans_image, chw, normalized, | trans_image = self._original_format(trans_image, chw, normalized, | ||||
| gray3dim) | gray3dim) | ||||
| return trans_image | return trans_image | ||||
| @@ -21,7 +21,7 @@ from mindspore import Tensor | |||||
| from mindspore import Model | from mindspore import Model | ||||
| from mindarmour.utils._check_param import check_model, check_numpy_param, \ | from mindarmour.utils._check_param import check_model, check_numpy_param, \ | ||||
| check_int_positive | |||||
| check_int_positive, check_param_multi_types | |||||
| from mindarmour.utils.logger import LogUtil | from mindarmour.utils.logger import LogUtil | ||||
| LOGGER = LogUtil.get_instance() | LOGGER = LogUtil.get_instance() | ||||
| @@ -43,8 +43,8 @@ class ModelCoverageMetrics: | |||||
| Args: | Args: | ||||
| model (Model): The pre-trained model which waiting for testing. | model (Model): The pre-trained model which waiting for testing. | ||||
| segmented_num (int): The number of segmented sections of neurons' output intervals. | |||||
| neuron_num (int): The number of testing neurons. | neuron_num (int): The number of testing neurons. | ||||
| segmented_num (int): The number of segmented sections of neurons' output intervals. | |||||
| train_dataset (numpy.ndarray): Training dataset used for determine | train_dataset (numpy.ndarray): Training dataset used for determine | ||||
| the neurons' output boundaries. | the neurons' output boundaries. | ||||
| @@ -52,17 +52,18 @@ class ModelCoverageMetrics: | |||||
| ValueError: If neuron_num is too big (for example, bigger than 1e+9). | ValueError: If neuron_num is too big (for example, bigger than 1e+9). | ||||
| Examples: | Examples: | ||||
| >>> train_images = np.random.random((10000, 128)).astype(np.float32) | |||||
| >>> test_images = np.random.random((5000, 128)).astype(np.float32) | |||||
| >>> net = LeNet5() | |||||
| >>> train_images = np.random.random((10000, 1, 32, 32)).astype(np.float32) | |||||
| >>> test_images = np.random.random((5000, 1, 32, 32)).astype(np.float32) | |||||
| >>> model = Model(net) | >>> model = Model(net) | ||||
| >>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | |||||
| >>> model_fuzz_test.test_adequacy_coverage_calculate(test_images) | |||||
| >>> model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
| >>> model_fuzz_test.calculate_coverage(test_images) | |||||
| >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | >>> print('KMNC of this test is : %s', model_fuzz_test.get_kmnc()) | ||||
| >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | >>> print('NBC of this test is : %s', model_fuzz_test.get_nbc()) | ||||
| >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) | >>> print('SNAC of this test is : %s', model_fuzz_test.get_snac()) | ||||
| """ | """ | ||||
| def __init__(self, model, segmented_num, neuron_num, train_dataset): | |||||
| def __init__(self, model, neuron_num, segmented_num, train_dataset): | |||||
| self._model = check_model('model', model, Model) | self._model = check_model('model', model, Model) | ||||
| self._segmented_num = check_int_positive('segmented_num', segmented_num) | self._segmented_num = check_int_positive('segmented_num', segmented_num) | ||||
| self._neuron_num = check_int_positive('neuron_num', neuron_num) | self._neuron_num = check_int_positive('neuron_num', neuron_num) | ||||
| @@ -139,8 +140,8 @@ class ModelCoverageMetrics: | |||||
| Args: | Args: | ||||
| dataset (numpy.ndarray): Data for fuzz test. | dataset (numpy.ndarray): Data for fuzz test. | ||||
| bias_coefficient (float): The coefficient used for changing the | |||||
| neurons' output boundaries. Default: 0. | |||||
| bias_coefficient (Union[int, float]): The coefficient used | |||||
| for changing the neurons' output boundaries. Default: 0. | |||||
| batch_size (int): The number of samples in a predict batch. | batch_size (int): The number of samples in a predict batch. | ||||
| Default: 32. | Default: 32. | ||||
| @@ -148,8 +149,10 @@ class ModelCoverageMetrics: | |||||
| >>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | >>> model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, train_images) | ||||
| >>> model_fuzz_test.calculate_coverage(test_images) | >>> model_fuzz_test.calculate_coverage(test_images) | ||||
| """ | """ | ||||
| dataset = check_numpy_param('dataset', dataset) | dataset = check_numpy_param('dataset', dataset) | ||||
| batch_size = check_int_positive('batch_size', batch_size) | batch_size = check_int_positive('batch_size', batch_size) | ||||
| bias_coefficient = check_param_multi_types('bias_coefficient', bias_coefficient, [int, float]) | |||||
| self._lower_bounds -= bias_coefficient*self._var | self._lower_bounds -= bias_coefficient*self._var | ||||
| self._upper_bounds += bias_coefficient*self._var | self._upper_bounds += bias_coefficient*self._var | ||||
| intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num | intervals = (self._upper_bounds - self._lower_bounds) / self._segmented_num | ||||
| @@ -78,7 +78,11 @@ class LogUtil: | |||||
| def set_level(self, level): | def set_level(self, level): | ||||
| """ | """ | ||||
| Set the logging level of this logger, level must be an integer or a | Set the logging level of this logger, level must be an integer or a | ||||
| string. | |||||
| string. Supported levels are 'NOTSET'(integer: 0), 'ERROR'(integer: 1-40), | |||||
| 'WARNING'('WARN', integer: 1-30), 'INFO'(integer: 1-20) and 'DEBUG'(integer: 1-10). | |||||
| For example, if logger.set_level('WARNING') or logger.set_level(21), then | |||||
| logger.warn() and logger.error() in scripts would be printed while running, | |||||
| while logger.info() or logger.debug() would not be printed. | |||||
| Args: | Args: | ||||
| level (Union[int, str]): Level of logger. | level (Union[int, str]): Level of logger. | ||||
| @@ -98,7 +98,7 @@ class GradWrapWithLoss(Cell): | |||||
| Examples: | Examples: | ||||
| >>> data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)*0.01) | >>> data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32)*0.01) | ||||
| >>> label = Tensor(np.ones([1, 10]).astype(np.float32)) | |||||
| >>> labels = Tensor(np.ones([1, 10]).astype(np.float32)) | |||||
| >>> net = NET() | >>> net = NET() | ||||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() | >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits() | ||||
| >>> loss_net = WithLossCell(net, loss_fn) | >>> loss_net = WithLossCell(net, loss_fn) | ||||
| @@ -71,7 +71,7 @@ def test_lenet_mnist_coverage_cpu(): | |||||
| # initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
| training_data = (np.random.random((10000, 10))*20).astype(np.float32) | training_data = (np.random.random((10000, 10))*20).astype(np.float32) | ||||
| model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data) | |||||
| model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, training_data) | |||||
| # fuzz test with original test data | # fuzz test with original test data | ||||
| # get test data | # get test data | ||||
| @@ -105,7 +105,7 @@ def test_lenet_mnist_coverage_ascend(): | |||||
| # initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
| training_data = (np.random.random((10000, 10))*20).astype(np.float32) | training_data = (np.random.random((10000, 10))*20).astype(np.float32) | ||||
| model_fuzz_test = ModelCoverageMetrics(model, 10000, 10, training_data) | |||||
| model_fuzz_test = ModelCoverageMetrics(model, 10, 1000, training_data) | |||||
| # fuzz test with original test data | # fuzz test with original test data | ||||
| # get test data | # get test data | ||||
| @@ -102,7 +102,7 @@ def test_fuzzing_ascend(): | |||||
| ] | ] | ||||
| # initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
| train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | ||||
| model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||||
| model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
| # fuzz test with original test data | # fuzz test with original test data | ||||
| # get test data | # get test data | ||||
| @@ -113,7 +113,7 @@ def test_fuzzing_ascend(): | |||||
| initial_seeds = [] | initial_seeds = [] | ||||
| # make initial seeds | # make initial seeds | ||||
| for img, label in zip(test_images, test_labels): | for img, label in zip(test_images, test_labels): | ||||
| initial_seeds.append([img, label, 0]) | |||||
| initial_seeds.append([img, label]) | |||||
| initial_seeds = initial_seeds[:100] | initial_seeds = initial_seeds[:100] | ||||
| model_coverage_test.calculate_coverage( | model_coverage_test.calculate_coverage( | ||||
| @@ -148,7 +148,7 @@ def test_fuzzing_cpu(): | |||||
| ] | ] | ||||
| # initialize fuzz test with training dataset | # initialize fuzz test with training dataset | ||||
| train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | train_images = np.random.rand(32, 1, 32, 32).astype(np.float32) | ||||
| model_coverage_test = ModelCoverageMetrics(model, 1000, 10, train_images) | |||||
| model_coverage_test = ModelCoverageMetrics(model, 10, 1000, train_images) | |||||
| # fuzz test with original test data | # fuzz test with original test data | ||||
| # get test data | # get test data | ||||
| @@ -159,7 +159,7 @@ def test_fuzzing_cpu(): | |||||
| initial_seeds = [] | initial_seeds = [] | ||||
| # make initial seeds | # make initial seeds | ||||
| for img, label in zip(test_images, test_labels): | for img, label in zip(test_images, test_labels): | ||||
| initial_seeds.append([img, label, 0]) | |||||
| initial_seeds.append([img, label]) | |||||
| initial_seeds = initial_seeds[:100] | initial_seeds = initial_seeds[:100] | ||||
| model_coverage_test.calculate_coverage( | model_coverage_test.calculate_coverage( | ||||