Browse Source

Fix FL server running bugs

pull/16134/head
ZPaC 4 years ago
parent
commit
6242f97bc2
10 changed files with 90 additions and 31 deletions
  1. +3
    -12
      mindspore/ccsrc/pipeline/jit/action.cc
  2. +4
    -4
      mindspore/ccsrc/ps/ps_context.cc
  3. +38
    -6
      mindspore/ccsrc/ps/ps_context.h
  4. +6
    -0
      mindspore/ccsrc/ps/server/kernel/aggregation_kernel.h
  5. +15
    -2
      mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.h
  6. +1
    -1
      mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc
  7. +2
    -0
      mindspore/ccsrc/ps/server/memory_register.cc
  8. +5
    -0
      mindspore/ccsrc/ps/server/memory_register.h
  9. +1
    -1
      mindspore/ccsrc/ps/server/model_store.cc
  10. +15
    -5
      mindspore/ccsrc/ps/server/parameter_aggregator.cc

+ 3
- 12
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -624,7 +624,6 @@ bool StartServerAction(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr func_graph = res->func_graph();
const std::string &server_mode_ = ps::PSContext::instance()->server_mode(); const std::string &server_mode_ = ps::PSContext::instance()->server_mode();
size_t worker_num = ps::PSContext::instance()->initial_worker_num(); size_t worker_num = ps::PSContext::instance()->initial_worker_num();
size_t server_num = ps::PSContext::instance()->initial_server_num();
uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port(); uint64_t fl_server_port = ps::PSContext::instance()->fl_server_port();


// Update model threshold is a certain ratio of start_fl_job threshold. // Update model threshold is a certain ratio of start_fl_job threshold.
@@ -633,17 +632,9 @@ bool StartServerAction(const ResourcePtr &res) {
float percent_for_update_model = 1; float percent_for_update_model = 1;
size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model)); size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * percent_for_update_model));


std::vector<ps::server::RoundConfig> rounds_config = {
{"startFLJob", false, 3000, false, start_fl_job_threshold},
{"updateModel", false, 3000, false, update_model_threshold},
{"getModel", false, 3000},
{"asyncUpdateModel"},
{"asyncGetModel"},
{"push", false, 3000, true, worker_num},
{"pull", false, 3000, true, worker_num},
{"getWeightsByKey", false, 3000, true, 1},
{"overwriteWeightsByKey", false, 3000, true, server_num},
};
std::vector<ps::server::RoundConfig> rounds_config = {{"startFLJob", false, 3000, true, start_fl_job_threshold},
{"updateModel", false, 3000, true, update_model_threshold},
{"getModel", false, 3000}};


size_t executor_threshold = 0; size_t executor_threshold = 0;
if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) { if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {


+ 4
- 4
mindspore/ccsrc/ps/ps_context.cc View File

@@ -235,7 +235,7 @@ void PSContext::GenerateResetterRound() {
} }


binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) | binary_server_context = (is_parameter_server_mode << 0) | (is_federated_learning_mode << 1) |
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_overwrite_weights_ << 4);
(is_mixed_training_mode << 2) | (secure_aggregation_ << 3) | (worker_upload_weights_ << 4);
if (kServerContextToResetRoundMap.count(binary_server_context) == 0) { if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
resetter_round_ = ResetterRound::kNoNeedToReset; resetter_round_ = ResetterRound::kNoNeedToReset;
} else { } else {
@@ -277,11 +277,11 @@ void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch


uint64_t PSContext::client_batch_size() const { return client_batch_size_; } uint64_t PSContext::client_batch_size() const { return client_batch_size_; }


void PSContext::set_worker_overwrite_weights(uint64_t worker_overwrite_weights) {
worker_overwrite_weights_ = worker_overwrite_weights;
void PSContext::set_worker_upload_weights(uint64_t worker_upload_weights) {
worker_upload_weights_ = worker_upload_weights;
} }


uint64_t PSContext::worker_overwrite_weights() const { return worker_overwrite_weights_; }
uint64_t PSContext::worker_upload_weights() const { return worker_upload_weights_; }


void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; } void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }




+ 38
- 6
mindspore/ccsrc/ps/ps_context.h View File

