Browse Source

Add secure parameters for mindspore federated learning.

tags/v1.3.0
jin-xiulang 4 years ago
parent
commit
ebc71d3306
10 changed files with 134 additions and 14 deletions
  1. +7
    -1
      mindspore/ccsrc/pipeline/jit/init.cc
  2. +41
    -0
      mindspore/ccsrc/ps/ps_context.cc
  3. +32
    -4
      mindspore/ccsrc/ps/ps_context.h
  4. +1
    -1
      mindspore/ccsrc/ps/server/kernel/round/reconstruct_secrets_kernel.cc
  5. +3
    -2
      mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc
  6. +8
    -4
      mindspore/ccsrc/ps/server/server.cc
  7. +11
    -0
      mindspore/context.py
  8. +5
    -1
      mindspore/parallel/_ps_context.py
  9. +13
    -0
      tests/st/fl/mobile/run_mobile_server.py
  10. +13
    -1
      tests/st/fl/mobile/test_mobile_lenet.py

+ 7
- 1
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -397,7 +397,13 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_config_file_path", &PSContext::set_config_file_path,
"Set configuration files required by the communication layer.")
.def("config_file_path", &PSContext::config_file_path,
"Get configuration files required by the communication layer.");
"Get configuration files required by the communication layer.")
.def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.")
.def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.")
.def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
"Set dp norm clip for federated learning secure aggregation.")
.def("set_encrypt_type", &PSContext::set_encrypt_type,
"Set encrypt type for federated learning secure aggregation.");

(void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
.def(py::init())


+ 41
- 0
mindspore/ccsrc/ps/ps_context.cc View File

@@ -184,6 +184,47 @@ void PSContext::set_server_mode(const std::string &server_mode) {

const std::string &PSContext::server_mode() const { return server_mode_; }

void PSContext::set_encrypt_type(const std::string &encrypt_type) {
if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType) {
MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
<< kDPEncryptType << " or " << kPWEncryptType;
return;
}
encrypt_type_ = encrypt_type;
}
const std::string &PSContext::encrypt_type() const { return encrypt_type_; }

void PSContext::set_dp_eps(float dp_eps) {
if (dp_eps > 0) {
dp_eps_ = dp_eps;
} else {
MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0.";
return;
}
}

float PSContext::dp_eps() const { return dp_eps_; }

void PSContext::set_dp_delta(float dp_delta) {
if (dp_delta > 0 && dp_delta < 1) {
dp_delta_ = dp_delta;
} else {
MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1).";
return;
}
}
float PSContext::dp_delta() const { return dp_delta_; }

void PSContext::set_dp_norm_clip(float dp_norm_clip) {
if (dp_norm_clip > 0) {
dp_norm_clip_ = dp_norm_clip;
} else {
MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0.";
return;
}
}
float PSContext::dp_norm_clip() const { return dp_norm_clip_; }

