| @@ -397,7 +397,13 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("set_config_file_path", &PSContext::set_config_file_path, | |||
| "Set configuration files required by the communication layer.") | |||
| .def("config_file_path", &PSContext::config_file_path, | |||
| "Get configuration files required by the communication layer."); | |||
| "Get configuration files required by the communication layer.") | |||
| .def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.") | |||
| .def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.") | |||
| .def("set_dp_norm_clip", &PSContext::set_dp_norm_clip, | |||
| "Set dp norm clip for federated learning secure aggregation.") | |||
| .def("set_encrypt_type", &PSContext::set_encrypt_type, | |||
| "Set encrypt type for federated learning secure aggregation."); | |||
| (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | |||
| .def(py::init()) | |||
| @@ -184,6 +184,47 @@ void PSContext::set_server_mode(const std::string &server_mode) { | |||
| const std::string &PSContext::server_mode() const { return server_mode_; } | |||
| void PSContext::set_encrypt_type(const std::string &encrypt_type) { | |||
| if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType) { | |||
| MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or " | |||
| << kDPEncryptType << " or " << kPWEncryptType; | |||
| return; | |||
| } | |||
| encrypt_type_ = encrypt_type; | |||
| } | |||
| const std::string &PSContext::encrypt_type() const { return encrypt_type_; } | |||
| void PSContext::set_dp_eps(float dp_eps) { | |||
| if (dp_eps > 0) { | |||
| dp_eps_ = dp_eps; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0."; | |||
| return; | |||
| } | |||
| } | |||
| float PSContext::dp_eps() const { return dp_eps_; } | |||
| void PSContext::set_dp_delta(float dp_delta) { | |||
| if (dp_delta > 0 && dp_delta < 1) { | |||
| dp_delta_ = dp_delta; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1)."; | |||
| return; | |||
| } | |||
| } | |||
| float PSContext::dp_delta() const { return dp_delta_; } | |||
| void PSContext::set_dp_norm_clip(float dp_norm_clip) { | |||
| if (dp_norm_clip > 0) { | |||
| dp_norm_clip_ = dp_norm_clip; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0."; | |||
| return; | |||
| } | |||
| } | |||
| float PSContext::dp_norm_clip() const { return dp_norm_clip_; } | |||
| void PSContext::set_ms_role(const std::string &role) { | |||
| if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) { | |||
| MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context."; | |||
| @@ -35,9 +35,9 @@ constexpr char kEnvRoleOfServer[] = "MS_SERVER"; | |||
| constexpr char kEnvRoleOfWorker[] = "MS_WORKER"; | |||
| constexpr char kEnvRoleOfScheduler[] = "MS_SCHED"; | |||
| constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS"; | |||
| constexpr char kDPEncryptType[] = "DPEncrypt"; | |||
| constexpr char kPWEncryptType[] = "PWEncrypt"; | |||
| constexpr char kNotEncryptType[] = "NotEncrypt"; | |||
| constexpr char kDPEncryptType[] = "DP_ENCRYPT"; | |||
| constexpr char kPWEncryptType[] = "PW_ENCRYPT"; | |||
| constexpr char kNotEncryptType[] = "NOT_ENCRYPT"; | |||
| // Use binary data to represent federated learning server's context so that we can judge which round resets the | |||
| // iteration. From right to left, each bit stands for: | |||
| @@ -166,6 +166,18 @@ class PSContext { | |||
| void set_config_file_path(const std::string &path); | |||
| std::string config_file_path() const; | |||
| void set_dp_eps(float dp_eps); | |||
| float dp_eps() const; | |||
| void set_dp_delta(float dp_delta); | |||
| float dp_delta() const; | |||
| void set_dp_norm_clip(float dp_norm_clip); | |||
| float dp_norm_clip() const; | |||
| void set_encrypt_type(const std::string &encrypt_type); | |||
| const std::string &encrypt_type() const; | |||
| private: | |||
| PSContext() | |||
| : ps_enabled_(false), | |||
| @@ -199,7 +211,11 @@ class PSContext { | |||
| secure_aggregation_(false), | |||
| cluster_config_(nullptr), | |||
| scheduler_manage_port_(11202), | |||
| config_file_path_("") {} | |||
| config_file_path_(""), | |||
| dp_eps_(50), | |||
| dp_delta_(0.01), | |||
| dp_norm_clip_(1.0), | |||
| encrypt_type_(kNotEncryptType) {} | |||
| bool ps_enabled_; | |||
| bool is_worker_; | |||
| bool is_pserver_; | |||
| @@ -276,6 +292,18 @@ class PSContext { | |||
| // The path of the configuration file, used to configure the certification path and persistent storage type, etc. | |||
| std::string config_file_path_; | |||
| // Epsilon budget of differential privacy mechanism. Used in federated learning for now. | |||
| float dp_eps_; | |||
| // Delta budget of differential privacy mechanism. Used in federated learning for now. | |||
| float dp_delta_; | |||
| // Norm clip factor of differential privacy mechanism. Used in federated learning for now. | |||
| float dp_norm_clip_; | |||
| // Secure mechanism for federated learning. Used in federated learning for now. | |||
| std::string encrypt_type_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -133,7 +133,7 @@ bool ReconstructSecretsKernel::Launch(const std::vector<AddressPtr> &inputs, con | |||
| void ReconstructSecretsKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { | |||
| MS_LOG(INFO) << "ITERATION NUMBER IS : " << LocalMetaStore::GetInstance().curr_iter_num(); | |||
| if (true) { // todo: PSContext::instance()->encrypt_type == PWEncrypt { | |||
| if (PSContext::instance()->encrypt_type() == kPWEncryptType) { | |||
| while (!Executor::GetInstance().IsAllWeightAggregationDone()) { | |||
| std::this_thread::sleep_for(std::chrono::milliseconds(5)); | |||
| } | |||
| @@ -96,8 +96,9 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand | |||
| size_t total_data_size = LocalMetaStore::GetInstance().value<size_t>(kCtxFedAvgTotalDataSize); | |||
| MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is " | |||
| << total_data_size; | |||
| FinishIteration(); | |||
| if (PSContext::instance()->encrypt_type() != kPWEncryptType) { | |||
| FinishIteration(); | |||
| } | |||
| } | |||
| } | |||
| @@ -235,6 +235,10 @@ void Server::InitCipher() { | |||
| unsigned char cipher_p[SECRET_MAX_LEN] = {0}; | |||
| int cipher_g = 1; | |||
| unsigned char cipher_prime[PRIME_MAX_LEN] = {0}; | |||
| float dp_eps = PSContext::instance()->dp_eps(); | |||
| float dp_delta = PSContext::instance()->dp_delta(); | |||
| float dp_norm_clip = PSContext::instance()->dp_norm_clip(); | |||
| std::string encrypt_type = PSContext::instance()->encrypt_type(); | |||
| mpz_t prim; | |||
| mpz_init(prim); | |||
| @@ -248,10 +252,10 @@ void Server::InitCipher() { | |||
| param.t = cipher_t; | |||
| memcpy_s(param.p, SECRET_MAX_LEN, cipher_p, SECRET_MAX_LEN); | |||
| memcpy_s(param.prime, PRIME_MAX_LEN, cipher_prime, PRIME_MAX_LEN); | |||
| // param.dp_delta = dp_delta; | |||
| // param.dp_eps = dp_eps; | |||
| // param.dp_norm_clip = dp_norm_clip; | |||
| param.encrypt_type = kNotEncryptType; // PSContext::instance()->encrypt_type; | |||
| param.dp_delta = dp_delta; | |||
| param.dp_eps = dp_eps; | |||
| param.dp_norm_clip = dp_norm_clip; | |||
| param.encrypt_type = encrypt_type; | |||
| cipher_init_->Init(param, 0, cipher_initial_client_cnt_, cipher_exchange_secrets_cnt_, cipher_share_secrets_cnt_, | |||
| cipher_get_clientlist_cnt_, cipher_reconstruct_secrets_down_cnt_, | |||
| cipher_reconstruct_secrets_up_cnt_); | |||
| @@ -851,6 +851,17 @@ def set_fl_context(**kwargs): | |||
| client_learning_rate (float): Client training learning rate. Default: 0.001. | |||
| worker_step_num_per_iteration (int): The worker's standalone training step number before communicating with | |||
| server. Default: 65. | |||
| dp_eps (float): Epsilon budget of differential privacy mechanism. The smaller the dp_eps, the better the | |||
| privacy protection effect. Default: 50.0. | |||
| dp_delta (float): Delta budget of differential privacy mechanism, which is usually equals the reciprocal of | |||
| client number. The smaller the dp_delta, the better the privacy protection effect. Default: 0.01. | |||
| dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is | |||
| suggested to be 0.5~2. Default: 1.0. | |||
| encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT' or | |||
| 'PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied for clients and the privacy | |||
| protection effect would be determined by dp_eps, dp_delta and dp_norm_clip as described above. If | |||
| 'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' model from stealing. | |||
| Default: 'NOT_ENCRYPT'. | |||
| Raises: | |||
| ValueError: If input key is not the attribute in federated learning mode context. | |||
| @@ -68,7 +68,11 @@ _set_ps_context_func_map = { | |||
| "worker_step_num_per_iteration": ps_context().set_worker_step_num_per_iteration, | |||
| "enable_ps_ssl": ps_context().set_enable_ssl, | |||
| "scheduler_manage_port": ps_context().set_scheduler_manage_port, | |||
| "config_file_path": ps_context().set_config_file_path | |||
| "config_file_path": ps_context().set_config_file_path, | |||
| "dp_eps": ps_context().set_dp_eps, | |||
| "dp_delta": ps_context().set_dp_delta, | |||
| "dp_norm_clip": ps_context().set_dp_norm_clip, | |||
| "encrypt_type": ps_context().set_encrypt_type | |||
| } | |||
| _get_ps_context_func_map = { | |||
| @@ -38,6 +38,11 @@ parser.add_argument("--client_batch_size", type=int, default=32) | |||
| parser.add_argument("--client_learning_rate", type=float, default=0.1) | |||
| parser.add_argument("--local_server_num", type=int, default=-1) | |||
| parser.add_argument("--config_file_path", type=str, default="") | |||
| parser.add_argument("--encrypt_type", type=str, default="NotEncrypt") | |||
| # parameters for encrypt_type='DP_ENCRYPT' | |||
| parser.add_argument("--dp_eps", type=float, default=50.0) | |||
| parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num | |||
| parser.add_argument("--dp_norm_clip", type=float, default=1.0) | |||
| if __name__ == "__main__": | |||
| args, _ = parser.parse_known_args() | |||
| @@ -62,6 +67,10 @@ if __name__ == "__main__": | |||
| client_learning_rate = args.client_learning_rate | |||
| local_server_num = args.local_server_num | |||
| config_file_path = args.config_file_path | |||
| dp_eps = args.dp_eps | |||
| dp_delta = args.dp_delta | |||
| dp_norm_clip = args.dp_norm_clip | |||
| encrypt_type = args.encrypt_type | |||
| if local_server_num == -1: | |||
| local_server_num = server_num | |||
| @@ -95,6 +104,10 @@ if __name__ == "__main__": | |||
| cmd_server += " --client_epoch_num=" + str(client_epoch_num) | |||
| cmd_server += " --client_batch_size=" + str(client_batch_size) | |||
| cmd_server += " --client_learning_rate=" + str(client_learning_rate) | |||
| cmd_server += " --dp_eps=" + str(dp_eps) | |||
| cmd_server += " --dp_delta=" + str(dp_delta) | |||
| cmd_server += " --dp_norm_clip=" + str(dp_norm_clip) | |||
| cmd_server += " --encrypt_type=" + str(encrypt_type) | |||
| cmd_server += " > server.log 2>&1 &" | |||
| import time | |||
| @@ -46,6 +46,10 @@ parser.add_argument("--client_batch_size", type=int, default=32) | |||
| parser.add_argument("--client_learning_rate", type=float, default=0.1) | |||
| parser.add_argument("--scheduler_manage_port", type=int, default=11202) | |||
| parser.add_argument("--config_file_path", type=str, default="") | |||
| # parameters for encrypt_type='DP_ENCRYPT' | |||
| parser.add_argument("--dp_eps", type=float, default=50.0) | |||
| parser.add_argument("--dp_delta", type=float, default=0.01) # 1/worker_num | |||
| parser.add_argument("--dp_norm_clip", type=float, default=1.0) | |||
| args, _ = parser.parse_known_args() | |||
| device_target = args.device_target | |||
| @@ -70,6 +74,10 @@ client_batch_size = args.client_batch_size | |||
| client_learning_rate = args.client_learning_rate | |||
| scheduler_manage_port = args.scheduler_manage_port | |||
| config_file_path = args.config_file_path | |||
| dp_eps = args.dp_eps | |||
| dp_delta = args.dp_delta | |||
| dp_norm_clip = args.dp_norm_clip | |||
| encrypt_type = args.encrypt_type | |||
| ctx = { | |||
| "enable_fl": True, | |||
| @@ -93,7 +101,11 @@ ctx = { | |||
| "client_batch_size": client_batch_size, | |||
| "client_learning_rate": client_learning_rate, | |||
| "scheduler_manage_port": scheduler_manage_port, | |||
| "config_file_path": config_file_path | |||
| "config_file_path": config_file_path, | |||
| "dp_eps": dp_eps, | |||
| "dp_delta": dp_delta, | |||
| "dp_norm_clip": dp_norm_clip, | |||
| "encrypt_type": encrypt_type | |||
| } | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=device_target, save_graphs=False) | |||