@@ -132,8 +132,8 @@ class PSContext {
uint64_t client_batch_size() const; uint64_t client_batch_size() const;


// Set true if worker will overwrite weights on server. Used in hybrid training. // Set true if worker will overwrite weights on server. Used in hybrid training.
void set_worker_overwrite_weights(uint64_t worker_overwrite_weights);
uint64_t worker_overwrite_weights() const;
void set_worker_upload_weights(uint64_t worker_upload_weights);
uint64_t worker_upload_weights() const;


// Set true if using secure aggregation for federated learning. // Set true if using secure aggregation for federated learning.
void set_secure_aggregation(bool secure_aggregation); void set_secure_aggregation(bool secure_aggregation);
@@ -149,7 +149,19 @@ class PSContext {
worker_num_(0), worker_num_(0),
server_num_(0), server_num_(0),
scheduler_host_(""), scheduler_host_(""),
scheduler_port_(0) {}
scheduler_port_(0),
role_(kEnvRoleOfNotPS),
server_mode_(""),
resetter_round_(ResetterRound::kNoNeedToReset),
fl_server_port_(0),
fl_client_enable_(false),
fl_name_(""),
start_fl_job_threshold_(0),
fl_iteration_num_(0),
client_epoch_num_(0),
client_batch_size_(0),
secure_aggregation_(false),
worker_upload_weights_(false) {}
bool ps_enabled_; bool ps_enabled_;
bool is_worker_; bool is_worker_;
bool is_pserver_; bool is_pserver_;
@@ -160,22 +172,42 @@ class PSContext {
std::string scheduler_host_; std::string scheduler_host_;
uint16_t scheduler_port_; uint16_t scheduler_port_;


// The server process's role.
std::string role_; std::string role_;


// Members for federated learning.
// Server mode which could be Parameter Server, Federated Learning and Hybrid Training mode.
std::string server_mode_; std::string server_mode_;

// The round which will reset the iteration. Used in federated learning for now.
ResetterRound resetter_round_; ResetterRound resetter_round_;

// Http port of federated learning server.
uint16_t fl_server_port_; uint16_t fl_server_port_;

// Whether this process is the federated client. Used in cross-silo scenario of federated learning.
bool fl_client_enable_; bool fl_client_enable_;

// Federated learning job name.
std::string fl_name_; std::string fl_name_;

// The threshold count of startFLJob round. Used in federated learning for now.
size_t start_fl_job_threshold_; size_t start_fl_job_threshold_;

// Iteration number of federeated learning, which is the number of interactions between client and server.
uint64_t fl_iteration_num_; uint64_t fl_iteration_num_;

// Client training epoch number. Used in federated learning for now.
uint64_t client_epoch_num_; uint64_t client_epoch_num_;

// Client training data batch size. Used in federated learning for now.
uint64_t client_batch_size_; uint64_t client_batch_size_;
bool worker_overwrite_weights_;


// Federated learning security.
// Whether to use secure aggregation algorithm. Used in federated learning for now.
bool secure_aggregation_; bool secure_aggregation_;

// Whether there's a federated learning worker uploading weights to federated learning server. Used in hybrid training
// mode for now.
bool worker_upload_weights_;
}; };
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore


+ 6
- 0
mindspore/ccsrc/ps/server/kernel/aggregation_kernel.h View File

@@ -58,6 +58,12 @@ class AggregationKernel : public CPUKernel {


virtual bool IsAggregationDone() = 0; virtual bool IsAggregationDone() = 0;


// Some kernels should know the inputs/workspace/outputs addresses at initializing phase. For example, FedAvgKernel.
virtual void SetParameterAddress(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
return;
}

// Setter and getter of kernels parameters information. // Setter and getter of kernels parameters information.
void set_params_info(const ParamsInfo &params_info) { params_info_ = params_info; } void set_params_info(const ParamsInfo &params_info) { params_info_ = params_info; }
const std::vector<std::string> &input_names() { return params_info_.inputs_names(); } const std::vector<std::string> &input_names() { return params_info_.inputs_names(); }


+ 15
- 2
mindspore/ccsrc/ps/server/kernel/fed_avg_kernel.h View File

@@ -44,6 +44,7 @@ class FedAvgKernel : public AggregationKernel {
public: public:
FedAvgKernel() : participated_(false) {} FedAvgKernel() : participated_(false) {}
~FedAvgKernel() override = default; ~FedAvgKernel() override = default;

void InitKernel(const CNodePtr &kernel_node) override { void InitKernel(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node); std::string cnode_name = AnfAlgo::GetCNodeName(kernel_node);
@@ -97,6 +98,7 @@ class FedAvgKernel : public AggregationKernel {
DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler}); DistributedCountService::GetInstance().RegisterCounter(name_, done_count_, {first_cnt_handler, last_cnt_handler});
return; return;
} }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override { const std::vector<AddressPtr> &outputs) override {
std::unique_lock<std::mutex> lock(weight_mutex_); std::unique_lock<std::mutex> lock(weight_mutex_);
@@ -125,14 +127,25 @@ class FedAvgKernel : public AggregationKernel {
name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_)); name_, std::to_string(DistributedCountService::GetInstance().local_rank()) + "_" + std::to_string(accum_count_));
return true; return true;
} }
void Reset() {

void Reset() override {
accum_count_ = 0; accum_count_ = 0;
done_ = false; done_ = false;
participated_ = false; participated_ = false;
DistributedCountService::GetInstance().ResetCounter(name_); DistributedCountService::GetInstance().ResetCounter(name_);
return; return;
} }
bool IsAggregationDone() { return done_; }