void PSContext::set_ms_role(const std::string &role) {
if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";


+ 32
- 4
mindspore/ccsrc/ps/ps_context.h View File

@@ -35,9 +35,9 @@ constexpr char kEnvRoleOfServer[] = "MS_SERVER";
constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
constexpr char kDPEncryptType[] = "DPEncrypt";
constexpr char kPWEncryptType[] = "PWEncrypt";
constexpr char kNotEncryptType[] = "NotEncrypt";
constexpr char kDPEncryptType[] = "DP_ENCRYPT";
constexpr char kPWEncryptType[] = "PW_ENCRYPT";
constexpr char kNotEncryptType[] = "NOT_ENCRYPT";

// Use binary data to represent federated learning server's context so that we can judge which round resets the
// iteration. From right to left, each bit stands for:
@@ -166,6 +166,18 @@ class PSContext {
void set_config_file_path(const std::string &path);
std::string config_file_path() const;

void set_dp_eps(float dp_eps);
float dp_eps() const;

void set_dp_delta(float dp_delta);
float dp_delta() const;

void set_dp_norm_clip(float dp_norm_clip);
float dp_norm_clip() const;

void set_encrypt_type(const std::string &encrypt_type);
const std::string &encrypt_type() const;

private:
PSContext()
: ps_enabled_(false),
@@ -199,7 +211,11 @@ class PSContext {
secure_aggregation_(false),
cluster_config_(nullptr),
scheduler_manage_port_(11202),
config_file_path_("") {}
config_file_path_(""),
dp_eps_(50),
dp_delta_(0.01),
dp_norm_clip_(1.0),
encrypt_type_(kNotEncryptType) {}
bool ps_enabled_;
bool is_worker_;
bool is_pserver_;
@@ -276,6 +292,18 @@ class PSContext {

// The path of the configuration file, used to configure the certification path and persistent storage type, etc.
std::string config_file_path_;

// Epsilon budget of differential privacy mechanism. Used in federated learning for now.
float dp_eps_;

// Delta budget of differential privacy mechanism. Used in federated learning for now.
float dp_delta_;

// Norm clip factor of differential privacy mechanism. Used in federated learning for now.
float dp_norm_clip_;

// Secure mechanism for federated learning. Used in federated learning for now.
std::string encrypt_type_;
};
} // namespace ps
} // namespace mindspore


+ 1
- 1
mindspore/ccsrc/ps/server/kernel/round/reconstruct_secrets_kernel.cc View File

@@ -133,7 +133,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con

void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) {
MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num();
if (true) { // todo: PSContext::instance()->encrypt_type == PWEncrypt {
if (PSContext::instance()->encrypt_type() == kPWEncryptType) {
while (!Executor::GetInstance().IsAllWeightAggregationDone()) {
std::this_thread::sleep_for(std::chrono::milliseconds(5));
}


+ 3
- 2
mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc View File

@@ -96,8 +96,9 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize);
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
<< total_data_size;

FinishIteration();
if (PSContext::instance()->encrypt_type() != kPWEncryptType) {
FinishIteration();
}
}
}



+ 8
- 4
mindspore/ccsrc/ps/server/server.cc View File

@@ -235,6 +235,10 @@ void Server::InitCipher() {
unsigned char cipher_p[SECRET_MAX_LEN] = {0};
int cipher_g = 1;
unsigned char cipher_prime[PRIME_MAX_LEN] = {0};
float dp_eps = PSContext::instance()->dp_eps();
float dp_delta = PSContext::instance()->dp_delta();
float dp_norm_clip = PSContext::instance()->dp_norm_clip();
std::string encrypt_type = PSContext::instance()->encrypt_type();

mpz_t prim;
mpz_init(prim);
@@ -248,10 +252,10 @@ void Server::InitCipher() {
param.t = cipher_t;
memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN);
memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN);
// param.dp_delta = dp_delta;
// param.dp_eps = dp_eps;
// param.dp_norm_clip = dp_norm_clip;
param.encrypt_type = kNotEncryptType; // PSContext::instance()->encrypt_type;
param.dp_delta = dp_delta;
param.dp_eps = dp_eps;
param.dp_norm_clip = dp_norm_clip;
param.encrypt_type = encrypt_type;
cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_,
cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_,
cipher_reconstruct_secrets_up_cnt_);


+ 11
- 0
mindspore/context.py View File

@@ -851,6 +851,17 @@ def set_fl_context(**kwargs):
client_learning_rate (float): Client training learning rate. Default: 0.001.
worker_step_num_per_iteration (int): The worker's standalone training step number before communicating with
server. Default: 65.
dp_eps (float): Epsilon budget of differential privacy mechanism. The smaller the dp_eps, the better the
privacy protection effect. Default: 50.0.
dp_delta (float): Delta budget of differential privacy mechanism, which is usually equals the reciprocal of
client number. The smaller the dp_delta, the better the privacy protection effect. Default: 0.01.
dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is
suggested to be 0.5~2. Default: 1.0.
encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT' or
'PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied for clients and the privacy
protection effect would be determined by dp_eps, dp_delta and dp_norm_clip as described above. If
'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' model from stealing.
Default: 'NOT_ENCRYPT'.

