|
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """FederatedLearningManager related class and functions."""
-
- from copy import deepcopy
- import numpy as np
- from mindspore import context, nn
- from mindspore.common import Parameter, ParameterTuple
- from mindspore.train.callback import Callback
- from mindspore.ops import operations as P
- from mindspore._checkparam import Validator, Rel
-
-
- class _StartFLJob(nn.Cell):
- """
- StartFLJob for Federated Learning Worker.
- """
- def __init__(self, data_size):
- super(_StartFLJob, self).__init__()
- self.start_fl_job = P.StartFLJob(data_size)
-
- def construct(self):
- succ = self.start_fl_job()
- return succ
-
-
- class _UpdateAndGetModel(nn.Cell):
- """
- Update and Get Model for Federated Learning Worker.
- """
- def __init__(self, weights):
- super(_UpdateAndGetModel, self).__init__()
- self.update_model = P.UpdateModel()
- self.get_model = P.GetModel()
- self.weights = weights
-
- def construct(self):
- self.update_model(self.weights)
- succ = self.get_model(self.weights)
- return succ
-
-
- class _ExchangeKeys(nn.Cell):
- """
- Exchange Keys for Stable PW Encrypt.
- """
- def __init__(self):
- super(ExchangeKeys, self).__init__()
- self.exchange_keys = P.ExchangeKeys()
-
- def construct(self):
- return self.exchange_keys()
-
-
- class _GetKeys(nn.Cell):
- """
- Get Keys for Stable PW Encrypt.
- """
- def __init__(self):
- super(GetKeys, self).__init__()
- self.get_keys = P.GetKeys()
-
- def construct(self):
- return self.get_keys()
-
-
- class FederatedLearningManager(Callback):
- """
- Manage Federated Learning during training.
-
- Args:
- model (nn.Cell): A training model.
- sync_frequency (int): Synchronization frequency of parameters in Federated Learning.
- Note:
- In dataset sink mode, the unit of the frequency is the number of epochs.
- Otherwise, the unit of the frequency is the number of steps.
- sync_type (str): Parameter synchronization type in Federated Learning.
- Supports ["fixed", "adaptive"]. Default: "fixed".
-
- - fixed: The frequency of parameter synchronization is fixed.
- - adaptive: The frequency of parameter synchronization changes adaptively.
-
- Note:
- This is an experimental prototype that is subject to change.
- """
-
- def __init__(self, model, sync_frequency, sync_type='fixed', **kwargs):
- super(FederatedLearningManager, self).__init__()
- server_mode = context.get_fl_context("server_mode")
- if server_mode not in ("FEDERATED_LEARNING", "HYBRID_TRAINING"):
- raise ValueError("server_mode must in (\"FEDERATED_LEARNING\", \"HYBRID_TRAINING\")")
- Validator.check_isinstance('model', model, nn.Cell)
- Validator.check_positive_int(sync_frequency)
- Validator.check_string(sync_type, ["fixed", "adaptive"])
- self._model = model
- self._sync_frequency = sync_frequency
- self._next_sync_iter_id = self._sync_frequency
- self._sync_type = sync_type
- self._global_step = 0
- self._data_size = 0
-
- if self._is_adaptive_sync():
- self._as_set_init_state(kwargs)
- self._as_wrap_cell()
-
- def _is_adaptive_sync(self):
- """
- Determine whether adaptive frequency synchronization is required.
- """
- return self._sync_type == "adaptive"
-
- def _as_set_init_state(self, kwargs):
- """
- Setting the initial state for adaptive synchronization.
- """
- self._as_prefix = "as_abs_grad."
-
- self._min_consistent_rate = kwargs.get("min_consistent_rate", 1.1)
- Validator.check_non_negative_float(self._min_consistent_rate)
- self._min_consistent_rate_at_round = kwargs.get("min_consistent_rate_at_round", 0)
- Validator.check_non_negative_int(self._min_consistent_rate_at_round)
- self._ema_alpha = kwargs.get("ema_alpha", 0.5)
- Validator.check_float_range(self._ema_alpha, 0.0, 1.0, Rel.INC_NEITHER)
- self._observation_window_size = kwargs.get("observation_window_size", 5)
- Validator.check_positive_int(self._observation_window_size)
- self._frequency_increase_ratio = kwargs.get("frequency_increase_ratio", 2)
- Validator.check_positive_int(self._frequency_increase_ratio)
- self._unchanged_round = kwargs.get("unchanged_round", 0)
- Validator.check_non_negative_int(self._unchanged_round)
-
- self._round_id = 0
- self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params()
- if self._as_prefix not in _.name}
- self._model_size = 0
- self._grads_ema = dict()
- self._abs_grads_ema = dict()
- for param in self._model.trainable_params():
- if self._as_prefix not in param.name:
- self._model_size += np.product(param.shape)
- self._grads_ema[param.name] = np.zeros(param.shape)
- self._abs_grads_ema[param.name] = np.zeros(param.shape)
- self._model_size = float(self._model_size)
-
- def _as_wrap_cell(self):
- """
- Wrap Cell for adaptive synchronization.
- """
- param_list = list()
- for param in self._model.trainable_params():
- new_param = param.clone()
- new_param.name = self._as_prefix + param.name
- param_list.append(new_param)
- for param in param_list:
- self._model.insert_param_to_cell(param.name, param, False)
-
- def _as_set_grads(self):
- """
- Set the absolute value of the gradient for adaptive synchronization.
- """
- abs_grads = dict()
- for param in self._model.trainable_params():
- if self._as_prefix not in param.name:
- abs_grads[self._as_prefix+param.name] = np.abs(param.asnumpy() - self._last_param[param.name])
- for param in self._model.trainable_params():
- if self._as_prefix in param.name:
- param.set_data(Parameter(abs_grads[param.name]))
-
- def _as_analyze_gradient(self):
- """
- Analysis of relevant statistics based on gradient for adaptive synchronization.
- """
- worker_num = context.get_fl_context("worker_num")
- ema_alpha = self._ema_alpha
- consistent_rate_sum = 0.0
- grads = dict()
- abs_grads = dict()
- for param in self._model.trainable_params():
- if self._as_prefix in param.name:
- abs_grads[param.name.replace(self._as_prefix, '')] = param.asnumpy() * worker_num
- else:
- grads[param.name] = (param.asnumpy() - self._last_param[param.name]) * worker_num
- for last_p in self._last_param:
- self._grads_ema[last_p] = ema_alpha * self._grads_ema[last_p] + (1 - ema_alpha) * grads[last_p]
- self._abs_grads_ema[last_p] = ema_alpha * self._abs_grads_ema[last_p] + (1 - ema_alpha) * abs_grads[last_p]
- divide_base = np.where(self._abs_grads_ema[last_p] == 0,
- np.ones(self._abs_grads_ema[last_p].shape), self._abs_grads_ema[last_p])
- layer_consistent_rate = np.abs(self._grads_ema[last_p]) / divide_base
- consistent_rate_sum += np.sum(layer_consistent_rate)
-
- consistent_rate = float(consistent_rate_sum / self._model_size)
-
- if self._min_consistent_rate > consistent_rate:
- self._min_consistent_rate = consistent_rate
- self._min_consistent_rate_at_round = self._round_id
- else:
- if self._round_id - self._min_consistent_rate_at_round > self._observation_window_size:
- if self._sync_frequency > 1 and self._round_id > self._unchanged_round:
- self._sync_frequency = (self._sync_frequency + self._frequency_increase_ratio - 1) \
- // self._frequency_increase_ratio
- self._min_consistent_rate = 1.1
- self._min_consistent_rate_at_round = self._round_id
- self._observation_window_size *= self._frequency_increase_ratio
-
- for param in self._model.trainable_params():
- if self._as_prefix not in param.name:
- self._grads_ema[param.name] = np.zeros(param.shape)
- self._abs_grads_ema[param.name] = np.zeros(param.shape)
-
- def _as_set_last_param(self):
- """
- Set the value of last parameters for adaptive synchronization.
- """
- self._last_param = {_.name: deepcopy(_.asnumpy()) for _ in self._model.trainable_params()
- if self._as_prefix not in _.name}
-
- def step_end(self, run_context):
- """
- Synchronization parameters at the end of step. If sync_type is "adaptive", the synchronous frequency is
- adaptively adjusted here.
-
- Args:
- run_context (RunContext): Context of the train running.
- """
- self._global_step += 1
- cb_params = run_context.original_args()
- inputs = cb_params.train_dataset_element
- batch_size = inputs[0].shape[0] if isinstance(inputs, (tuple, list)) else inputs.shape[0]
- self._data_size += batch_size
- if context.get_fl_context("ms_role") == "MS_WORKER":
- if self._global_step == self._next_sync_iter_id:
- start_fl_job = _StartFLJob(self._data_size)
- start_fl_job()
- self._data_size = 0
- if self._is_adaptive_sync():
- self._as_set_grads()
- update_and_get_model = _UpdateAndGetModel(ParameterTuple(self._model.trainable_params()))
- update_and_get_model()
- self._next_sync_iter_id = self._global_step + self._sync_frequency
- if self._is_adaptive_sync():
- self._as_analyze_gradient()
- self._round_id += 1
- self._as_set_last_param()
-
- print("sync step is: {}".format(self._global_step))
|