You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example_vgg_cifar.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Examples of membership inference
  17. """
  18. import argparse
  19. import sys
  20. import numpy as np
  21. from mindspore.train import Model
  22. from mindspore.train.serialization import load_param_into_net, load_checkpoint
  23. import mindspore.nn as nn
  24. from mindarmour.privacy.evaluation import MembershipInference
  25. from mindarmour.utils import LogUtil
  26. from examples.common.networks.vgg.vgg import vgg16
  27. from examples.common.networks.vgg.config import cifar_cfg as cfg
  28. from examples.common.networks.vgg.utils.util import get_param_groups
  29. from examples.common.dataset.data_processing import vgg_create_dataset100
  30. logging = LogUtil.get_instance()
  31. logging.set_level(20)
  32. sys.path.append("../../../")
  33. TAG = "membership inference example"
  34. if __name__ == "__main__":
  35. parser = argparse.ArgumentParser("main case arg parser.")
  36. parser.add_argument("--device_target", type=str, default="Ascend",
  37. choices=["Ascend"])
  38. parser.add_argument("--data_path", type=str, required=True,
  39. help="Data home path for Cifar100.")
  40. parser.add_argument("--pre_trained", type=str, required=True,
  41. help="Checkpoint path.")
  42. args = parser.parse_args()
  43. args.num_classes = cfg.num_classes
  44. args.batch_norm = cfg.batch_norm
  45. args.has_dropout = cfg.has_dropout
  46. args.has_bias = cfg.has_bias
  47. args.initialize_mode = cfg.initialize_mode
  48. args.padding = cfg.padding
  49. args.pad_mode = cfg.pad_mode
  50. args.weight_decay = cfg.weight_decay
  51. args.loss_scale = cfg.loss_scale
  52. # load the pretrained model
  53. net = vgg16(args.num_classes, args)
  54. loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
  55. opt = nn.Momentum(params=get_param_groups(net), learning_rate=0.1, momentum=0.9,
  56. weight_decay=args.weight_decay, loss_scale=args.loss_scale)
  57. load_param_into_net(net, load_checkpoint(args.pre_trained))
  58. model = Model(network=net, loss_fn=loss, optimizer=opt)
  59. logging.info(TAG, "The model is loaded.")
  60. attacker = MembershipInference(model)
  61. config = [
  62. {
  63. "method": "knn",
  64. "params": {
  65. "n_neighbors": [3, 5, 7]
  66. }
  67. },
  68. {
  69. "method": "lr",
  70. "params": {
  71. "C": np.logspace(-4, 2, 10)
  72. }
  73. },
  74. {
  75. "method": "mlp",
  76. "params": {
  77. "hidden_layer_sizes": [(64,), (32, 32)],
  78. "solver": ["adam"],
  79. "alpha": [0.0001, 0.001, 0.01]
  80. }
  81. },
  82. {
  83. "method": "rf",
  84. "params": {
  85. "n_estimators": [100],
  86. "max_features": ["auto", "sqrt"],
  87. "max_depth": [5, 10, 20, None],
  88. "min_samples_split": [2, 5, 10],
  89. "min_samples_leaf": [1, 2, 4]
  90. }
  91. }
  92. ]
  93. # load and split dataset
  94. train_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224),
  95. batch_size=64, num_samples=10000, shuffle=False)
  96. test_dataset = vgg_create_dataset100(data_home=args.data_path, image_size=(224, 224),
  97. batch_size=64, num_samples=10000, shuffle=False, training=False)
  98. train_train, eval_train = train_dataset.split([0.8, 0.2])
  99. train_test, eval_test = test_dataset.split([0.8, 0.2])
  100. logging.info(TAG, "Data loading is complete.")
  101. logging.info(TAG, "Start training the inference model.")
  102. attacker.train(train_train, train_test, config)
  103. logging.info(TAG, "The inference model is training complete.")
  104. logging.info(TAG, "Start the evaluation phase")
  105. metrics = ["precision", "accuracy", "recall"]
  106. result = attacker.eval(eval_train, eval_test, metrics)
  107. # Show the metrics for each attack method.
  108. count = len(config)
  109. for i in range(count):
  110. print("Method: {}, {}".format(config[i]["method"], result[i]))

MindArmour关注AI的安全和隐私问题。致力于增强模型的安全可信、保护用户的数据隐私。主要包含3个模块:对抗样本鲁棒性模块、Fuzz Testing模块、隐私保护与评估模块。 对抗样本鲁棒性模块 对抗样本鲁棒性模块用于评估模型对于对抗样本的鲁棒性,并提供模型增强方法用于增强模型抗对抗样本攻击的能力,提升模型鲁棒性。对抗样本鲁棒性模块包含了4个子模块:对抗样本的生成、对抗样本的检测、模型防御、攻防评估。