Raises:
ValueError: If input key is not the attribute in federated learning mode context.


+ 5
- 1
mindspore/parallel/_ps_context.py View File

@@ -68,7 +68,11 @@ _set_ps_context_func_map = {
"worker_step_num_per_iteration": ps_context().set_worker_step_num_per_iteration,
"enable_ps_ssl": ps_context().set_enable_ssl,
"scheduler_manage_port": ps_context().set_scheduler_manage_port,
"config_file_path": ps_context().set_config_file_path
"config_file_path": ps_context().set_config_file_path,
"dp_eps": ps_context().set_dp_eps,
"dp_delta": ps_context().set_dp_delta,
"dp_norm_clip": ps_context().set_dp_norm_clip,
"encrypt_type": ps_context().set_encrypt_type
}

_get_ps_context_func_map = {


+ 13
- 0
tests/st/fl/mobile/run_mobile_server.py View File

@@ -38,6 +38,11 @@ parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--local_server_num", type=int, default=-1)
parser.add_argument("--config_file_path", type=str, default="")
parser.add_argument("--encrypt_type", type=str, default="NotEncrypt")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)

if __name__ == "__main__":
args, _ = parser.parse_known_args()
@@ -62,6 +67,10 @@ if __name__ == "__main__":
client_learning_rate = args.client_learning_rate
local_server_num = args.local_server_num
config_file_path = args.config_file_path
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
encrypt_type = args.encrypt_type

if local_server_num == -1:
local_server_num = server_num
@@ -95,6 +104,10 @@ if __name__ == "__main__":
cmd_server += " --client_epoch_num=" + str(client_epoch_num)
cmd_server += " --client_batch_size=" + str(client_batch_size)
cmd_server += " --client_learning_rate=" + str(client_learning_rate)
cmd_server += " --dp_eps=" + str(dp_eps)
cmd_server += " --dp_delta=" + str(dp_delta)
cmd_server += " --dp_norm_clip=" + str(dp_norm_clip)
cmd_server += " --encrypt_type=" + str(encrypt_type)
cmd_server += " > server.log 2>&1 &"

import time


+ 13
- 1
tests/st/fl/mobile/test_mobile_lenet.py View File

@@ -46,6 +46,10 @@ parser.add_argument("--client_batch_size", type=int, default=32)
parser.add_argument("--client_learning_rate", type=float, default=0.1)
parser.add_argument("--scheduler_manage_port", type=int, default=11202)
parser.add_argument("--config_file_path", type=str, default="")
# parameters for encrypt_type='DP_ENCRYPT'
parser.add_argument("--dp_eps", type=float, default=50.0)
parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num
parser.add_argument("--dp_norm_clip", type=float, default=1.0)

args, _ = parser.parse_known_args()
device_target = args.device_target
@@ -70,6 +74,10 @@ client_batch_size = args.client_batch_size
client_learning_rate = args.client_learning_rate
scheduler_manage_port = args.scheduler_manage_port
config_file_path = args.config_file_path
dp_eps = args.dp_eps
dp_delta = args.dp_delta
dp_norm_clip = args.dp_norm_clip
encrypt_type = args.encrypt_type

ctx = {
"enable_fl": True,
@@ -93,7 +101,11 @@ ctx = {
"client_batch_size": client_batch_size,
"client_learning_rate": client_learning_rate,
"scheduler_manage_port": scheduler_manage_port,
"config_file_path": config_file_path
"config_file_path": config_file_path,
"dp_eps": dp_eps,
"dp_delta": dp_delta,
"dp_norm_clip": dp_norm_clip,
"encrypt_type": encrypt_type
}

context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False)


Loading…
Cancel
Save