https://arxiv.org/abs/1906.08935tags/v1.2.0-rc1
| @@ -1,33 +1,54 @@ | |||||
| # Application demos of privacy stealing and privacy protection | # Application demos of privacy stealing and privacy protection | ||||
| ## Introduction | ## Introduction | ||||
| Although machine learning could obtain a generic model based on training data, it has been proved that the trained | Although machine learning could obtain a generic model based on training data, it has been proved that the trained | ||||
| model may disclose the information of training data (such as the membership inference attack). Differential | |||||
| privacy training | |||||
| is an effective | |||||
| method proposed | |||||
| to overcome this problem, in which Gaussian noise is added while training. There are mainly three parts for | |||||
| differential privacy(DP) training: noise-generating mechanism, DP optimizer and DP monitor. We have implemented | |||||
| a novel noise-generating mechanisms: adaptive decay noise mechanism. DP | |||||
| monitor is used to compute the privacy budget while training. | |||||
| model may disclose the information of training data (such as the membership inference attack). | |||||
| Differential privacy training is an effective method proposed to overcome this problem, in which Gaussian noise is | |||||
| added while training. There are mainly three parts for differential privacy(DP) training: noise-generating | |||||
| mechanism, DP optimizer and DP monitor. We have implemented a novel noise-generating mechanisms: adaptive decay | |||||
| noise mechanism. DP monitor is used to compute the privacy budget while training. | |||||
| Suppress Privacy training is a novel method to protect privacy distinct from the noise addition method | |||||
| (such as DP), in which the negligible model parameter is removed gradually to achieve a better balance between | |||||
| accuracy and privacy. | |||||
| ## 1. Adaptive decay DP training | ## 1. Adaptive decay DP training | ||||
| With adaptive decay mechanism, the magnitude of the Gaussian noise would be decayed as the training step grows, which | With adaptive decay mechanism, the magnitude of the Gaussian noise would be decayed as the training step grows, which | ||||
| resulting a stable convergence. | resulting a stable convergence. | ||||
| ```sh | ```sh | ||||
| $ cd examples/privacy/diff_privacy | |||||
| $ python lenet5_dp_ada_gaussian.py | |||||
| cd examples/privacy/diff_privacy | |||||
| python lenet5_dp_ada_gaussian.py | |||||
| ``` | ``` | ||||
| ## 2. Adaptive norm clip training | ## 2. Adaptive norm clip training | ||||
| With adaptive norm clip mechanism, the norm clip of the gradients would be changed according to the norm values of | With adaptive norm clip mechanism, the norm clip of the gradients would be changed according to the norm values of | ||||
| them, which can adjust the ratio of noise and original gradients. | them, which can adjust the ratio of noise and original gradients. | ||||
| ```sh | ```sh | ||||
| $ cd examples/privacy/diff_privacy | |||||
| $ python lenet5_dp.py | |||||
| cd examples/privacy/diff_privacy | |||||
| python lenet5_dp.py | |||||
| ``` | ``` | ||||
| ## 3. Membership inference evaluation | ## 3. Membership inference evaluation | ||||
| By this evaluation method, we could judge whether a sample is belongs to training dataset or not. | By this evaluation method, we could judge whether a sample is belongs to training dataset or not. | ||||
| ```sh | |||||
| cd examples/privacy/membership_inference_attack | |||||
| python train.py --data_path home_path_to_cifar100 --ckpt_path ./ | |||||
| python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt | |||||
| ``` | |||||
| ## 4. suppress privacy training | |||||
| With suppress privacy mechanism, the values of some trainable parameters (such as conv layers and fully connected | |||||
| layers) are set to zero as the training step grows, which can | |||||
| achieve a better balance between accuracy and privacy | |||||
| ```sh | ```sh | ||||
| $ cd examples/privacy/membership_inference_attack | |||||
| $ python train.py --data_path home_path_to_cifar100 --ckpt_path ./ | |||||
| $ python example_vgg_cifar.py --data_path home_path_to_cifar100 --pre_trained 0-100_781.ckpt | |||||
| cd examples/privacy/sup_privacy | |||||
| python sup_privacy.py | |||||
| ``` | ``` | ||||
| @@ -0,0 +1,154 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| Training example of suppress-based privacy. | |||||
| """ | |||||
| import os | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore.train.callback import ModelCheckpoint | |||||
| from mindspore.train.callback import CheckpointConfig | |||||
| from mindspore.train.callback import LossMonitor | |||||
| from mindspore.nn.metrics import Accuracy | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| import mindspore.dataset as ds | |||||
| import mindspore.dataset.vision.c_transforms as CV | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| from mindspore.dataset.vision.utils import Inter | |||||
| import mindspore.common.dtype as mstype | |||||
| from examples.common.networks.lenet5.lenet5_net import LeNet5 | |||||
| from sup_privacy_config import mnist_cfg as cfg | |||||
| from mindarmour.privacy.sup_privacy import SuppressModel | |||||
| from mindarmour.privacy.sup_privacy import SuppressMasker | |||||
| from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory | |||||
| from mindarmour.privacy.sup_privacy import MaskLayerDes | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| LOGGER = LogUtil.get_instance() | |||||
| LOGGER.set_level('INFO') | |||||
| TAG = 'Lenet5_Suppress_train' | |||||
| def generate_mnist_dataset(data_path, batch_size=32, repeat_size=1, samples=None, num_parallel_workers=1, sparse=True): | |||||
| """ | |||||
| create dataset for training or testing | |||||
| """ | |||||
| # define dataset | |||||
| ds1 = ds.MnistDataset(data_path, num_samples=samples) | |||||
| # define operation parameters | |||||
| resize_height, resize_width = 32, 32 | |||||
| rescale = 1.0 / 255.0 | |||||
| shift = 0.0 | |||||
| # define map operations | |||||
| resize_op = CV.Resize((resize_height, resize_width), | |||||
| interpolation=Inter.LINEAR) | |||||
| rescale_op = CV.Rescale(rescale, shift) | |||||
| hwc2chw_op = CV.HWC2CHW() | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| # apply map operations on images | |||||
| if not sparse: | |||||
| one_hot_enco = C.OneHot(10) | |||||
| ds1 = ds1.map(input_columns="label", operations=one_hot_enco, num_parallel_workers=num_parallel_workers) | |||||
| type_cast_op = C.TypeCast(mstype.float32) | |||||
| ds1 = ds1.map(input_columns="label", operations=type_cast_op, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| ds1 = ds1.map(input_columns="image", operations=resize_op, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| ds1 = ds1.map(input_columns="image", operations=rescale_op, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| ds1 = ds1.map(input_columns="image", operations=hwc2chw_op, | |||||
| num_parallel_workers=num_parallel_workers) | |||||
| # apply DatasetOps | |||||
| buffer_size = 10000 | |||||
| ds1 = ds1.shuffle(buffer_size=buffer_size) | |||||
| ds1 = ds1.batch(batch_size, drop_remainder=True) | |||||
| ds1 = ds1.repeat(repeat_size) | |||||
| return ds1 | |||||
| def mnist_suppress_train(epoch_size=10, start_epoch=3, lr=0.05, samples=10000, mask_times=1000, | |||||
| sparse_thd=0.90, sparse_start=0.0, masklayers=None): | |||||
| """ | |||||
| local train by suppress-based privacy | |||||
| """ | |||||
| networks_l5 = LeNet5() | |||||
| suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", | |||||
| end_epoch=epoch_size, | |||||
| batch_num=(int)(samples/cfg.batch_size), | |||||
| start_epoch=start_epoch, | |||||
| mask_times=mask_times, | |||||
| networks=networks_l5, | |||||
| lr=lr, | |||||
| sparse_end=sparse_thd, | |||||
| sparse_start=sparse_start, | |||||
| mask_layers=masklayers) | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
| net_opt = nn.SGD(networks_l5.trainable_params(), lr) | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), | |||||
| keep_checkpoint_max=10) | |||||
| # Create the SuppressModel model for training. | |||||
| model_instance = SuppressModel(network=networks_l5, | |||||
| loss_fn=net_loss, | |||||
| optimizer=net_opt, | |||||
| metrics={"Accuracy": Accuracy()}) | |||||
| model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
| # Create a Masker for Suppress training. The function of the Masker is to | |||||
| # enforce suppress operation while training. | |||||
| suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance) | |||||
| mnist_path = "./MNIST_unzip/" #"../../MNIST_unzip/" | |||||
| ds_train = generate_mnist_dataset(os.path.join(mnist_path, "train"), | |||||
| batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
| directory="./trained_ckpt_file/", | |||||
| config=config_ck) | |||||
| print("============== Starting SUPP Training ==============") | |||||
| model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
| dataset_sink_mode=False) | |||||
| print("============== Starting SUPP Testing ==============") | |||||
| ckpt_file_name = 'trained_ckpt_file/checkpoint_lenet-10_1875.ckpt' | |||||
| param_dict = load_checkpoint(ckpt_file_name) | |||||
| load_param_into_net(networks_l5, param_dict) | |||||
| ds_eval = generate_mnist_dataset(os.path.join(mnist_path, 'test'), | |||||
| batch_size=cfg.batch_size) | |||||
| acc = model_instance.eval(ds_eval, dataset_sink_mode=False) | |||||
| print("============== SUPP Accuracy: %s ==============", acc) | |||||
| if __name__ == "__main__": | |||||
| # This configure can run in pynative mode | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target=cfg.device_target) | |||||
| masklayers_lenet5 = [] # determine which layer should be masked | |||||
| masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, True, 10)) | |||||
| masklayers_lenet5.append(MaskLayerDes("conv2.weight", False, True, 150)) | |||||
| masklayers_lenet5.append(MaskLayerDes("fc1.weight", True, False, -1)) | |||||
| masklayers_lenet5.append(MaskLayerDes("fc2.weight", True, False, -1)) | |||||
| masklayers_lenet5.append(MaskLayerDes("fc3.weight", True, False, 50)) | |||||
| # do suppreess privacy train, with stronger privacy protection and better performance than Differential Privacy | |||||
| mnist_suppress_train(10, 3, 0.10, 60000, 1000, 0.95, 0.0, masklayers=masklayers_lenet5) # used | |||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in sup_privacy.py | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| mnist_cfg = edict({ | |||||
| 'num_classes': 10, # the number of classes of model's output | |||||
| 'epoch_size': 1, # training epochs | |||||
| 'batch_size': 32, # batch size for training | |||||
| 'image_height': 32, # the height of training samples | |||||
| 'image_width': 32, # the width of training samples | |||||
| 'save_checkpoint_steps': 1875, # the interval steps for saving checkpoint file of the model | |||||
| 'keep_checkpoint_max': 10, # the maximum number of checkpoint files would be saved | |||||
| 'device_target': 'Ascend', # device used | |||||
| 'data_path': './MNIST_unzip', # the path of training and testing data set | |||||
| 'dataset_sink_mode': False, # whether deliver all training data to device one time | |||||
| }) | |||||
| @@ -0,0 +1,27 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| This module provides Suppress Privacy feature to protect user privacy. | |||||
| """ | |||||
| from .mask_monitor.masker import SuppressMasker | |||||
| from .train.model import SuppressModel | |||||
| from .sup_ctrl.conctrl import SuppressPrivacyFactory | |||||
| from .sup_ctrl.conctrl import SuppressCtrl | |||||
| from .sup_ctrl.conctrl import MaskLayerDes | |||||
| __all__ = ['SuppressMasker', | |||||
| 'SuppressModel', | |||||
| 'SuppressPrivacyFactory', | |||||
| 'SuppressCtrl', | |||||
| 'MaskLayerDes'] | |||||
| @@ -0,0 +1,98 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| Masker module of suppress-based privacy.. | |||||
| """ | |||||
| from mindspore.train.callback import Callback | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| from mindarmour.utils._check_param import check_param_type | |||||
| from mindarmour.privacy.sup_privacy.train.model import SuppressModel | |||||
| from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = 'suppress masker' | |||||
| class SuppressMasker(Callback): | |||||
| """ | |||||
| Args: | |||||
| args (Union[int, float, numpy.ndarray, list, str]): Parameters | |||||
| used for creating a suppress privacy monitor. | |||||
| kwargs (Union[int, float, numpy.ndarray, list, str]): Keyword | |||||
| parameters used for creating a suppress privacy monitor. | |||||
| model (SuppressModel): SuppressModel instance. | |||||
| suppress_ctrl (SuppressCtrl): SuppressCtrl instance. | |||||
| Examples: | |||||
| networks_l5 = LeNet5() | |||||
| masklayers = [] | |||||
| masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | |||||
| suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", | |||||
| end_epoch=10, | |||||
| batch_num=(int)(10000/cfg.batch_size), | |||||
| start_epoch=3, | |||||
| mask_times=100, | |||||
| networks=networks_l5, | |||||
| lr=lr, | |||||
| sparse_end=0.90, | |||||
| sparse_start=0.0, | |||||
| mask_layers=masklayers) | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
| net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) | |||||
| model_instance = SuppressModel(network=networks_l5, | |||||
| loss_fn=net_loss, | |||||
| optimizer=net_opt, | |||||
| metrics={"Accuracy": Accuracy()}) | |||||
| model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
| ds_train = generate_mnist_dataset("./MNIST_unzip/train", | |||||
| batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
| directory="./trained_ckpt_file/", | |||||
| config=config_ck) | |||||
| model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
| dataset_sink_mode=False) | |||||
| """ | |||||
| def __init__(self, model=None, suppress_ctrl=None): | |||||
| super(SuppressMasker, self).__init__() | |||||
| self._model = check_param_type('model', model, SuppressModel) | |||||
| self._suppress_ctrl = check_param_type('suppress_ctrl', suppress_ctrl, SuppressCtrl) | |||||
| def step_end(self, run_context): | |||||
| """ | |||||
| Update mask matrix tensor used for SuppressModel instance. | |||||
| Args: | |||||
| run_context (RunContext): Include some information of the model. | |||||
| """ | |||||
| cb_params = run_context.original_args() | |||||
| cur_step = cb_params.cur_step_num | |||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
| if self._suppress_ctrl is not None and self._model.network_end is not None: | |||||
| self._suppress_ctrl.update_status(cb_params.cur_epoch_num, cur_step, cur_step_in_epoch) | |||||
| if not self._suppress_ctrl.mask_initialized: | |||||
| raise ValueError("Not initialize network!") | |||||
| if self._suppress_ctrl.to_do_mask: | |||||
| self._suppress_ctrl.update_mask(self._suppress_ctrl.networks, cur_step) | |||||
| LOGGER.info(TAG, "suppress update") | |||||
| elif not self._suppress_ctrl.to_do_mask and self._suppress_ctrl.mask_started: | |||||
| self._suppress_ctrl.reset_zeros() | |||||
| if cur_step_in_epoch % 100 == 1: | |||||
| self._suppress_ctrl.calc_theoretical_sparse_for_conv() | |||||
| _, _, _ = self._suppress_ctrl.calc_actual_sparse_for_conv( | |||||
| self._suppress_ctrl.networks) | |||||
| @@ -0,0 +1,640 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| control function of suppress-based privacy. | |||||
| """ | |||||
| import math | |||||
| import numpy as np | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.nn import Cell | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| from mindarmour.utils._check_param import check_int_positive, check_value_positive, \ | |||||
| check_value_non_negative, check_param_type | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = 'Suppression training.' | |||||
| class SuppressPrivacyFactory: | |||||
| """ Factory class of SuppressCtrl mechanisms""" | |||||
| def __init__(self): | |||||
| pass | |||||
| @staticmethod | |||||
| def create(policy="local_train", end_epoch=10, batch_num=2, start_epoch=3, mask_times=100, networks=None, | |||||
| lr=0.05, sparse_end=0.60, sparse_start=0.0, mask_layers=None): | |||||
| """ | |||||
| Args: | |||||
| policy (str): Training policy for suppress privacy training. "local_train" means local training. | |||||
| end_epoch (int): The last epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 . | |||||
| batch_num (int): The num of batch in an epoch, should be equal to num_samples/batch_size . | |||||
| start_epoch (int): The first epoch in suppress operations, 0 < start_epoch <= end_epoch <= 100 . | |||||
| mask_times (int): The num of suppress operations. | |||||
| networks (Cell): The training network. | |||||
| lr (float): Learning rate. | |||||
| sparse_end (float): The sparsity to reach, 0.0 <= sparse_start < sparse_end < 1.0 . | |||||
| sparse_start (float): The sparsity to start, 0.0 <= sparse_start < sparse_end < 1.0 . | |||||
| mask_layers (list): Description of the training network layers that need to be suppressed. | |||||
| Returns: | |||||
| SuppressCtrl, class of Suppress Privavy Mechanism. | |||||
| Examples: | |||||
| networks_l5 = LeNet5() | |||||
| masklayers = [] | |||||
| masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | |||||
| suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", | |||||
| end_epoch=10, | |||||
| batch_num=(int)(10000/cfg.batch_size), | |||||
| start_epoch=3, | |||||
| mask_times=100, | |||||
| networks=networks_l5, | |||||
| lr=lr, | |||||
| sparse_end=0.90, | |||||
| sparse_start=0.0, | |||||
| mask_layers=masklayers) | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
| net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) | |||||
| model_instance = SuppressModel(network=networks_l5, | |||||
| loss_fn=net_loss, | |||||
| optimizer=net_opt, | |||||
| metrics={"Accuracy": Accuracy()}) | |||||
| model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
| ds_train = generate_mnist_dataset("./MNIST_unzip/train", | |||||
| batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
| directory="./trained_ckpt_file/", | |||||
| config=config_ck) | |||||
| model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
| dataset_sink_mode=False) | |||||
| """ | |||||
| if policy == "local_train": | |||||
| return SuppressCtrl(networks, end_epoch, batch_num, start_epoch, mask_times, lr, sparse_end, | |||||
| sparse_start, mask_layers) | |||||
| msg = "Only local training is supported now, federal training will be supported " \ | |||||
| "in the future. But got {}.".format(policy) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| class SuppressCtrl(Cell): | |||||
| """ | |||||
| Args: | |||||
| networks (Cell): The training network. | |||||
| end_epoch (int): The last epoch in suppress operations. | |||||
| batch_num (int): The num of grad operation in an epoch. | |||||
| mask_start_epoch (int): The first epoch in suppress operations. | |||||
| mask_times (int): The num of suppress operations. | |||||
| lr (Union[float, int]): Learning rate. | |||||
| sparse_end (Union[float, int]): The sparsity to reach. | |||||
| sparse_start (float): The sparsity to start. | |||||
| mask_layers (list): Description of those layers that need to be suppressed. | |||||
| """ | |||||
| def __init__(self, networks, end_epoch, batch_num, mask_start_epoch=3, mask_times=500, lr=0.05, | |||||
| sparse_end=0.60, | |||||
| sparse_start=0.0, | |||||
| mask_layers=None): | |||||
| super(SuppressCtrl, self).__init__() | |||||
| self.networks = check_param_type('networks', networks, Cell) | |||||
| self.mask_end_epoch = check_int_positive('end_epoch', end_epoch) | |||||
| self.batch_num = check_int_positive('batch_num', batch_num) | |||||
| self.mask_start_epoch = check_int_positive('mask_start_epoch', mask_start_epoch) | |||||
| self.mask_times = check_int_positive('mask_times', mask_times) | |||||
| self.lr = check_value_positive('lr', lr) | |||||
| self.sparse_end = check_value_non_negative('sparse_end', sparse_end) | |||||
| self.sparse_start = check_value_non_negative('sparse_start', sparse_start) | |||||
| self.mask_layers = check_param_type('mask_layers', mask_layers, list) | |||||
| self.weight_lower_bound = 0.005 # all network weight will be larger than this value | |||||
| self.sparse_vibra = 0.02 # the sparsity may have certain range of variations | |||||
| self.sparse_valid_max_weight = 0.20 # if max network weight is less than this value, suppress operation stop temporarily | |||||
| self.add_noise_thd = 0.50 # if network weight is more than this value, noise is forced | |||||
| self.noise_volume = 0.01 # noise volume 0.01 | |||||
| self.base_ground_thd = 0.0000001 # if network weight is less than this value, will be considered as 0 | |||||
| self.model = None # SuppressModel instance | |||||
| self.grads_mask_list = [] # list for Grad Mask Matrix tensor | |||||
| self.de_weight_mask_list = [] # list for weight Mask Matrix tensor | |||||
| self.to_do_mask = False # the flag means suppress operation is toggled immediately | |||||
| self.mask_started = False # the flag means suppress operation has been started | |||||
| self.mask_start_step = 0 # suppress operation is actually started at this step | |||||
| self.mask_prev_step = 0 # previous suppress operation is done at this step | |||||
| self.cur_sparse = 0.0 # current sparsity to which one suppress will get | |||||
| self.mask_all_steps = (self.mask_end_epoch-mask_start_epoch+1)*batch_num # the amount of step contained in all suppress operation | |||||
| self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation | |||||
| self.mask_initialized = False # flag means the initialization is done | |||||
| if self.mask_start_epoch > self.mask_end_epoch: | |||||
| msg = "start_epoch error: {}".format(self.mask_start_epoch) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| if self.mask_end_epoch > 100: | |||||
| msg = "end_epoch error: {}".format(self.mask_end_epoch) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| if self.mask_step_interval < 0: | |||||
| msg = "step_interval error: {}".format(self.mask_step_interval) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| if self.sparse_end > 1.00 or self.sparse_end <= 0: | |||||
| msg = "sparse_end error: {}".format(self.sparse_end) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| if self.sparse_start >= self.sparse_end: | |||||
| msg = "sparse_start error: {}".format(self.sparse_start) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| if mask_layers is not None: | |||||
| mask_layer_id = 0 | |||||
| for one_mask_layer in mask_layers: | |||||
| if not isinstance(one_mask_layer, MaskLayerDes): | |||||
| msg = "mask_layer instance error!" | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| layer_name = one_mask_layer.layer_name | |||||
| mask_layer_id2 = 0 | |||||
| for one_mask_layer_2 in mask_layers: | |||||
| if mask_layer_id != mask_layer_id2 and layer_name in one_mask_layer_2.layer_name: | |||||
| msg = "mask_layers repeat item : {} in {} and {}".format(layer_name, | |||||
| mask_layer_id, | |||||
| mask_layer_id2) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| mask_layer_id2 = mask_layer_id2 + 1 | |||||
| mask_layer_id = mask_layer_id + 1 | |||||
| if networks is not None: | |||||
| m = 0 | |||||
| for layer in networks.get_parameters(expand=True): | |||||
| one_mask_layer = None | |||||
| if mask_layers is not None: | |||||
| one_mask_layer = get_one_mask_layer(mask_layers, layer.name) | |||||
| if one_mask_layer is not None and not one_mask_layer.inited: | |||||
| one_mask_layer.inited = True | |||||
| shape = P.Shape()(layer) | |||||
| mul_mask_array = np.ones(shape, dtype=np.float32) | |||||
| grad_mask_cell = GradMaskInCell(mul_mask_array, | |||||
| one_mask_layer.is_add_noise, | |||||
| one_mask_layer.is_lower_clip, | |||||
| one_mask_layer.min_num, | |||||
| one_mask_layer.upper_bound) | |||||
| grad_mask_cell.mask_able = True | |||||
| self.grads_mask_list.append(grad_mask_cell) | |||||
| add_mask_array = np.zeros(shape, dtype=np.float32) | |||||
| de_weight_cell = DeWeightInCell(add_mask_array) | |||||
| de_weight_cell.mask_able = True | |||||
| self.de_weight_mask_list.append(de_weight_cell) | |||||
| msg = "do mask {}, {}".format(m, one_mask_layer.layer_name) | |||||
| LOGGER.info(TAG, msg) | |||||
| elif one_mask_layer is not None and one_mask_layer.inited: | |||||
| msg = "repeated match masked setting {}=>{}.".format(one_mask_layer.layer_name, layer.name) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| else: | |||||
| shape = np.shape([1]) | |||||
| mul_mask_array = np.ones(shape, dtype=np.float32) | |||||
| grad_mask_cell = GradMaskInCell(mul_mask_array, False, False, -1) | |||||
| grad_mask_cell.mask_able = False | |||||
| self.grads_mask_list.append(grad_mask_cell) | |||||
| add_mask_array = np.zeros(shape, dtype=np.float32) | |||||
| de_weight_cell = DeWeightInCell(add_mask_array) | |||||
| de_weight_cell.mask_able = False | |||||
| self.de_weight_mask_list.append(de_weight_cell) | |||||
| m += 1 | |||||
| self.mask_initialized = True | |||||
| msg = "init SuppressCtrl by networks" | |||||
| LOGGER.info(TAG, msg) | |||||
| msg = "complete init mask for lenet5.step_interval: {}".format(self.mask_step_interval) | |||||
| LOGGER.info(TAG, msg) | |||||
| for one_mask_layer in mask_layers: | |||||
| if not one_mask_layer.inited: | |||||
| msg = "can't match this mask layer: {} ".format(one_mask_layer.layer_name) | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| def update_status(self, cur_epoch, cur_step, cur_step_in_epoch): | |||||
| """ | |||||
| Update the suppress operation status. | |||||
| Args: | |||||
| cur_epoch (int): Current epoch of the whole training process. | |||||
| cur_step (int): Current step of the whole training process. | |||||
| cur_step_in_epoch (int): Current step of the current epoch. | |||||
| """ | |||||
| if not self.mask_initialized: | |||||
| self.mask_started = False | |||||
| elif (self.mask_start_epoch <= cur_epoch <= self.mask_end_epoch) or self.mask_started: | |||||
| if not self.mask_started: | |||||
| self.mask_started = True | |||||
| self.mask_start_step = cur_step | |||||
| if cur_step >= (self.mask_prev_step + self.mask_step_interval): | |||||
| self.mask_prev_step = cur_step | |||||
| self.to_do_mask = True | |||||
| # execute the last suppression operation | |||||
| elif cur_epoch == self.mask_end_epoch and cur_step_in_epoch == self.batch_num-2: | |||||
| self.mask_prev_step = cur_step | |||||
| self.to_do_mask = True | |||||
| else: | |||||
| self.to_do_mask = False | |||||
| else: | |||||
| self.to_do_mask = False | |||||
| self.mask_started = False | |||||
| def update_mask(self, networks, cur_step): | |||||
| """ | |||||
| Update add mask arrays and multiply mask arrays of network layers. | |||||
| Args: | |||||
| networks (Cell): The training network. | |||||
| cur_step (int): Current epoch of the whole training process. | |||||
| """ | |||||
| if self.sparse_end <= 0.0: | |||||
| return | |||||
| self.cur_sparse = self.sparse_end +\ | |||||
| (self.sparse_start - self.sparse_end)*\ | |||||
| math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3) | |||||
| m = 0 | |||||
| for layer in networks.get_parameters(expand=True): | |||||
| if self.grads_mask_list[m].mask_able: | |||||
| weight_array = layer.data.asnumpy() | |||||
| weight_avg = np.mean(weight_array) | |||||
| weight_array_flat = weight_array.flatten() | |||||
| weight_array_flat_abs = np.abs(weight_array_flat) | |||||
| weight_abs_avg = np.mean(weight_array_flat_abs) | |||||
| weight_array_flat_abs.sort() | |||||
| len_array = weight_array.size | |||||
| weight_abs_max = np.max(weight_array_flat_abs) | |||||
| if m == 0 and weight_abs_max < self.sparse_valid_max_weight: | |||||
| msg = "give up this masking .." | |||||
| LOGGER.info(TAG, msg) | |||||
| return | |||||
| if self.grads_mask_list[m].min_num > 0: | |||||
| sparse_weight_thd, _, actual_stop_pos = self.calc_sparse_thd(weight_array_flat_abs, | |||||
| self.cur_sparse, m) | |||||
| else: | |||||
| actual_stop_pos = int(len_array * self.cur_sparse) | |||||
| sparse_weight_thd = weight_array_flat_abs[actual_stop_pos] | |||||
| self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, m) | |||||
| msg = "{} len={}, sparse={}, current sparse thd={}, max={}, avg={}, avg_abs={} \n".format( | |||||
| layer.name, len_array, actual_stop_pos/len_array, sparse_weight_thd, | |||||
| weight_abs_max, weight_avg, weight_abs_avg) | |||||
| LOGGER.info(TAG, msg) | |||||
| m = m + 1 | |||||
| def update_mask_layer(self, weight_array_flat, sparse_weight_thd, sparse_stop_pos, weight_abs_max, layer_index): | |||||
| """ | |||||
| Update add mask arrays and multiply mask arrays of one single layer. | |||||
| Args: | |||||
| weight_array (numpy.ndarray): The weight array of layer's parameters. | |||||
| sparse_weight_thd (float): The weight threshold of sparse operation. | |||||
| sparse_stop_pos (int): The maximum number of elements to be suppressed. | |||||
| weight_abs_max (float): The maximum absolute value of weights. | |||||
| layer_index (int): The index of target layer. | |||||
| """ | |||||
| grad_mask_cell = self.grads_mask_list[layer_index] | |||||
| mul_mask_array_flat = grad_mask_cell.mul_mask_array_flat | |||||
| de_weight_cell = self.de_weight_mask_list[layer_index] | |||||
| add_mask_array_flat = de_weight_cell.add_mask_array_flat | |||||
| min_num = grad_mask_cell.min_num | |||||
| is_add_noise = grad_mask_cell.is_add_noise | |||||
| is_lower_clip = grad_mask_cell.is_lower_clip | |||||
| upper_bound = grad_mask_cell.upper_bound | |||||
| if not self.grads_mask_list[layer_index].mask_able: | |||||
| return | |||||
| m = 0 | |||||
| n = 0 | |||||
| p = 0 | |||||
| q = 0 | |||||
| # add noise on weights if not masking or clipping. | |||||
| weight_noise_bound = min(self.add_noise_thd, max(self.noise_volume*10, weight_abs_max*0.75)) | |||||
| for i in range(0, weight_array_flat.size): | |||||
| if abs(weight_array_flat[i]) <= sparse_weight_thd: | |||||
| if m < weight_array_flat.size - min_num and m < sparse_stop_pos: | |||||
| # to mask | |||||
| mul_mask_array_flat[i] = 0.0 | |||||
| add_mask_array_flat[i] = weight_array_flat[i] / self.lr | |||||
| m = m + 1 | |||||
| else: | |||||
| # not mask | |||||
| if weight_array_flat[i] > 0.0: | |||||
| add_mask_array_flat[i] = (weight_array_flat[i] - self.weight_lower_bound) / self.lr | |||||
| else: | |||||
| add_mask_array_flat[i] = (weight_array_flat[i] + self.weight_lower_bound) / self.lr | |||||
| p = p + 1 | |||||
| elif is_lower_clip and abs(weight_array_flat[i]) <= \ | |||||
| self.weight_lower_bound and sparse_weight_thd > self.weight_lower_bound*0.5: | |||||
| # not mask | |||||
| mul_mask_array_flat[i] = 1.0 | |||||
| if weight_array_flat[i] > 0.0: | |||||
| add_mask_array_flat[i] = (weight_array_flat[i] - self.weight_lower_bound) / self.lr | |||||
| else: | |||||
| add_mask_array_flat[i] = (weight_array_flat[i] + self.weight_lower_bound) / self.lr | |||||
| p = p + 1 | |||||
| elif abs(weight_array_flat[i]) > upper_bound: | |||||
| mul_mask_array_flat[i] = 1.0 | |||||
| if weight_array_flat[i] > 0.0: | |||||
| add_mask_array_flat[i] = (weight_array_flat[i] - upper_bound) / self.lr | |||||
| else: | |||||
| add_mask_array_flat[i] = (weight_array_flat[i] + upper_bound) / self.lr | |||||
| n = n + 1 | |||||
| else: | |||||
| # not mask | |||||
| mul_mask_array_flat[i] = 1.0 | |||||
| if is_add_noise and abs(weight_array_flat[i]) > weight_noise_bound > 0.0: | |||||
| # add noise | |||||
| add_mask_array_flat[i] = np.random.uniform(-self.noise_volume, self.noise_volume) / self.lr | |||||
| q = q + 1 | |||||
| else: | |||||
| add_mask_array_flat[i] = 0.0 | |||||
| grad_mask_cell.update() | |||||
| de_weight_cell.update() | |||||
| msg = "Dimension of mask tensor is {}D, which located in the {}-th layer of the network. \n The number of " \ | |||||
| "suppressed elements, max-clip elements, min-clip elements and noised elements are {}, {}, {}, {}"\ | |||||
| .format(len(grad_mask_cell.mul_mask_array_shape), layer_index, m, n, p, q) | |||||
| LOGGER.info(TAG, msg) | |||||
| def calc_sparse_thd(self, array_flat, sparse_value, layer_index): | |||||
| """ | |||||
| Calculate the suppression threshold of one weight array. | |||||
| Args: | |||||
| array_flat (numpy.ndarray): The flattened weight array. | |||||
| sparse_value (float): The target sparse value of weight array. | |||||
| Returns: | |||||
| - float, the sparse threshold of this array. | |||||
| - int, the number of weight elements to be suppressed. | |||||
| - int, the larger number of weight elements to be suppressed. | |||||
| """ | |||||
| size = len(array_flat) | |||||
| sparse_max_thd = 1.0 - min(self.grads_mask_list[layer_index].min_num, size) / size | |||||
| pos = int(size*min(sparse_max_thd, sparse_value)) | |||||
| thd = array_flat[pos] | |||||
| farther_stop_pos = int(size*min(sparse_max_thd, max(0, sparse_value + self.sparse_vibra / 2.0))) | |||||
| return thd, pos, farther_stop_pos | |||||
| def reset_zeros(self): | |||||
| """ | |||||
| Set add mask arrays to be zero. | |||||
| """ | |||||
| for de_weight_cell in self.de_weight_mask_list: | |||||
| de_weight_cell.reset_zeros() | |||||
| def calc_theoretical_sparse_for_conv(self): | |||||
| """ | |||||
| Compute actually sparsity of mask matrix for conv1 layer and conv2 layer. | |||||
| """ | |||||
| array_mul_mask_flat_conv1 = self.grads_mask_list[0].mul_mask_array_flat | |||||
| array_mul_mask_flat_conv2 = self.grads_mask_list[1].mul_mask_array_flat | |||||
| sparse = 0.0 | |||||
| sparse_value_1 = 0.0 | |||||
| sparse_value_2 = 0.0 | |||||
| full = 0.0 | |||||
| full_conv1 = 0.0 | |||||
| full_conv2 = 0.0 | |||||
| for i in range(0, array_mul_mask_flat_conv1.size): | |||||
| full += 1.0 | |||||
| full_conv1 += 1.0 | |||||
| if array_mul_mask_flat_conv1[i] <= 0.0: | |||||
| sparse += 1.0 | |||||
| sparse_value_1 += 1.0 | |||||
| for i in range(0, array_mul_mask_flat_conv2.size): | |||||
| full = full + 1.0 | |||||
| full_conv2 = full_conv2 + 1.0 | |||||
| if array_mul_mask_flat_conv2[i] <= 0.0: | |||||
| sparse = sparse + 1.0 | |||||
| sparse_value_2 += 1.0 | |||||
| sparse = sparse/full | |||||
| sparse_value_1 = sparse_value_1/full_conv1 | |||||
| sparse_value_2 = sparse_value_2/full_conv2 | |||||
| msg = "conv sparse mask={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2) | |||||
| LOGGER.info(TAG, msg) | |||||
| return sparse, sparse_value_1, sparse_value_2 | |||||
| def calc_actual_sparse_for_conv(self, networks): | |||||
| """ | |||||
| Compute actually sparsity of network for conv1 layer and conv2 layer. | |||||
| Args: | |||||
| networks (Cell): The training network. | |||||
| """ | |||||
| sparse = 0.0 | |||||
| sparse_value_1 = 0.0 | |||||
| sparse_value_2 = 0.0 | |||||
| full = 0.0 | |||||
| full_conv1 = 0.0 | |||||
| full_conv2 = 0.0 | |||||
| array_cur_conv1 = np.ones(np.shape([1]), dtype=np.float32) | |||||
| array_cur_conv2 = np.ones(np.shape([1]), dtype=np.float32) | |||||
| for layer in networks.get_parameters(expand=True): | |||||
| if "conv1.weight" in layer.name: | |||||
| array_cur_conv1 = layer.data.asnumpy() | |||||
| if "conv2.weight" in layer.name: | |||||
| array_cur_conv2 = layer.data.asnumpy() | |||||
| array_mul_mask_flat_conv1 = array_cur_conv1.flatten() | |||||
| array_mul_mask_flat_conv2 = array_cur_conv2.flatten() | |||||
| for i in range(0, array_mul_mask_flat_conv1.size): | |||||
| full += 1.0 | |||||
| full_conv1 += 1.0 | |||||
| if abs(array_mul_mask_flat_conv1[i]) <= self.base_ground_thd: | |||||
| sparse += 1.0 | |||||
| sparse_value_1 += 1.0 | |||||
| for i in range(0, array_mul_mask_flat_conv2.size): | |||||
| full = full + 1.0 | |||||
| full_conv2 = full_conv2 + 1.0 | |||||
| if abs(array_mul_mask_flat_conv2[i]) <= self.base_ground_thd: | |||||
| sparse = sparse + 1.0 | |||||
| sparse_value_2 += 1.0 | |||||
| sparse = sparse / full | |||||
| sparse_value_1 = sparse_value_1 / full_conv1 | |||||
| sparse_value_2 = sparse_value_2 / full_conv2 | |||||
| msg = "conv sparse fact={}, sparse_1={}, sparse_2={}".format(sparse, sparse_value_1, sparse_value_2) | |||||
| LOGGER.info(TAG, msg) | |||||
| return sparse, sparse_value_1, sparse_value_2 | |||||
| def calc_actual_sparse_for_fc1(self, networks): | |||||
| self.calc_actual_sparse_for_layer(networks, "fc1.weight") | |||||
| def calc_actual_sparse_for_layer(self, networks, layer_name): | |||||
| """ | |||||
| Compute actually sparsity of one network layer | |||||
| Args: | |||||
| networks (Cell): The training network. | |||||
| layer_name (str): The name of target layer. | |||||
| """ | |||||
| check_param_type('networks', networks, Cell) | |||||
| check_param_type('layer_name', layer_name, str) | |||||
| sparse = 0.0 | |||||
| full = 0.0 | |||||
| array_cur = None | |||||
| for layer in networks.get_parameters(expand=True): | |||||
| if layer_name in layer.name: | |||||
| array_cur = layer.data.asnumpy() | |||||
| if array_cur is None: | |||||
| msg = "no such layer to calc sparse: {} ".format(layer_name) | |||||
| LOGGER.info(TAG, msg) | |||||
| return | |||||
| array_cur_flat = array_cur.flatten() | |||||
| for i in range(0, array_cur_flat.size): | |||||
| full += 1.0 | |||||
| if abs(array_cur_flat[i]) <= self.base_ground_thd: | |||||
| sparse += 1.0 | |||||
| sparse = sparse / full | |||||
| msg = "{} sparse fact={} ".format(layer_name, sparse) | |||||
| LOGGER.info(TAG, msg) | |||||
| def get_one_mask_layer(mask_layers, layer_name): | |||||
| """ | |||||
| Returns the layer definitions that need to be suppressed. | |||||
| Args: | |||||
| mask_layers (list): The layers that need to be suppressed. | |||||
| layer_name (str): The name of target layer. | |||||
| Returns: | |||||
| Union[MaskLayerDes, None], the layer definitions that need to be suppressed. | |||||
| """ | |||||
| for each_mask_layer in mask_layers: | |||||
| if each_mask_layer.layer_name in layer_name: | |||||
| return each_mask_layer | |||||
| return None | |||||
| class MaskLayerDes: | |||||
| """ | |||||
| Describe the layer that need to be suppressed. | |||||
| Args: | |||||
| layer_name (str): Layer name, get the name of one layer as following: | |||||
| for layer in networks.get_parameters(expand=True): | |||||
| if layer.name == "conv": ... | |||||
| is_add_noise (bool): If True, the weight of this layer can add noise. | |||||
| If False, the weight of this layer can not add noise. | |||||
| is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value. | |||||
| If False, the weights of this layer won't be clipped. | |||||
| min_num (int): The number of weights left that not be suppressed, which need to be greater than 0. | |||||
| upper_bound (float): max value of weight in this layer, default value is 1.20 . | |||||
| """ | |||||
| def __init__(self, layer_name, is_add_noise, is_lower_clip, min_num, upper_bound=1.20): | |||||
| self.layer_name = check_param_type('layer_name', layer_name, str) | |||||
| self.is_add_noise = check_param_type('is_add_noise', is_add_noise, bool) | |||||
| self.is_lower_clip = check_param_type('is_lower_clip', is_lower_clip, bool) | |||||
| self.min_num = check_param_type('min_num', min_num, int) | |||||
| self.upper_bound = check_value_positive('upper_bound', upper_bound) | |||||
| self.inited = False | |||||
| class GradMaskInCell(Cell): | |||||
| """ | |||||
| Define the mask matrix for gradients masking. | |||||
| Args: | |||||
| array (numpy.ndarray): The mask array. | |||||
| is_add_noise (bool): If True, the weight of this layer can add noise. | |||||
| If False, the weight of this layer can not add noise. | |||||
| is_lower_clip (bool): If true, the weights of this layer would be clipped to greater than an lower bound value. | |||||
| If False, the weights of this layer won't be clipped. | |||||
| min_num (int): The number of weights left that not be suppressed, which need to be greater than 0. | |||||
| upper_bound (float): max value of weight in this layer, default value is 1.20 | |||||
| """ | |||||
| def __init__(self, array, is_add_noise, is_lower_clip, min_num, upper_bound=1.20): | |||||
| super(GradMaskInCell, self).__init__() | |||||
| self.mul_mask_array_shape = array.shape | |||||
| mul_mask_array = array.copy() | |||||
| self.mul_mask_array_flat = mul_mask_array.flatten() | |||||
| self.mul_mask_tensor = Tensor(array, mstype.float32) | |||||
| self.mask_able = False | |||||
| self.is_add_noise = is_add_noise | |||||
| self.is_lower_clip = is_lower_clip | |||||
| self.min_num = min_num | |||||
| self.upper_bound = check_value_positive('upper_bound', upper_bound) | |||||
| def construct(self): | |||||
| """ | |||||
| Return the mask matrix for optimization. | |||||
| """ | |||||
| return self.mask_able, self.mul_mask_tensor | |||||
| def update(self): | |||||
| """ | |||||
| Update the mask tensor. | |||||
| """ | |||||
| self.mul_mask_tensor = Tensor(self.mul_mask_array_flat.reshape(self.mul_mask_array_shape), mstype.float32) | |||||
| class DeWeightInCell(Cell): | |||||
| """ | |||||
| Define the mask matrix for de-weight masking. | |||||
| Args: | |||||
| array (numpy.ndarray): The mask array. | |||||
| """ | |||||
| def __init__(self, array): | |||||
| super(DeWeightInCell, self).__init__() | |||||
| self.add_mask_array_shape = array.shape | |||||
| add_mask_array = array.copy() | |||||
| self.add_mask_array_flat = add_mask_array.flatten() | |||||
| self.add_mask_tensor = Tensor(array, mstype.float32) | |||||
| self.mask_able = False | |||||
| self.zero_mask_tensor = Tensor(np.zeros(array.shape, np.float32), mstype.float32) | |||||
| self.just_update = -1.0 | |||||
| def construct(self): | |||||
| """ | |||||
| Return the mask matrix for optimization. | |||||
| """ | |||||
| if self.just_update > 0.0: | |||||
| return self.mask_able, self.add_mask_tensor | |||||
| return self.mask_able, self.zero_mask_tensor | |||||
| def update(self): | |||||
| """ | |||||
| Update the mask tensor. | |||||
| """ | |||||
| self.just_update = 1.0 | |||||
| self.add_mask_tensor = Tensor(self.add_mask_array_flat.reshape(self.add_mask_array_shape), mstype.float32) | |||||
| def reset_zeros(self): | |||||
| """ | |||||
| Make the de-weight operation expired. | |||||
| """ | |||||
| self.just_update = -1.0 | |||||
| @@ -0,0 +1,325 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| suppress-basd privacy model. | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| from mindspore.train.model import Model | |||||
| from mindspore._checkparam import Validator as validator | |||||
| from mindspore._checkparam import Rel | |||||
| from mindspore.train.amp import _config_level | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell | |||||
| from mindspore.parallel._utils import _get_parallel_mode | |||||
| from mindspore.train.model import ParallelMode | |||||
| from mindspore.train.amp import _do_keep_batchnorm_fp32 | |||||
| from mindspore.train.amp import _add_loss_network | |||||
| from mindspore import nn | |||||
| from mindspore import context | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.parallel._utils import _get_gradients_mean | |||||
| from mindspore.parallel._utils import _get_device_num | |||||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
| from mindspore.nn import Cell | |||||
| from mindarmour.utils._check_param import check_param_type | |||||
| from mindarmour.utils.logger import LogUtil | |||||
| from mindarmour.privacy.sup_privacy.sup_ctrl.conctrl import SuppressCtrl | |||||
| LOGGER = LogUtil.get_instance() | |||||
| TAG = 'Mask model' | |||||
| GRADIENT_CLIP_TYPE = 1 | |||||
| _grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
| _reciprocal = P.Reciprocal() | |||||
| @_grad_scale.register("Tensor", "Tensor") | |||||
| def tensor_grad_scale(scale, grad): | |||||
| """ grad scaling """ | |||||
| return grad*F.cast(_reciprocal(scale), F.dtype(grad)) | |||||
| class SuppressModel(Model): | |||||
| """ | |||||
| This class is overload mindspore.train.model.Model. | |||||
| Args: | |||||
| network (Cell): The training network. | |||||
| loss_fn (Cell): Computes softmax cross entropy between logits and labels. | |||||
| optimizer (Optimizer): optimizer instance. | |||||
| metrics (Union[dict, set]): Calculates the accuracy for classification and multilabel data. | |||||
| kwargs: Keyword parameters used for creating a suppress model. | |||||
| Examples: | |||||
| networks_l5 = LeNet5() | |||||
| masklayers = [] | |||||
| masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | |||||
| suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", | |||||
| end_epoch=10, | |||||
| batch_num=(int)(10000/cfg.batch_size), | |||||
| start_epoch=3, | |||||
| mask_times=100, | |||||
| networks=networks_l5, | |||||
| lr=lr, | |||||
| sparse_end=0.90, | |||||
| sparse_start=0.0, | |||||
| mask_layers=masklayers) | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
| net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) | |||||
| model_instance = SuppressModel(network=networks_l5, | |||||
| loss_fn=net_loss, | |||||
| optimizer=net_opt, | |||||
| metrics={"Accuracy": Accuracy()}) | |||||
| model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
| ds_train = generate_mnist_dataset("./MNIST_unzip/train", | |||||
| batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
| directory="./trained_ckpt_file/", | |||||
| config=config_ck) | |||||
| model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
| dataset_sink_mode=False) | |||||
| """ | |||||
| def __init__(self, | |||||
| network=None, | |||||
| **kwargs): | |||||
| check_param_type('networks', network, Cell) | |||||
| self.network_end = None | |||||
| self._train_one_step = None | |||||
| super(SuppressModel, self).__init__(network, **kwargs) | |||||
| def link_suppress_ctrl(self, suppress_pri_ctrl): | |||||
| """ | |||||
| Link self and SuppressCtrl instance. | |||||
| Args: | |||||
| suppress_pri_ctrl (SuppressCtrl): SuppressCtrl instance. | |||||
| """ | |||||
| check_param_type('suppress_pri_ctrl', suppress_pri_ctrl, Cell) | |||||
| if not isinstance(suppress_pri_ctrl, SuppressCtrl): | |||||
| msg = "SuppressCtrl instance error!" | |||||
| LOGGER.error(TAG, msg) | |||||
| raise ValueError(msg) | |||||
| suppress_pri_ctrl.model = self | |||||
| if self._train_one_step is not None: | |||||
| self._train_one_step.link_suppress_ctrl(suppress_pri_ctrl) | |||||
| def _build_train_network(self): | |||||
| """Build train network""" | |||||
| network = self._network | |||||
| ms_mode = context.get_context("mode") | |||||
| if ms_mode != context.PYNATIVE_MODE: | |||||
| raise ValueError("Only PYNATIVE_MODE is supported for suppress privacy now.") | |||||
| if self._optimizer: | |||||
| network = self._amp_build_train_network(network, | |||||
| self._optimizer, | |||||
| self._loss_fn, | |||||
| level=self._amp_level, | |||||
| keep_batchnorm_fp32=self._keep_bn_fp32) | |||||
| else: | |||||
| raise ValueError("_optimizer is none") | |||||
| self._train_one_step = network | |||||
| if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, | |||||
| ParallelMode.AUTO_PARALLEL): | |||||
| network.set_auto_parallel() | |||||
| self.network_end = self._train_one_step.network | |||||
| return network | |||||
| def _amp_build_train_network(self, network, optimizer, loss_fn=None, | |||||
| level='O0', **kwargs): | |||||
| """ | |||||
| Build the mixed precision training cell automatically. | |||||
| Args: | |||||
| network (Cell): Definition of the network. | |||||
| loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, | |||||
| the `network` should have the loss inside. Default: None. | |||||
| optimizer (Optimizer): Optimizer to update the Parameter. | |||||
| level (str): Supports [O0, O2]. Default: "O0". | |||||
| - O0: Do not change. | |||||
| - O2: Cast network to float16, keep batchnorm and `loss_fn` | |||||
| (if set) run in float32, using dynamic loss scale. | |||||
| cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` | |||||
| or `mstype.float32`. If set to `mstype.float16`, use `float16` | |||||
| mode to train. If set, overwrite the level setting. | |||||
| keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, | |||||
| overwrite the level setting. | |||||
| loss_scale_manager (Union[None, LossScaleManager]): If None, not | |||||
| scale the loss, or else scale the loss by LossScaleManager. | |||||
| If set, overwrite the level setting. | |||||
| """ | |||||
| validator.check_value_type('network', network, nn.Cell, None) | |||||
| validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) | |||||
| validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None) | |||||
| self._check_kwargs(kwargs) | |||||
| config = dict(_config_level[level], **kwargs) | |||||
| config = edict(config) | |||||
| if config.cast_model_type == mstype.float16: | |||||
| network.to_float(mstype.float16) | |||||
| if config.keep_batchnorm_fp32: | |||||
| _do_keep_batchnorm_fp32(network) | |||||
| if loss_fn: | |||||
| network = _add_loss_network(network, loss_fn, | |||||
| config.cast_model_type) | |||||
| if _get_parallel_mode() in ( | |||||
| ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
| network = _VirtualDatasetCell(network) | |||||
| loss_scale = 1.0 | |||||
| if config.loss_scale_manager is not None: | |||||
| print("----model config have loss scale manager !") | |||||
| network = TrainOneStepCell(network, optimizer, sens=loss_scale).set_train() | |||||
| return network | |||||
| class _TupleAdd(nn.Cell): | |||||
| """ | |||||
| Add two tuple of data. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(_TupleAdd, self).__init__() | |||||
| self.add = P.TensorAdd() | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, input1, input2): | |||||
| """Add two tuple of data.""" | |||||
| out = self.hyper_map(self.add, input1, input2) | |||||
| return out | |||||
| class _TupleMul(nn.Cell): | |||||
| """ | |||||
| Mul two tuple of data. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(_TupleMul, self).__init__() | |||||
| self.mul = P.Mul() | |||||
| self.hyper_map = C.HyperMap() | |||||
| def construct(self, input1, input2): | |||||
| """Add two tuple of data.""" | |||||
| out = self.hyper_map(self.mul, input1, input2) | |||||
| #print(out) | |||||
| return out | |||||
| # come from nn.cell_wrapper.TrainOneStepCell | |||||
| class TrainOneStepCell(Cell): | |||||
| r""" | |||||
| Network training package class. | |||||
| Wraps the network with an optimizer. The resulting Cell be trained with input data and label. | |||||
| Backward graph will be created in the construct function to do parameter updating. Different | |||||
| parallel modes are available to run the training. | |||||
| Args: | |||||
| network (Cell): The training network. | |||||
| optimizer (Cell): Optimizer for updating the weights. | |||||
| sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. | |||||
| Inputs: | |||||
| - **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||||
| - **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`. | |||||
| Outputs: | |||||
| Tensor, a scalar Tensor with shape :math:`()`. | |||||
| """ | |||||
| def __init__(self, network, optimizer, sens=1.0): | |||||
| super(TrainOneStepCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_grad() | |||||
| self.network.add_flags(defer_inline=True) | |||||
| self.weights = optimizer.parameters | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) # for mindspore 0.7x | |||||
| self.sens = sens | |||||
| self.reducer_flag = False | |||||
| self.grad_reducer = None | |||||
| self._tuple_add = _TupleAdd() | |||||
| self._tuple_mul = _TupleMul() | |||||
| parallel_mode = _get_parallel_mode() | |||||
| if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||||
| self.reducer_flag = True | |||||
| if self.reducer_flag: | |||||
| mean = _get_gradients_mean() # for mindspore 0.7x | |||||
| degree = _get_device_num() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| self.do_privacy = False | |||||
| self.grad_mask_tup = () # tuple containing grad_mask(cell) | |||||
| self.de_weight_tup = () # tuple containing de_weight(cell) | |||||
| self._suppress_pri_ctrl = None | |||||
| def link_suppress_ctrl(self, suppress_pri_ctrl): | |||||
| """ | |||||
| Set Suppress Mask for grad_mask_tup and de_weight_tup. | |||||
| Args: | |||||
| suppress_pri_ctrl (SuppressCtrl): SuppressCtrl instance. | |||||
| """ | |||||
| self._suppress_pri_ctrl = suppress_pri_ctrl | |||||
| if self._suppress_pri_ctrl.grads_mask_list: | |||||
| for grad_mask_cell in self._suppress_pri_ctrl.grads_mask_list: | |||||
| self.grad_mask_tup += (grad_mask_cell,) | |||||
| self.do_privacy = True | |||||
| for de_weight_cell in self._suppress_pri_ctrl.de_weight_mask_list: | |||||
| self.de_weight_tup += (de_weight_cell,) | |||||
| else: | |||||
| self.do_privacy = False | |||||
| def construct(self, data, label): | |||||
| """ | |||||
| Construct a compute flow. | |||||
| """ | |||||
| weights = self.weights | |||||
| loss = self.network(data, label) | |||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||||
| grads = self.grad(self.network, weights)(data, label, sens) | |||||
| new_grads = () | |||||
| m = 0 | |||||
| for grad in grads: | |||||
| if self.do_privacy and self._suppress_pri_ctrl.mask_started: | |||||
| enable_mask, grad_mask = self.grad_mask_tup[m]() | |||||
| enable_de_weight, de_weight_array = self.de_weight_tup[m]() | |||||
| if enable_mask and enable_de_weight: | |||||
| grad_n = self._tuple_add(de_weight_array, self._tuple_mul(grad, grad_mask)) | |||||
| new_grads = new_grads + (grad_n,) | |||||
| else: | |||||
| new_grads = new_grads + (grad,) | |||||
| else: | |||||
| new_grads = new_grads + (grad,) | |||||
| m = m + 1 | |||||
| if self.reducer_flag: | |||||
| new_grads = self.grad_reducer(new_grads) | |||||
| return F.depend(loss, self.optimizer(new_grads)) | |||||
| @@ -1,4 +1,4 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | # you may not use this file except in compliance with the License. | ||||
| @@ -12,6 +12,6 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| """ | """ | ||||
| This package includes unit tests for differential-privacy training and | |||||
| privacy breach estimation. | |||||
| This package includes unit tests for differential-privacy training, | |||||
| suppress-privacy training and privacy breach estimation. | |||||
| """ | """ | ||||
| @@ -0,0 +1,16 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| This package includes unit tests for suppress-privacy training. | |||||
| """ | |||||
| @@ -0,0 +1,85 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ | |||||
| Suppress Privacy model test. | |||||
| """ | |||||
| import pytest | |||||
| import numpy as np | |||||
| from mindspore import nn | |||||
| from mindspore import context | |||||
| from mindspore.train.callback import ModelCheckpoint | |||||
| from mindspore.train.callback import CheckpointConfig | |||||
| from mindspore.train.callback import LossMonitor | |||||
| from mindspore.nn.metrics import Accuracy | |||||
| import mindspore.dataset as ds | |||||
| from ut.python.utils.mock_net import Net as LeNet5 | |||||
| from mindarmour.privacy.sup_privacy import SuppressModel | |||||
| from mindarmour.privacy.sup_privacy import SuppressMasker | |||||
| from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory | |||||
| from mindarmour.privacy.sup_privacy import MaskLayerDes | |||||
| def dataset_generator(batch_size, batches): | |||||
| """mock training data.""" | |||||
| data = np.random.random((batches*batch_size, 1, 32, 32)).astype( | |||||
| np.float32) | |||||
| label = np.random.randint(0, 10, batches*batch_size).astype(np.int32) | |||||
| for i in range(batches): | |||||
| yield data[i*batch_size:(i + 1)*batch_size],\ | |||||
| label[i*batch_size:(i + 1)*batch_size] | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_card | |||||
| @pytest.mark.component_mindarmour | |||||
| def test_suppress_model_with_pynative_mode(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") | |||||
| networks_l5 = LeNet5() | |||||
| epochs = 5 | |||||
| batch_num = 10 | |||||
| batch_size = 32 | |||||
| mask_times = 10 | |||||
| lr = 0.01 | |||||
| masklayers_lenet5 = [] | |||||
| masklayers_lenet5.append(MaskLayerDes("conv1.weight", False, False, -1)) | |||||
| suppress_ctrl_instance = SuppressPrivacyFactory().create(policy="local_train", | |||||
| end_epoch=epochs, | |||||
| batch_num=batch_num, | |||||
| start_epoch=1, | |||||
| mask_times=mask_times, | |||||
| networks=networks_l5, | |||||
| lr=lr, | |||||
| sparse_end=0.50, | |||||
| sparse_start=0.0, | |||||
| mask_layers=masklayers_lenet5) | |||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
| net_opt = nn.SGD(networks_l5.trainable_params(), lr) | |||||
| model_instance = SuppressModel( | |||||
| network=networks_l5, | |||||
| loss_fn=net_loss, | |||||
| optimizer=net_opt, | |||||
| metrics={"Accuracy": Accuracy()}) | |||||
| model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
| suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance) | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
| directory="./trained_ckpt_file/", | |||||
| config=config_ck) | |||||
| ds_train = ds.GeneratorDataset(dataset_generator(batch_size, batch_num), ['data', 'label']) | |||||
| model_instance.train(epochs, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
| dataset_sink_mode=False) | |||||