bool IsAggregationDone() override { return done_; }

void SetParameterAddress(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
weight_addr_ = inputs[0];
data_size_addr_ = inputs[1];
new_weight_addr_ = inputs[2];
new_data_size_addr_ = inputs[3];
return;
}


private: private:
void GenerateReuseKernelNodeInfo() override { void GenerateReuseKernelNodeInfo() override {


+ 1
- 1
mindspore/ccsrc/ps/server/kernel/round/update_model_kernel.cc View File

@@ -96,7 +96,7 @@ void UpdateModelKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHand
MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is " MS_LOG(INFO) << "Total data size for iteration " << LocalMetaStore::GetInstance().curr_iter_num() << " is "
<< total_data_size; << total_data_size;


FinishIterCb();
FinishIteration();
} }
} }




+ 2
- 0
mindspore/ccsrc/ps/server/memory_register.cc View File

@@ -29,6 +29,8 @@ void MemoryRegister::StoreFloatArray(std::unique_ptr<float[]> *array) { float_ar
void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) { int32_arrays_.push_back(std::move(*array)); } void MemoryRegister::StoreInt32Array(std::unique_ptr<int[]> *array) { int32_arrays_.push_back(std::move(*array)); }


void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64_arrays_.push_back(std::move(*array)); } void MemoryRegister::StoreUint64Array(std::unique_ptr<size_t[]> *array) { uint64_arrays_.push_back(std::move(*array)); }

void MemoryRegister::StoreCharArray(std::unique_ptr<char[]> *array) { char_arrays_.push_back(std::move(*array)); }
} // namespace server } // namespace server
} // namespace ps } // namespace ps
} // namespace mindspore } // namespace mindspore

+ 5
- 0
mindspore/ccsrc/ps/server/memory_register.h View File

