| @@ -134,6 +134,7 @@ class SuppressCtrl(Cell): | |||
| self.mask_all_steps = (end_epoch - start_epoch + 1)*batch_num # the amount of step contained in all suppress operation | |||
| self.mask_step_interval = self.mask_all_steps/mask_times # the amount of step contaied in one suppress operation | |||
| self.mask_initialized = False # flag means the initialization is done | |||
| self.grad_idx_map = [] | |||
| if self.lr > 0.5: | |||
| msg = "learning rate should not be greater than 0.5, but got {}".format(self.lr) | |||
| @@ -210,6 +211,8 @@ class SuppressCtrl(Cell): | |||
| de_weight_cell.mask_able = False | |||
| self.de_weight_mask_list.append(de_weight_cell) | |||
| self.grad_idx_map.append(-1) | |||
| m = 0 | |||
| for layer in networks.get_parameters(expand=True): | |||
| one_mask_layer = None | |||
| @@ -231,6 +234,7 @@ class SuppressCtrl(Cell): | |||
| de_weight_cell = DeWeightInCell(add_mask_array) | |||
| de_weight_cell.mask_able = True | |||
| self.de_weight_mask_list[one_mask_layer.grad_idx] = de_weight_cell | |||
| self.grad_idx_map[m] = one_mask_layer.grad_idx | |||
| msg = "do mask {}, {}, {}".format(m, one_mask_layer.layer_name, one_mask_layer.grad_idx) | |||
| LOGGER.info(TAG, msg) | |||
| elif one_mask_layer is not None and one_mask_layer.inited: | |||
| @@ -294,7 +298,10 @@ class SuppressCtrl(Cell): | |||
| math.pow((1.0 - (cur_step + 0.0 - self.mask_start_step) / self.mask_all_steps), 3) | |||
| m = 0 | |||
| for layer in networks.get_parameters(expand=True): | |||
| if self.grads_mask_list[m].mask_able: | |||
| grad_idx = self.grad_idx_map[m] | |||
| if grad_idx < 0: | |||
| continue | |||
| if self.grads_mask_list[grad_idx].mask_able: | |||
| weight_array = layer.data.asnumpy() | |||
| weight_avg = np.mean(weight_array) | |||
| weight_array_flat = weight_array.flatten() | |||
| @@ -307,14 +314,14 @@ class SuppressCtrl(Cell): | |||
| msg = "give up this masking .." | |||
| LOGGER.info(TAG, msg) | |||
| return | |||
| if self.grads_mask_list[m].min_num > 0: | |||
| if self.grads_mask_list[grad_idx].min_num > 0: | |||
| sparse_weight_thd, _, actual_stop_pos = self.calc_sparse_thd(weight_array_flat_abs, | |||
| self.cur_sparse, m) | |||
| self.cur_sparse, grad_idx) | |||
| else: | |||
| actual_stop_pos = int(len_array * self.cur_sparse) | |||
| sparse_weight_thd = weight_array_flat_abs[actual_stop_pos] | |||
| self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, m) | |||
| self.update_mask_layer(weight_array_flat, sparse_weight_thd, actual_stop_pos, weight_abs_max, grad_idx) | |||
| msg = "{} len={}, sparse={}, current sparse thd={}, max={}, avg={}, avg_abs={} \n".format( | |||
| layer.name, len_array, actual_stop_pos/len_array, sparse_weight_thd, | |||
| @@ -570,8 +577,12 @@ class MaskLayerDes: | |||
| Args: | |||
| layer_name (str): Layer name, get the name of one layer as following: | |||
| for layer in networks.get_parameters(expand=True): | |||
| if layer.name == "conv": ... | |||
| .. code-block:: | |||
| for layer in networks.get_parameters(expand=True): | |||
| if layer.name == "conv": ... | |||
| grad_idx (int): Grad layer index, get mask layer's index in grad tuple.You can refer to the construct function | |||
| of TrainOneStepCell in mindarmour/privacy/sup_privacy/train/model.py to get the index of some specified | |||
| grad layers (print in PYNATIVE_MODE). | |||