@@ -58,6 +58,9 @@ class MemoryRegister {
} else if (typeid(T) == typeid(size_t)) { } else if (typeid(T) == typeid(size_t)) {
auto uint64_arr = CastUniquePtr<size_t, T>(array); auto uint64_arr = CastUniquePtr<size_t, T>(array);
StoreUint64Array(&uint64_arr); StoreUint64Array(&uint64_arr);
} else if (typeid(T) == typeid(char)) {
auto char_arr = CastUniquePtr<char, T>(array);
StoreCharArray(&char_arr);
} else { } else {
MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name(); MS_LOG(ERROR) << "MemoryRegister does not support type " << typeid(T).name();
return; return;
@@ -72,10 +75,12 @@ class MemoryRegister {
std::vector<std::unique_ptr<float[]>> float_arrays_; std::vector<std::unique_ptr<float[]>> float_arrays_;
std::vector<std::unique_ptr<int[]>> int32_arrays_; std::vector<std::unique_ptr<int[]>> int32_arrays_;
std::vector<std::unique_ptr<size_t[]>> uint64_arrays_; std::vector<std::unique_ptr<size_t[]>> uint64_arrays_;
std::vector<std::unique_ptr<char[]>> char_arrays_;


void StoreInt32Array(std::unique_ptr<int[]> *array); void StoreInt32Array(std::unique_ptr<int[]> *array);
void StoreFloatArray(std::unique_ptr<float[]> *array); void StoreFloatArray(std::unique_ptr<float[]> *array);
void StoreUint64Array(std::unique_ptr<size_t[]> *array); void StoreUint64Array(std::unique_ptr<size_t[]> *array);
void StoreCharArray(std::unique_ptr<char[]> *array);


template <typename T, typename S> template <typename T, typename S>
std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) { std::unique_ptr<T[]> CastUniquePtr(std::unique_ptr<S[]> *array) {


+ 1
- 1
mindspore/ccsrc/ps/server/model_store.cc View File

@@ -68,7 +68,7 @@ bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::strin
auto &stored_model = memory_register->addresses(); auto &stored_model = memory_register->addresses();
for (const auto &weight : new_model) { for (const auto &weight : new_model) {
const std::string &weight_name = weight.first; const std::string &weight_name = weight.first;
if (stored_model.count(weight_name) != 0) {
if (stored_model.count(weight_name) == 0) {
MS_LOG(ERROR) << "The stored model has no weight " << weight_name; MS_LOG(ERROR) << "The stored model has no weight " << weight_name;
continue; continue;
} }


+ 15
- 5
mindspore/ccsrc/ps/server/parameter_aggregator.cc View File

@@ -183,10 +183,10 @@ bool ParameterAggregator::InitAggregationKernels(const CNodePtr &cnode) {
} }


bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) { bool ParameterAggregator::InitOptimizerKernels(const CNodePtr &cnode) {
// if (PSContext::instance()->server_mode() == kServerModeFL) {
// MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
// return false;
// }
if (PSContext::instance()->server_mode() == kServerModeFL) {
MS_LOG(DEBUG) << "Federated learning mode doesn't need optimizer kernel.";
return false;
}
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
const std::string &name = AnfAlgo::GetCNodeName(cnode); const std::string &name = AnfAlgo::GetCNodeName(cnode);
auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode); auto optimizer_kernel = kernel::OptimizerKernelFactory::GetInstance().Create(name, cnode);
@@ -275,6 +275,7 @@ bool ParameterAggregator::GenerateAggregationKernelParams(const std::shared_ptr<
std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs), std::transform(output_names.begin(), output_names.end(), std::back_inserter(aggr_params.outputs),
[&](const std::string &name) { return memory_register->addresses()[name]; }); [&](const std::string &name) { return memory_register->addresses()[name]; });


aggr_kernel->SetParameterAddress(aggr_params.inputs, aggr_params.workspace, aggr_params.outputs);
aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params)); aggregation_kernel_parameters_.push_back(std::make_pair(aggr_kernel, aggr_params));
return true; return true;
} }
@@ -302,7 +303,16 @@ bool ParameterAggregator::GenerateOptimizerKernelParams(const std::shared_ptr<ke
} }


std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) { std::vector<std::string> ParameterAggregator::SelectAggregationAlgorithm(const CNodePtr &cnode) {
std::vector<std::string> aggregation_algorithm = {kFedAvg};
std::vector<std::string> aggregation_algorithm = {};
if (PSContext::instance()->server_mode() == kServerModeFL ||
PSContext::instance()->server_mode() == kServerModeHybrid) {
aggregation_algorithm.push_back("FedAvg");
} else if (PSContext::instance()->server_mode() == kServerModePS) {
aggregation_algorithm.push_back("DenseGradAccum");
} else {
MS_LOG(ERROR) << "Server doesn't support mode " << PSContext::instance()->server_mode();
}

MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm; MS_LOG(INFO) << "Aggregation algorithm selection result: " << aggregation_algorithm;
return aggregation_algorithm; return aggregation_algorithm;
} }


Loading…
Cancel
Save