| @@ -35,7 +35,7 @@ function(ms_build_flatbuffers source_schema_files | |||||
| set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs}) | set(total_schema_dirs -I ${schema_dir} ${total_schema_dirs}) | ||||
| endforeach() | endforeach() | ||||
| foreach(schema ${source_schema_files}) | |||||
| foreach(schema IN LISTS ${source_schema_files}) | |||||
| get_filename_component(filename ${schema} NAME_WE) | get_filename_component(filename ${schema} NAME_WE) | ||||
| if(NOT ${generated_output_dir} STREQUAL "") | if(NOT ${generated_output_dir} STREQUAL "") | ||||
| set(generated_file ${generated_output_dir}/${filename}_generated.h) | set(generated_file ${generated_output_dir}/${filename}_generated.h) | ||||
| @@ -212,7 +212,7 @@ if(ENABLE_GPU) | |||||
| ) | ) | ||||
| endif() | endif() | ||||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||||
| if(ENABLE_CPU AND NOT WIN32) | |||||
| install( | install( | ||||
| TARGETS ps_cache | TARGETS ps_cache | ||||
| DESTINATION ${INSTALL_LIB_DIR} | DESTINATION ${INSTALL_LIB_DIR} | ||||
| @@ -373,7 +373,7 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||||
| target_link_libraries(mindspore mindspore_gvar) | target_link_libraries(mindspore mindspore_gvar) | ||||
| target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load) | target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore mindspore_core -Wl,-noall_load) | ||||
| else() | else() | ||||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||||
| if(ENABLE_CPU AND NOT WIN32) | |||||
| target_link_libraries(mindspore proto_input mindspore::protobuf | target_link_libraries(mindspore proto_input mindspore::protobuf | ||||
| mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::json) | mindspore::event mindspore::event_pthreads mindspore::event_openssl mindspore::json) | ||||
| target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) | target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) | ||||
| @@ -75,7 +75,7 @@ if(ENABLE_CPU) | |||||
| endif() | endif() | ||||
| endif() | endif() | ||||
| if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||||
| if(NOT ENABLE_CPU OR WIN32) | |||||
| list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc") | list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/apply_momentum_ps_kernel.cc") | ||||
| list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc") | list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_proxy_kernel.cc") | ||||
| list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc") | list(REMOVE_ITEM CPU_SRC_LIST "cpu/ps/embedding_look_up_ps_kernel.cc") | ||||
| @@ -421,7 +421,7 @@ void AscendSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_gra | |||||
| size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); | size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); | ||||
| } | } | ||||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| const std::string ¶m_name = input_node->fullname_with_scope(); | const std::string ¶m_name = input_node->fullname_with_scope(); | ||||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | if (ps::ps_cache_instance.IsHashTable(param_name)) { | ||||
| continue; | continue; | ||||
| @@ -33,7 +33,7 @@ | |||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #include "debug/dump_proto.h" | #include "debug/dump_proto.h" | ||||
| #include "debug/data_dump/dump_json_parser.h" | #include "debug/data_dump/dump_json_parser.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "ps/ps_context.h" | #include "ps/ps_context.h" | ||||
| #endif | #endif | ||||
| @@ -74,7 +74,7 @@ void CPUSession::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderPos | |||||
| void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | void CPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) { | ||||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | auto optimizer = std::make_shared<opt::GraphOptimizer>(); | ||||
| auto pm = std::make_shared<opt::PassManager>(); | auto pm = std::make_shared<opt::PassManager>(); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) { | if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && ps::PSContext::instance()->is_ps_mode()) { | ||||
| @@ -174,7 +174,7 @@ void CPUSession::PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_grap | |||||
| MS_LOG(INFO) << "Bind input output address"; | MS_LOG(INFO) << "Bind input output address"; | ||||
| runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); | runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| InitPSParamAndOptim(kernel_graph, inputs); | InitPSParamAndOptim(kernel_graph, inputs); | ||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/scoped_long_running.h" | #include "utils/scoped_long_running.h" | ||||
| #include "pybind_api/ir/tensor_py.h" | #include "pybind_api/ir/tensor_py.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | #include "ps/ps_cache/ps_cache_manager.h" | ||||
| #endif | #endif | ||||
| @@ -43,7 +43,7 @@ | |||||
| #include "debug/common.h" | #include "debug/common.h" | ||||
| #include "utils/trace_base.h" | #include "utils/trace_base.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | #include "ps/ps_cache/ps_cache_manager.h" | ||||
| #include "ps/constants.h" | #include "ps/constants.h" | ||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| @@ -2357,7 +2357,7 @@ void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { | |||||
| #endif | #endif | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { | void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { | ||||
| if (!ps::PSContext::instance()->is_worker()) { | if (!ps::PSContext::instance()->is_worker()) { | ||||
| return; | return; | ||||
| @@ -244,7 +244,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| std::vector<uint32_t> GetAllReduceSplitIndex(); | std::vector<uint32_t> GetAllReduceSplitIndex(); | ||||
| virtual std::string GetCommWorldGroup() { return std::string(); } | virtual std::string GetCommWorldGroup() { return std::string(); } | ||||
| void DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph); | void DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const; | void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const; | ||||
| void GetBatchElements(const AnfNodePtr &kernel_node) const; | void GetBatchElements(const AnfNodePtr &kernel_node) const; | ||||
| void InitPsWorker(const KernelGraphPtr &kernel_graph); | void InitPsWorker(const KernelGraphPtr &kernel_graph); | ||||
| @@ -263,7 +263,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | #if !defined(_WIN32) && !defined(_WIN64) | ||||
| std::shared_ptr<Debugger> debugger_; | std::shared_ptr<Debugger> debugger_; | ||||
| #endif | #endif | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| bool initialized_ps_cache_{false}; | bool initialized_ps_cache_{false}; | ||||
| #endif | #endif | ||||
| }; | }; | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "frontend/parallel/device_matrix.h" | #include "frontend/parallel/device_matrix.h" | ||||
| #include "frontend/parallel/graph_util/generate_graph.h" | #include "frontend/parallel/graph_util/generate_graph.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | #include "ps/ps_cache/ps_cache_manager.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #endif | #endif | ||||
| @@ -160,7 +160,7 @@ Status GatherPInfo::GetAttrs() { | |||||
| if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) { | if (std::find(inputs_shape_[1].begin(), inputs_shape_[1].end(), -1) != inputs_shape_[1].end()) { | ||||
| dynamic_shape_indices_ = true; | dynamic_shape_indices_ = true; | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | ||||
| std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); | std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); | ||||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | ||||
| @@ -617,7 +617,7 @@ Status GatherPInfo::InferBias() { | |||||
| rank = rank % (params_strategy[0] * params_strategy[1]); | rank = rank % (params_strategy[0] * params_strategy[1]); | ||||
| } | } | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | ||||
| bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); | bias_ = static_cast<int64_t>(ps::PsCacheManager::GetInstance().cache_indices_lower_bound()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -28,7 +28,7 @@ | |||||
| #include "frontend/parallel/strategy.h" | #include "frontend/parallel/strategy.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | #include "frontend/parallel/tensor_layout/tensor_redistribution.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | #include "ps/ps_cache/ps_cache_manager.h" | ||||
| #endif | #endif | ||||
| @@ -192,7 +192,7 @@ Status UniqueInfo::GenerateStrategies(int64_t stage_id) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | ||||
| GenerateGraph gen_g = GenerateGraph(); | GenerateGraph gen_g = GenerateGraph(); | ||||
| if (gen_g.Init(cnode) != SUCCESS) { | if (gen_g.Init(cnode) != SUCCESS) { | ||||
| @@ -230,7 +230,7 @@ Status UniqueInfo::ComputeReplaceGraph(const CNodePtr &cnode) { | |||||
| #endif | #endif | ||||
| ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) { | ReplaceGraphPtr UniqueInfo::replace_graph(const CNodePtr &cnode) { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | ||||
| auto inputs = cnode->inputs(); | auto inputs = cnode->inputs(); | ||||
| if (inputs.empty()) { | if (inputs.empty()) { | ||||
| @@ -51,7 +51,7 @@ class UniqueInfo : public OperatorInfo { | |||||
| Status InferMirrorOps() override; | Status InferMirrorOps() override; | ||||
| Status InferForwardCommunication() override { return SUCCESS; } | Status InferForwardCommunication() override { return SUCCESS; } | ||||
| Status InferAsLossDivisor() override { return SUCCESS; } | Status InferAsLossDivisor() override { return SUCCESS; } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| Status ComputeReplaceGraph(const CNodePtr &cnode); | Status ComputeReplaceGraph(const CNodePtr &cnode); | ||||
| #endif | #endif | ||||
| @@ -47,14 +47,14 @@ | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { | if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -46,7 +46,7 @@ | |||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #include "mindspore/core/utils/parallel_node_check.h" | #include "mindspore/core/utils/parallel_node_check.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "ps/ps_context.h" | #include "ps/ps_context.h" | ||||
| #endif | #endif | ||||
| @@ -3553,7 +3553,7 @@ static void HandleFullySplitParameters(const FuncGraphPtr &root) { | |||||
| } | } | ||||
| bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { | bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { | if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -295,7 +295,7 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") | |||||
| target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite) | target_link_libraries(_c_dataengine PRIVATE _c_mindrecord ${MINDRECORD_LINK_OBJECT} mindspore::sqlite) | ||||
| else() | else() | ||||
| target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) | target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) | ||||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||||
| if(ENABLE_CPU AND NOT WIN32) | |||||
| if(${ENABLE_IBVERBS} STREQUAL "ON") | if(${ENABLE_IBVERBS} STREQUAL "ON") | ||||
| target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) | target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) | ||||
| endif() | endif() | ||||
| @@ -1,7 +1,8 @@ | |||||
| add_subdirectory(perf EXCLUDE_FROM_ALL) | add_subdirectory(perf EXCLUDE_FROM_ALL) | ||||
| include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | ||||
| set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache") | ||||
| ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | |||||
| set(FBS_FILES de_tensor.fbs) | |||||
| ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU}) | |||||
| file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) | ||||
| @@ -43,7 +43,7 @@ | |||||
| #include "vm/transform.h" | #include "vm/transform.h" | ||||
| #include "parse/python_adapter.h" | #include "parse/python_adapter.h" | ||||
| #include "frontend/optimizer/py_pass_manager.h" | #include "frontend/optimizer/py_pass_manager.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/parameter_server.h" | #include "ps/parameter_server.h" | ||||
| #include "ps/scheduler.h" | #include "ps/scheduler.h" | ||||
| #include "ps/worker.h" | #include "ps/worker.h" | ||||
| @@ -606,7 +606,7 @@ bool ExecuteAction(const ResourcePtr &res) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| bool StartPSWorkerAction(const ResourcePtr &res) { | bool StartPSWorkerAction(const ResourcePtr &res) { | ||||
| ps::Worker::GetInstance().Run(); | ps::Worker::GetInstance().Run(); | ||||
| return true; | return true; | ||||
| @@ -782,7 +782,7 @@ std::vector<ActionItem> VmPipeline() { | |||||
| actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); | actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); | ||||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | actions.emplace_back(std::make_pair("validate", ValidateAction)); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PSContext::instance()->is_worker()) { | if (ps::PSContext::instance()->is_worker()) { | ||||
| actions.emplace_back(std::make_pair("worker", StartPSWorkerAction)); | actions.emplace_back(std::make_pair("worker", StartPSWorkerAction)); | ||||
| } | } | ||||
| @@ -796,7 +796,7 @@ std::vector<ActionItem> VmPipeline() { | |||||
| return actions; | return actions; | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| std::vector<ActionItem> PServerPipeline() { | std::vector<ActionItem> PServerPipeline() { | ||||
| auto actions = CommonPipeline(); | auto actions = CommonPipeline(); | ||||
| actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | actions.emplace_back(std::make_pair("optimize", VmOptimizeAction)); | ||||
| @@ -34,7 +34,7 @@ | |||||
| #else | #else | ||||
| #include "runtime/device/gpu/distribution/collective_fake_init.h" | #include "runtime/device/gpu/distribution/collective_fake_init.h" | ||||
| #endif | #endif | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #endif | #endif | ||||
| #include "ps/ps_context.h" | #include "ps/ps_context.h" | ||||
| @@ -42,7 +42,7 @@ | |||||
| #include "pipeline/jit/pipeline_split.h" | #include "pipeline/jit/pipeline_split.h" | ||||
| #include "pipeline/jit/static_analysis/auto_monad.h" | #include "pipeline/jit/static_analysis/auto_monad.h" | ||||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | #include "frontend/optimizer/irpass/gradient_eliminate.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "ps/ps_context.h" | #include "ps/ps_context.h" | ||||
| #endif | #endif | ||||
| @@ -407,7 +407,7 @@ bool AddRecomputationPass(const ResourcePtr &res) { | |||||
| } | } | ||||
| bool AddCacheEmbeddingPass(const ResourcePtr &res) { | bool AddCacheEmbeddingPass(const ResourcePtr &res) { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PSContext::instance()->is_ps_mode()) { | if (ps::PSContext::instance()->is_ps_mode()) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -49,7 +49,7 @@ | |||||
| #include "utils/shape_utils.h" | #include "utils/shape_utils.h" | ||||
| #include "utils/info.h" | #include "utils/info.h" | ||||
| #include "load_mindir/load_model.h" | #include "load_mindir/load_model.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/constants.h" | #include "ps/constants.h" | ||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "ps/worker.h" | #include "ps/worker.h" | ||||
| @@ -528,7 +528,7 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri | |||||
| std::string backend = MsContext::GetInstance()->backend_policy(); | std::string backend = MsContext::GetInstance()->backend_policy(); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PSContext::instance()->is_server()) { | if (ps::PSContext::instance()->is_server()) { | ||||
| resource->results()[kBackend] = compile::CreateBackend(); | resource->results()[kBackend] = compile::CreateBackend(); | ||||
| return PServerPipeline(); | return PServerPipeline(); | ||||
| @@ -961,7 +961,7 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba | |||||
| bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, | bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, | ||||
| const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, | const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, | ||||
| const std::vector<int64_t> &input_indexes, bool need_run) { | const std::vector<int64_t> &input_indexes, bool need_run) { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) { | if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -1027,7 +1027,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||||
| } | } | ||||
| ConfigManager::GetInstance().set_iter_num(size); | ConfigManager::GetInstance().set_iter_num(size); | ||||
| // PS cache does not support loop sink. | // PS cache does not support loop sink. | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | ||||
| ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); | ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); | ||||
| ConfigManager::GetInstance().set_iter_num(1); | ConfigManager::GetInstance().set_iter_num(1); | ||||
| @@ -1150,7 +1150,7 @@ void FinalizeBackend() { | |||||
| void ClearResAtexit() { | void ClearResAtexit() { | ||||
| MS_LOG(DEBUG) << "Pipeline clear all resource"; | MS_LOG(DEBUG) << "Pipeline clear all resource"; | ||||
| pynative::ClearPyNativeSession(); | pynative::ClearPyNativeSession(); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) { | if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) { | ||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | ||||
| ps::ps_cache_instance.Finalize(); | ps::ps_cache_instance.Finalize(); | ||||
| @@ -1,6 +1,13 @@ | |||||
| file(GLOB_RECURSE _PS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | file(GLOB_RECURSE _PS_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") | ||||
| if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||||
| set(SERVER_FLATBUFFER_OUTPUT "${CMAKE_BINARY_DIR}/schema") | |||||
| set(FBS_FILES | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../schema/cipher.fbs | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../schema/fl_job.fbs | |||||
| ) | |||||
| ms_build_flatbuffers(FBS_FILES ${CMAKE_CURRENT_SOURCE_DIR}../../schema generated_fbs_files ${SERVER_FLATBUFFER_OUTPUT}) | |||||
| if(NOT ENABLE_CPU OR WIN32) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info_builder.cc") | list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info_builder.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") | list(REMOVE_ITEM _PS_SRC_FILES "optimizer_info.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") | list(REMOVE_ITEM _PS_SRC_FILES "scheduler.cc") | ||||
| @@ -12,11 +19,6 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_client.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_client.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_message_handler.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_message_handler.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_server.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | ||||
| @@ -39,18 +41,32 @@ if(NOT ENABLE_GPU) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | ||||
| endif() | endif() | ||||
| if(WIN32 OR NOT ENABLE_CPU) | |||||
| if(NOT ENABLE_CPU OR WIN32) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/communicator_base.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_communicator.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_communicator.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/http_msg_handler.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/communicator/tcp_msg_handler.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/apply_momentum_kernel.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/aggregation_kernel_factory.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/dense_grad_accum_kernel.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/optimizer_kernel_factory.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel_factory.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/round_kernel.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/round/start_fl_job_kernel.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/kernel/params_info.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/consistent_hash_ring.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/iteration_timer.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_storage.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/local_meta_store.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/memory_register.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/parameter_aggregator.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc") | list(REMOVE_ITEM _PS_SRC_FILES "server/executor.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/collective_ops_impl.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_count_service.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/distributed_metadata_store.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/iteration.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/model_store.cc") | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "server/round.cc") | |||||
| endif() | endif() | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | ||||
| @@ -59,3 +75,5 @@ add_subdirectory(ps_cache) | |||||
| set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | ||||
| add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) | add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) | ||||
| add_dependencies(_mindspore_ps_obj generated_fbs_files) | |||||
| target_link_libraries(_mindspore_ps_obj mindspore::flatbuffers) | |||||
| @@ -34,13 +34,26 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace core { | namespace core { | ||||
| enum class TcpUserCommand { kPush, kPull, kCount, kReachThreshold, kResetCount, kGetValue, kPutValue, kCounterEvent }; | |||||
| enum class TcpUserCommand { | |||||
| kPush, | |||||
| kPull, | |||||
| kCount, | |||||
| kReachThreshold, | |||||
| kResetCount, | |||||
| kGetMetadata, | |||||
| kUpdateMetadata, | |||||
| kCounterEvent | |||||
| }; | |||||
| const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = { | const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = { | ||||
| {TcpUserCommand::kPush, "push"}, {TcpUserCommand::kPull, "pull"}, | |||||
| {TcpUserCommand::kCount, "count"}, {TcpUserCommand::kReachThreshold, "reachThreshold"}, | |||||
| {TcpUserCommand::kResetCount, "resetCnt"}, {TcpUserCommand::kGetValue, "getValue"}, | |||||
| {TcpUserCommand::kPutValue, "putValue"}, {TcpUserCommand::kCounterEvent, "counterEvent"}, | |||||
| {TcpUserCommand::kPush, "push"}, | |||||
| {TcpUserCommand::kPull, "pull"}, | |||||
| {TcpUserCommand::kCount, "count"}, | |||||
| {TcpUserCommand::kReachThreshold, "countReachThreshold"}, | |||||
| {TcpUserCommand::kResetCount, "resetCnt"}, | |||||
| {TcpUserCommand::kGetMetadata, "getMetadata"}, | |||||
| {TcpUserCommand::kUpdateMetadata, "updateMetadata"}, | |||||
| {TcpUserCommand::kCounterEvent, "counterEvent"}, | |||||
| }; | }; | ||||
| class TcpCommunicator : public CommunicatorBase { | class TcpCommunicator : public CommunicatorBase { | ||||
| @@ -0,0 +1,155 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| syntax = "proto3"; | |||||
| package mindspore.ps; | |||||
| message CollectiveData { | |||||
| bytes data = 1; | |||||
| } | |||||
| message CountRequest { | |||||
| string name = 1; | |||||
| string id = 2; | |||||
| } | |||||
| message CountResponse { | |||||
| bool result = 1; | |||||
| string reason = 2; | |||||
| } | |||||
| message CountReachThresholdRequest { | |||||
| string name = 1; | |||||
| } | |||||
| message CountReachThresholdResponse { | |||||
| bool is_enough = 1; | |||||
| } | |||||
| message ResetCounterRequest { | |||||
| string name = 1; | |||||
| } | |||||
| message UpdateMetadataRequest { | |||||
| string name = 1; | |||||
| bytes value = 2; | |||||
| } | |||||
| message GetMetadataRequest { | |||||
| string name = 1; | |||||
| } | |||||
| message GetMetadataResponse { | |||||
| bytes value = 1; | |||||
| } | |||||
| enum CounterEventType { | |||||
| FIRST_CNT = 0; | |||||
| LAST_CNT = 1; | |||||
| } | |||||
| message CounterEvent { | |||||
| CounterEventType type = 1; | |||||
| string name = 2; | |||||
| bytes data = 3; | |||||
| } | |||||
| message FLId { | |||||
| string fl_id = 1; | |||||
| } | |||||
| message UpdateModelClientList { | |||||
| repeated string fl_id = 1; | |||||
| } | |||||
| message DeviceMeta { | |||||
| string fl_name = 1; | |||||
| string fl_id = 2; | |||||
| uint64 data_size = 3; | |||||
| } | |||||
| message FLIdToDeviceMeta { | |||||
| map<string, DeviceMeta> fl_id_to_meta = 1; | |||||
| } | |||||
| message UpdateModelThreshold { | |||||
| uint64 threshold = 1; | |||||
| } | |||||
| message ClientShares { | |||||
| map<string, SharesPb> client_secret_shares = 1; | |||||
| } | |||||
| message PairClientShares { | |||||
| string fl_id = 1; | |||||
| SharesPb client_shares = 2; | |||||
| } | |||||
| message ClientKeys { | |||||
| map<string, KeysPb> client_keys = 1; | |||||
| } | |||||
| message ClientNoises { | |||||
| OneClientNoises one_client_noises = 1; | |||||
| } | |||||
| message PairClientKeys { | |||||
| string fl_id = 1; | |||||
| KeysPb client_keys = 2; | |||||
| } | |||||
| message OneClientNoises { | |||||
| repeated float noise = 1; | |||||
| } | |||||
| message ClientShareStr { | |||||
| string fl_id = 1; | |||||
| bytes share = 2; // todo: verify the correctness | |||||
| int32 index = 3; | |||||
| } | |||||
| message SharesPb { | |||||
| repeated ClientShareStr clientsharestrs = 1; | |||||
| } | |||||
| message KeysPb { | |||||
| repeated bytes key = 1; | |||||
| } | |||||
| message PBMetadata { | |||||
| oneof value { | |||||
| DeviceMeta device_meta = 1; | |||||
| FLIdToDeviceMeta device_metas = 2; | |||||
| FLId fl_id = 3; | |||||
| UpdateModelClientList client_list = 4; | |||||
| UpdateModelThreshold update_model_threshold = 5; | |||||
| PairClientShares pair_client_shares = 6; | |||||
| ClientShares client_shares = 7; | |||||
| PairClientKeys pair_client_keys = 8; | |||||
| ClientKeys client_keys = 9; | |||||
| OneClientNoises one_client_noises = 10; | |||||
| ClientNoises client_noises = 11; | |||||
| } | |||||
| } | |||||
| message PBMetadataWithName { | |||||
| string name = 1; | |||||
| PBMetadata metadata = 2; | |||||
| } | |||||
| @@ -60,4 +60,4 @@ message EmbeddingTableLookup { | |||||
| uint64 key = 2; | uint64 key = 2; | ||||
| repeated int32 keys = 3; | repeated int32 keys = 3; | ||||
| repeated float values = 4; | repeated float values = 4; | ||||
| } | |||||
| } | |||||
| @@ -1,4 +1,4 @@ | |||||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||||
| if(ENABLE_CPU AND NOT WIN32) | |||||
| file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc") | file(GLOB_RECURSE _PS_CACHE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ps_data/*.cc") | ||||
| set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | set_property(SOURCE ${_PS_CACHE_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | ||||
| add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES}) | add_library(ps_cache SHARED ${_PS_CACHE_SRC_FILES}) | ||||
| @@ -18,7 +18,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "backend/kernel_compiler/kernel.h" | #include "backend/kernel_compiler/kernel.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | #include "ps/ps_cache/ps_cache_manager.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | ||||
| #endif | #endif | ||||
| @@ -68,7 +68,7 @@ void PSContext::Reset() { | |||||
| is_worker_ = false; | is_worker_ = false; | ||||
| is_pserver_ = false; | is_pserver_ = false; | ||||
| is_sched_ = false; | is_sched_ = false; | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | ||||
| ps_cache_instance.Finalize(); | ps_cache_instance.Finalize(); | ||||
| set_cache_enable(false); | set_cache_enable(false); | ||||
| @@ -108,46 +108,62 @@ int PSContext::ps_rank_id() const { return rank_id_; } | |||||
| void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, | void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, | ||||
| size_t vocab_size) const { | size_t vocab_size) const { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size); | ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size); | ||||
| #endif | #endif | ||||
| } | } | ||||
| void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, | void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, | ||||
| size_t cache_vocab_size, size_t embedding_size) const { | size_t cache_vocab_size, size_t embedding_size) const { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size); | ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size); | ||||
| #endif | #endif | ||||
| } | } | ||||
| void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const { | void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed); | ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed); | ||||
| #endif | #endif | ||||
| } | } | ||||
| void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const { | void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| ps_cache_instance.InsertAccumuInitInfo(param_name, init_val); | ps_cache_instance.InsertAccumuInitInfo(param_name, init_val); | ||||
| #endif | #endif | ||||
| } | } | ||||
| void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const { | void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| ps_cache_instance.CloneHashTable(dest_param_name, src_param_name); | ps_cache_instance.CloneHashTable(dest_param_name, src_param_name); | ||||
| #endif | #endif | ||||
| } | } | ||||
| void PSContext::set_cache_enable(bool cache_enable) const { | void PSContext::set_cache_enable(bool cache_enable) const { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); | PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); | ||||
| #endif | #endif | ||||
| } | } | ||||
| void PSContext::set_rank_id(int rank_id) const { | void PSContext::set_rank_id(int rank_id) const { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| ps_cache_instance.set_rank_id(rank_id); | ps_cache_instance.set_rank_id(rank_id); | ||||
| #endif | #endif | ||||
| } | } | ||||
| void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; } | |||||
| const std::string &PSContext::fl_name() const { return fl_name_; } | |||||
| void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; } | |||||
| uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; } | |||||
| void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; } | |||||
| uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; } | |||||
| void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; } | |||||
| uint64_t PSContext::client_batch_size() const { return client_batch_size_; } | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,6 +60,19 @@ class PSContext { | |||||
| void set_cache_enable(bool cache_enable) const; | void set_cache_enable(bool cache_enable) const; | ||||
| void set_rank_id(int rank_id) const; | void set_rank_id(int rank_id) const; | ||||
| // Setter and getter for federated learning. | |||||
| void set_fl_name(const std::string &fl_name); | |||||
| const std::string &fl_name() const; | |||||
| void set_fl_iteration_num(uint64_t fl_iteration_num); | |||||
| uint64_t fl_iteration_num() const; | |||||
| void set_client_epoch_num(uint64_t client_epoch_num); | |||||
| uint64_t client_epoch_num() const; | |||||
| void set_client_batch_size(uint64_t client_batch_size); | |||||
| uint64_t client_batch_size() const; | |||||
| private: | private: | ||||
| PSContext() | PSContext() | ||||
| : ps_enabled_(false), | : ps_enabled_(false), | ||||
| @@ -80,6 +93,12 @@ class PSContext { | |||||
| uint32_t server_num_; | uint32_t server_num_; | ||||
| std::string scheduler_host_; | std::string scheduler_host_; | ||||
| uint16_t scheduler_port_; | uint16_t scheduler_port_; | ||||
| // Members for federated learning. | |||||
| std::string fl_name_; | |||||
| uint64_t fl_iteration_num_; | |||||
| uint64_t client_epoch_num_; | |||||
| uint64_t client_batch_size_; | |||||
| }; | }; | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,223 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/collective_ops_impl.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| void CollectiveOpsImpl::Initialize(const std::shared_ptr<core::ServerNode> &server_node) { | |||||
| MS_EXCEPTION_IF_NULL(server_node); | |||||
| server_node_ = server_node; | |||||
| local_rank_ = server_node_->rank_id(); | |||||
| server_num_ = PSContext::instance()->initial_server_num(); | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| bool CollectiveOpsImpl::RingAllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||||
| int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T)); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||||
| return false; | |||||
| } | |||||
| uint32_t rank_size = server_num_; | |||||
| uint32_t local_rank_ = server_node_->rank_id(); | |||||
| size_t chunk_size = count / rank_size; | |||||
| size_t remainder_size = count % rank_size; | |||||
| std::vector<size_t> chunk_sizes(rank_size, chunk_size); | |||||
| // The rest of the data should be assigned to each chunk. | |||||
| for (size_t i = 0; i < remainder_size; i++) { | |||||
| chunk_sizes[i]++; | |||||
| } | |||||
| // Store offsets to get every data chunk's address. | |||||
| std::vector<size_t> chunk_offset; | |||||
| for (size_t i = 0; i < rank_size; i++) { | |||||
| size_t ofs = | |||||
| std::accumulate(chunk_sizes.begin(), chunk_sizes.begin() + i, static_cast<size_t>(0), std::plus<size_t>()); | |||||
| chunk_offset.push_back(ofs); | |||||
| } | |||||
| T *output_buff = reinterpret_cast<T *>(recvbuff); | |||||
| uint32_t send_to_rank = (local_rank_ + 1) % rank_size; | |||||
| uint32_t recv_from_rank = (local_rank_ - 1 + rank_size) % rank_size; | |||||
| MS_LOG(DEBUG) << "AllReduce count:" << count << ", rank_size:" << rank_size << ", local_rank_:" << local_rank_ | |||||
| << ", chunk_size:" << chunk_size << ", remainder_size:" << remainder_size | |||||
| << ", chunk_sizes:" << chunk_sizes << ", send_to_rank:" << send_to_rank | |||||
| << ", recv_from_rank:" << recv_from_rank; | |||||
| // Ring ReduceScatter. | |||||
| MS_LOG(DEBUG) << "Start Ring ReduceScatter."; | |||||
| std::unique_ptr<T[]> tmp_recv_chunk = std::make_unique<T[]>(chunk_sizes[0]); | |||||
| for (size_t i = 0; i < rank_size - 1; i++) { | |||||
| // Step 1: Async send data to next rank. | |||||
| size_t send_chunk_index = (local_rank_ - i + rank_size) % rank_size; | |||||
| T *send_chunk = output_buff + chunk_offset[send_chunk_index]; | |||||
| auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk, | |||||
| chunk_sizes[send_chunk_index] * sizeof(T)); | |||||
| // Step 2: Async receive data to next rank and wait until it's done. | |||||
| size_t recv_chunk_index = (local_rank_ - i - 1 + rank_size) % rank_size; | |||||
| T *recv_chunk = output_buff + chunk_offset[recv_chunk_index]; | |||||
| MS_LOG(DEBUG) << "Ring ReduceScatter send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank | |||||
| << ", send count:" << chunk_sizes[send_chunk_index] | |||||
| << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i; | |||||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str); | |||||
| if (!server_node_->CollectiveWait(recv_req_id, 1)) { | |||||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||||
| return false; | |||||
| } | |||||
| memcpy_s(tmp_recv_chunk.get(), chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size()); | |||||
| // Step 3: Reduce the data so we can overlap the time cost of send. | |||||
| for (size_t j = 0; j < chunk_sizes[recv_chunk_index]; j++) { | |||||
| recv_chunk[j] += tmp_recv_chunk[j]; | |||||
| } | |||||
| // Step 4: Wait until send is done. | |||||
| if (!server_node_->Wait(send_req_id, 1)) { | |||||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "End Ring ReduceScatter."; | |||||
| // Ring AllGather. | |||||
| MS_LOG(DEBUG) << "Start Ring AllGather."; | |||||
| for (size_t i = 0; i < rank_size - 1; i++) { | |||||
| size_t send_chunk_index = (local_rank_ - i + 1 + rank_size) % rank_size; | |||||
| T *send_chunk = output_buff + chunk_offset[send_chunk_index]; | |||||
| auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, send_to_rank, send_chunk, | |||||
| chunk_sizes[send_chunk_index] * sizeof(T)); | |||||
| size_t recv_chunk_index = (local_rank_ - i + rank_size) % rank_size; | |||||
| T *recv_chunk = output_buff + chunk_offset[recv_chunk_index]; | |||||
| MS_LOG(DEBUG) << "Ring AllGather send_to_rank:" << send_to_rank << ", recv_from_rank:" << recv_from_rank | |||||
| << ", send count:" << chunk_sizes[send_chunk_index] | |||||
| << ", recv count:" << chunk_sizes[recv_chunk_index] << ", iteration:" << i; | |||||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, recv_from_rank, &recv_str); | |||||
| if (!server_node_->CollectiveWait(recv_req_id, 1)) { | |||||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||||
| return false; | |||||
| } | |||||
| memcpy_s(recv_chunk, chunk_sizes[recv_chunk_index] * sizeof(T), recv_str->data(), recv_str->size()); | |||||
| if (!server_node_->Wait(send_req_id, 1)) { | |||||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "End Ring AllGather."; | |||||
| return true; | |||||
| } | |||||
| template <typename T> | |||||
| bool CollectiveOpsImpl::ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||||
| uint32_t rank_size = server_num_; | |||||
| uint32_t local_rank_ = server_node_->rank_id(); | |||||
| MS_LOG(DEBUG) << "Reduce Broadcast AllReduce rank_size:" << rank_size << ", local_rank_:" << local_rank_ | |||||
| << ", count:" << count; | |||||
| int ret = memcpy_s(recvbuff, count * sizeof(T), sendbuff, count * sizeof(T)); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||||
| return false; | |||||
| } | |||||
| T *output_buff = reinterpret_cast<T *>(recvbuff); | |||||
| // Reduce data to rank 0 process. | |||||
| MS_LOG(DEBUG) << "Start Reduce to rank 0 process."; | |||||
| if (local_rank_ == 0) { | |||||
| std::unique_ptr<T[]> tmp_recv_buff = std::make_unique<T[]>(count); | |||||
| for (uint32_t i = 1; i < rank_size; i++) { | |||||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||||
| MS_LOG(DEBUG) << "Reduce rank 0 receive from rank " << i; | |||||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, i, &recv_str); | |||||
| if (!server_node_->CollectiveWait(recv_req_id, 1)) { | |||||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||||
| return false; | |||||
| } | |||||
| memcpy_s(tmp_recv_buff.get(), count * sizeof(T), recv_str->data(), recv_str->size()); | |||||
| for (size_t j = 0; j < count; j++) { | |||||
| output_buff[j] += tmp_recv_buff[j]; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Reduce send data to rank 0 process."; | |||||
| auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, 0, sendbuff, count * sizeof(T)); | |||||
| if (!server_node_->Wait(send_req_id, 1)) { | |||||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| MS_LOG(DEBUG) << "End Reduce."; | |||||
| // Broadcast data to not 0 rank process. | |||||
| MS_LOG(DEBUG) << "Start broadcast from rank 0 to other processes."; | |||||
| if (local_rank_ == 0) { | |||||
| for (uint32_t i = 1; i < rank_size; i++) { | |||||
| MS_LOG(DEBUG) << "Broadcast data to process " << i; | |||||
| auto send_req_id = server_node_->CollectiveSendAsync(core::NodeRole::SERVER, i, output_buff, count * sizeof(T)); | |||||
| if (!server_node_->Wait(send_req_id, 1)) { | |||||
| MS_LOG(ERROR) << "CollectiveWait " << send_req_id << " failed."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Broadcast receive from rank 0."; | |||||
| std::shared_ptr<std::vector<unsigned char>> recv_str; | |||||
| auto recv_req_id = server_node_->CollectiveReceiveAsync(core::NodeRole::SERVER, 0, &recv_str); | |||||
| if (!server_node_->CollectiveWait(recv_req_id, 1)) { | |||||
| MS_LOG(ERROR) << "CollectiveWait " << recv_req_id << " failed."; | |||||
| return false; | |||||
| } | |||||
| memcpy_s(output_buff, count * sizeof(T), recv_str->data(), recv_str->size()); | |||||
| } | |||||
| MS_LOG(DEBUG) << "End broadcast."; | |||||
| return true; | |||||
| } | |||||
| template <typename T> | |||||
| bool CollectiveOpsImpl::AllReduce(const void *sendbuff, void *recvbuff, size_t count) { | |||||
| // The collective communication API does not support calling Send and Recv concurrently with multiple threads; | |||||
| std::unique_lock<std::mutex> lock(mtx_); | |||||
| if (sendbuff == nullptr || recvbuff == nullptr) { | |||||
| MS_LOG(ERROR) << "AllReduce sendbuff or recvbuff is nullptr."; | |||||
| return false; | |||||
| } | |||||
| uint32_t rank_size = server_num_; | |||||
| if (count >= rank_size) { | |||||
| return RingAllReduce<T>(sendbuff, recvbuff, count); | |||||
| } else { | |||||
| return ReduceBroadcastAllReduce<T>(sendbuff, recvbuff, count); | |||||
| } | |||||
| } | |||||
| template bool CollectiveOpsImpl::RingAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count); | |||||
| template bool CollectiveOpsImpl::RingAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count); | |||||
| template bool CollectiveOpsImpl::RingAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count); | |||||
| template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<float>(const void *sendbuff, void *recvbuff, size_t count); | |||||
| template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count); | |||||
| template bool CollectiveOpsImpl::ReduceBroadcastAllReduce<int>(const void *sendbuff, void *recvbuff, size_t count); | |||||
| template bool CollectiveOpsImpl::AllReduce<float>(const void *sendbuff, void *recvbuff, size_t count); | |||||
| template bool CollectiveOpsImpl::AllReduce<size_t>(const void *sendbuff, void *recvbuff, size_t count); | |||||
| template bool CollectiveOpsImpl::AllReduce<int>(const void *sendbuff, void *recvbuff, size_t count); | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <functional> | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/ps_context.h" | |||||
| #include "ps/core/server_node.h" | |||||
| #include "ps/server/common.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| // CollectiveOpsImpl is the collective communication API of the server. | |||||
| // For now, it implements two AllReduce algorithms: RingAllReduce and BroadcastAllReduce. Elastic AllReduce is also | |||||
| // supported for the elastic scaling feature of the server. | |||||
| class CollectiveOpsImpl { | |||||
| public: | |||||
| static CollectiveOpsImpl &GetInstance() { | |||||
| static CollectiveOpsImpl instance; | |||||
| return instance; | |||||
| } | |||||
| void Initialize(const std::shared_ptr<core::ServerNode> &server_node); | |||||
| template <typename T> | |||||
| bool AllReduce(const void *sendbuff, void *recvbuff, size_t count); | |||||
| private: | |||||
| CollectiveOpsImpl() = default; | |||||
| ~CollectiveOpsImpl() = default; | |||||
| CollectiveOpsImpl(const CollectiveOpsImpl &) = delete; | |||||
| CollectiveOpsImpl &operator=(const CollectiveOpsImpl &) = delete; | |||||
| // Implementation of RingAllReduce. | |||||
| template <typename T> | |||||
| bool RingAllReduce(const void *sendbuff, void *recvbuff, size_t count); | |||||
| // Implementation of BroadcastAllReduce. | |||||
| template <typename T> | |||||
| bool ReduceBroadcastAllReduce(const void *sendbuff, void *recvbuff, size_t count); | |||||
| std::shared_ptr<core::ServerNode> server_node_; | |||||
| uint32_t local_rank_; | |||||
| uint32_t server_num_; | |||||
| // The mutex to ensure that collective communication is threadsafe. | |||||
| std::mutex mtx_; | |||||
| }; | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_COLLECTIVE_OPS_IMPL_H_ | |||||
| @@ -24,13 +24,17 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <functional> | #include <functional> | ||||
| #include "proto/ps.pb.h" | #include "proto/ps.pb.h" | ||||
| #include "proto/fl.pb.h" | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "ir/dtype/type_id.h" | #include "ir/dtype/type_id.h" | ||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | #include "backend/kernel_compiler/cpu/cpu_kernel.h" | ||||
| #include "schema/fl_job_generated.h" | |||||
| #include "schema/cipher_generated.h" | |||||
| #include "ps/ps_context.h" | #include "ps/ps_context.h" | ||||
| #include "ps/core/communicator/http_message_handler.h" | #include "ps/core/communicator/http_message_handler.h" | ||||
| #include "ps/core/communicator/tcp_server.h" | #include "ps/core/communicator/tcp_server.h" | ||||
| #include "ps/core/communicator/message_handler.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -40,13 +44,15 @@ enum ServerMode { PARAMETER_SERVER = 0, FL_SERVER }; | |||||
| enum CommType { HTTP = 0, TCP }; | enum CommType { HTTP = 0, TCP }; | ||||
| enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; | enum AggregationType { FedAvg = 0, FedAdam, FedAdagarg, FedMeta, qffl, DenseGradAccum, SparseGradAccum }; | ||||
| using kernel::Address; | |||||
| using kernel::AddressPtr; | |||||
| using kernel::CPUKernel; | |||||
| using mindspore::kernel::Address; | |||||
| using mindspore::kernel::AddressPtr; | |||||
| using mindspore::kernel::CPUKernel; | |||||
| using FBBuilder = flatbuffers::FlatBufferBuilder; | |||||
| using TimeOutCb = std::function<void(void)>; | using TimeOutCb = std::function<void(void)>; | ||||
| using StopTimerCb = std::function<void(void)>; | using StopTimerCb = std::function<void(void)>; | ||||
| using FinishIterCb = std::function<void(void)>; | using FinishIterCb = std::function<void(void)>; | ||||
| using FinalizeCb = std::function<void(void)>; | using FinalizeCb = std::function<void(void)>; | ||||
| using MessageCallback = std::function<void(const std::shared_ptr<core::MessageHandler> &)>; | |||||
| // Information about whether server kernel will reuse kernel node memory from the front end. | // Information about whether server kernel will reuse kernel node memory from the front end. | ||||
| // Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate". | // Key refers to the server kernel's parameter name, like "weights", "grad", "learning_rate". | ||||
| @@ -0,0 +1,298 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/distributed_count_service.h" | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| void DistributedCountService::Initialize(const std::shared_ptr<core::ServerNode> &server_node, | |||||
| uint32_t counting_server_rank) { | |||||
| server_node_ = server_node; | |||||
| MS_EXCEPTION_IF_NULL(server_node_); | |||||
| communicator_ = | |||||
| std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr)); | |||||
| MS_EXCEPTION_IF_NULL(communicator_); | |||||
| local_rank_ = server_node_->rank_id(); | |||||
| server_num_ = PSContext::instance()->initial_server_num(); | |||||
| counting_server_rank_ = counting_server_rank; | |||||
| RegisterCallback(); | |||||
| return; | |||||
| } | |||||
| void DistributedCountService::RegisterCounter(const std::string &name, size_t global_threshold_count, | |||||
| const CounterHandlers &counter_handlers) { | |||||
| if (!counter_handlers.first_count_handler || !counter_handlers.last_count_handler) { | |||||
| MS_LOG(EXCEPTION) << "First count handler or last count handler is not set."; | |||||
| return; | |||||
| } | |||||
| if (global_threshold_count_.count(name) != 0) { | |||||
| MS_LOG(ERROR) << "Counter for " << name << " is already set."; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Rank " << local_rank_ << " register counter for " << name << " count:" << global_threshold_count; | |||||
| // If the server is the leader server, it needs to set the counter handlers and do the real counting. | |||||
| if (local_rank_ == counting_server_rank_) { | |||||
| global_current_count_[name] = {}; | |||||
| global_threshold_count_[name] = global_threshold_count; | |||||
| mutex_[name]; | |||||
| } | |||||
| counter_handlers_[name] = counter_handlers; | |||||
| return; | |||||
| } | |||||
| bool DistributedCountService::Count(const std::string &name, const std::string &id) { | |||||
| MS_LOG(INFO) << "Rank " << local_rank_ << " reports count for " << name << " of " << id; | |||||
| if (local_rank_ == counting_server_rank_) { | |||||
| if (global_threshold_count_.count(name) == 0) { | |||||
| MS_LOG(ERROR) << "Counter for " << name << " is not registered."; | |||||
| return false; | |||||
| } | |||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||||
| if (global_current_count_[name].size() >= global_threshold_count_[name]) { | |||||
| MS_LOG(ERROR) << "Count for " << name << " is already enough. Threshold count is " | |||||
| << global_threshold_count_[name]; | |||||
| return false; | |||||
| } | |||||
| MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; | |||||
| global_current_count_[name].insert(id); | |||||
| TriggerCounterEvent(name); | |||||
| } else { | |||||
| // If this server is a follower server, it needs to send CountRequest to the leader server. | |||||
| CountRequest report_count_req; | |||||
| report_count_req.set_name(name); | |||||
| report_count_req.set_id(id); | |||||
| std::shared_ptr<std::vector<unsigned char>> report_cnt_rsp_msg = nullptr; | |||||
| if (!communicator_->SendPbRequest(report_count_req, counting_server_rank_, core::TcpUserCommand::kCount, | |||||
| &report_cnt_rsp_msg)) { | |||||
| MS_LOG(ERROR) << "Sending reporting count message to leader server failed for " << name; | |||||
| return false; | |||||
| } | |||||
| CountResponse count_rsp; | |||||
| count_rsp.ParseFromArray(report_cnt_rsp_msg->data(), report_cnt_rsp_msg->size()); | |||||
| if (!count_rsp.result()) { | |||||
| MS_LOG(ERROR) << "Reporting count failed:" << count_rsp.reason(); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool DistributedCountService::CountReachThreshold(const std::string &name) { | |||||
| MS_LOG(INFO) << "Rank " << local_rank_ << " query whether count reaches threshold for " << name; | |||||
| if (local_rank_ == counting_server_rank_) { | |||||
| if (global_threshold_count_.count(name) == 0) { | |||||
| MS_LOG(ERROR) << "Counter for " << name << " is not set."; | |||||
| return false; | |||||
| } | |||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||||
| return global_current_count_[name].size() == global_threshold_count_[name]; | |||||
| } else { | |||||
| CountReachThresholdRequest count_reach_threashold_req; | |||||
| count_reach_threashold_req.set_name(name); | |||||
| std::shared_ptr<std::vector<unsigned char>> query_cnt_enough_rsp_msg = nullptr; | |||||
| if (!communicator_->SendPbRequest(count_reach_threashold_req, counting_server_rank_, | |||||
| core::TcpUserCommand::kReachThreshold, &query_cnt_enough_rsp_msg)) { | |||||
| MS_LOG(ERROR) << "Sending querying whether count reaches threshold message to leader server failed for " << name; | |||||
| return false; | |||||
| } | |||||
| CountReachThresholdResponse count_reach_threashold_rsp; | |||||
| count_reach_threashold_rsp.ParseFromArray(query_cnt_enough_rsp_msg->data(), query_cnt_enough_rsp_msg->size()); | |||||
| return count_reach_threashold_rsp.is_enough(); | |||||
| } | |||||
| } | |||||
| void DistributedCountService::ResetCounter(const std::string &name) { | |||||
| if (local_rank_ == counting_server_rank_) { | |||||
| MS_LOG(INFO) << "Leader server reset count for " << name; | |||||
| global_current_count_[name].clear(); | |||||
| } | |||||
| return; | |||||
| } | |||||
| void DistributedCountService::RegisterCallback() { | |||||
| if (local_rank_ == counting_server_rank_) { | |||||
| communicator_->RegisterMsgCallBack( | |||||
| "count", std::bind(&DistributedCountService::HandleCountRequest, this, std::placeholders::_1)); | |||||
| communicator_->RegisterMsgCallBack( | |||||
| "countReachThreshold", | |||||
| std::bind(&DistributedCountService::HandleCountReachThresholdRequest, this, std::placeholders::_1)); | |||||
| } | |||||
| // The callback of first/last event must be set in both leader server and follower servers. | |||||
| communicator_->RegisterMsgCallBack( | |||||
| "counterEvent", std::bind(&DistributedCountService::HandleCounterEvent, this, std::placeholders::_1)); | |||||
| } | |||||
| void DistributedCountService::HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message) { | |||||
| if (message == nullptr) { | |||||
| MS_LOG(ERROR) << "Message is nullptr."; | |||||
| return; | |||||
| } | |||||
| CountRequest report_count_req; | |||||
| report_count_req.ParseFromArray(message->data(), message->len()); | |||||
| const std::string &name = report_count_req.name(); | |||||
| const std::string &id = report_count_req.id(); | |||||
| CountResponse count_rsp; | |||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||||
| // If leader server has no counter for the name registered, return an error. | |||||
| if (global_threshold_count_.count(name) == 0) { | |||||
| std::string reason = "Counter for " + name + " is not registered."; | |||||
| count_rsp.set_result(false); | |||||
| count_rsp.set_reason(reason); | |||||
| MS_LOG(ERROR) << reason; | |||||
| communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); | |||||
| return; | |||||
| } | |||||
| // If leader server already has enough count for the name, return an error. | |||||
| if (global_current_count_[name].size() >= global_threshold_count_[name]) { | |||||
| std::string reason = | |||||
| "Count for " + name + " is already enough. Threshold count is " + std::to_string(global_threshold_count_[name]); | |||||
| count_rsp.set_result(false); | |||||
| count_rsp.set_reason(reason); | |||||
| MS_LOG(ERROR) << reason; | |||||
| communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); | |||||
| return; | |||||
| } | |||||
| // Insert the id for the counter, which means the count for the name is increased. | |||||
| MS_LOG(INFO) << "Leader server increase count for " << name << " of " << id; | |||||
| global_current_count_[name].insert(id); | |||||
| TriggerCounterEvent(name); | |||||
| count_rsp.set_result(true); | |||||
| count_rsp.set_reason("success"); | |||||
| communicator_->SendResponse(count_rsp.SerializeAsString().data(), count_rsp.SerializeAsString().size(), message); | |||||
| return; | |||||
| } | |||||
| void DistributedCountService::HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message) { | |||||
| if (message == nullptr) { | |||||
| MS_LOG(ERROR) << "Message is nullptr."; | |||||
| return; | |||||
| } | |||||
| CountReachThresholdRequest count_reach_threashold_req; | |||||
| count_reach_threashold_req.ParseFromArray(message->data(), message->len()); | |||||
| const std::string &name = count_reach_threashold_req.name(); | |||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||||
| if (global_threshold_count_.count(name) == 0) { | |||||
| MS_LOG(ERROR) << "Counter for " << name << " is not registered."; | |||||
| return; | |||||
| } | |||||
| CountReachThresholdResponse count_reach_threashold_rsp; | |||||
| count_reach_threashold_rsp.set_is_enough(global_current_count_[name].size() == global_threshold_count_[name]); | |||||
| communicator_->SendResponse(count_reach_threashold_rsp.SerializeAsString().data(), | |||||
| count_reach_threashold_rsp.SerializeAsString().size(), message); | |||||
| return; | |||||
| } | |||||
| void DistributedCountService::HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message) { | |||||
| if (message == nullptr) { | |||||
| MS_LOG(ERROR) << "Message is nullptr."; | |||||
| return; | |||||
| } | |||||
| // Respond as soon as possible so the leader server won't wait for each follower servers to finish calling the | |||||
| // callbacks. | |||||
| std::string couter_event_rsp_msg = "success"; | |||||
| communicator_->SendResponse(couter_event_rsp_msg.data(), couter_event_rsp_msg.size(), message); | |||||
| CounterEvent counter_event; | |||||
| counter_event.ParseFromArray(message->data(), message->len()); | |||||
| const auto &type = counter_event.type(); | |||||
| const auto &name = counter_event.name(); | |||||
| MS_LOG(INFO) << "Rank " << local_rank_ << " do counter event " << type << " for " << name; | |||||
| if (type == CounterEventType::FIRST_CNT) { | |||||
| counter_handlers_[name].first_count_handler(message); | |||||
| } else if (type == CounterEventType::LAST_CNT) { | |||||
| counter_handlers_[name].last_count_handler(message); | |||||
| } else { | |||||
| MS_LOG(ERROR) << "DistributedCountService event type " << type << " is invalid."; | |||||
| return; | |||||
| } | |||||
| return; | |||||
| } | |||||
| void DistributedCountService::TriggerCounterEvent(const std::string &name) { | |||||
| MS_LOG(INFO) << "Current count for " << name << " is " << global_current_count_[name].size() | |||||
| << ", threshold count is " << global_threshold_count_[name]; | |||||
| // The threshold count may be 1 so the first and last count event should be both activated. | |||||
| if (global_current_count_[name].size() == 1) { | |||||
| TriggerFirstCountEvent(name); | |||||
| } | |||||
| if (global_current_count_[name].size() == global_threshold_count_[name]) { | |||||
| TriggerLastCountEvent(name); | |||||
| } | |||||
| return; | |||||
| } | |||||
| void DistributedCountService::TriggerFirstCountEvent(const std::string &name) { | |||||
| MS_LOG(INFO) << "Activating first count event for " << name; | |||||
| CounterEvent first_count_event; | |||||
| first_count_event.set_type(CounterEventType::FIRST_CNT); | |||||
| first_count_event.set_name(name); | |||||
| // Broadcast to all follower servers. | |||||
| for (uint32_t i = 1; i < server_num_; i++) { | |||||
| if (!communicator_->SendPbRequest(first_count_event, i, core::TcpUserCommand::kCounterEvent)) { | |||||
| MS_LOG(ERROR) << "Activating first count event to server " << i << " failed."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| // Leader server directly calls the callback. | |||||
| counter_handlers_[name].first_count_handler(nullptr); | |||||
| return; | |||||
| } | |||||
| void DistributedCountService::TriggerLastCountEvent(const std::string &name) { | |||||
| MS_LOG(INFO) << "Activating last count event for " << name; | |||||
| CounterEvent last_count_event; | |||||
| last_count_event.set_type(CounterEventType::LAST_CNT); | |||||
| last_count_event.set_name(name); | |||||
| // Broadcast to all follower servers. | |||||
| for (uint32_t i = 1; i < server_num_; i++) { | |||||
| if (!communicator_->SendPbRequest(last_count_event, i, core::TcpUserCommand::kCounterEvent)) { | |||||
| MS_LOG(ERROR) << "Activating last count event to server " << i << " failed."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| // Leader server directly calls the callback. | |||||
| counter_handlers_[name].last_count_handler(nullptr); | |||||
| return; | |||||
| } | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,126 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/core/server_node.h" | |||||
| #include "ps/core/communicator/tcp_communicator.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| // The callbacks for the first count and last count event. | |||||
| typedef struct { | |||||
| MessageCallback first_count_handler; | |||||
| MessageCallback last_count_handler; | |||||
| } CounterHandlers; | |||||
| // DistributedCountService is used for counting in the server cluster dimension. It's used for counting of rounds, | |||||
| // aggregation counting, etc. | |||||
| // The counting could be called by any server, but only one server has the information | |||||
| // of the cluster count and we mark this server as the counting server. Other servers must communicate with this | |||||
| // counting server to increase/query count number. | |||||
| // On the first count or last count event, DistributedCountService on the counting server triggers the event on other | |||||
| // servers by sending counter event commands. This is for the purpose of keeping server cluster's consistency. | |||||
| class DistributedCountService { | |||||
| public: | |||||
| static DistributedCountService &GetInstance() { | |||||
| static DistributedCountService instance; | |||||
| return instance; | |||||
| } | |||||
| // Initialize counter service with the server node because communication is needed. | |||||
| void Initialize(const std::shared_ptr<core::ServerNode> &server_node, uint32_t counting_server_rank); | |||||
| // Register counter to the counting server for the name with its threshold count in server cluster dimension and | |||||
| // first/last count event callbacks. | |||||
| void RegisterCounter(const std::string &name, size_t global_threshold_count, const CounterHandlers &counter_handlers); | |||||
| // Report a count to the counting server. Parameter 'id' is in case of repeated counting. | |||||
| bool Count(const std::string &name, const std::string &id); | |||||
| // Query whether the count reaches the threshold count for the name. If the count is the same as the threshold count, | |||||
| // this method returns true. | |||||
| bool CountReachThreshold(const std::string &name); | |||||
| // Reset the count of the name to 0. | |||||
| void ResetCounter(const std::string &name); | |||||
| // Returns the server rank because in some cases the callers use this rank as the 'id' for method | |||||
| // Count. | |||||
| uint32_t local_rank() { return local_rank_; } | |||||
| private: | |||||
| DistributedCountService() = default; | |||||
| ~DistributedCountService() = default; | |||||
| DistributedCountService(const DistributedCountService &) = delete; | |||||
| DistributedCountService &operator=(const DistributedCountService &) = delete; | |||||
| // Register callbacks of the counting server to handle messages sent by the other servers. | |||||
| void RegisterCallback(); | |||||
| // Callback for the reporting count message from other servers. Only counting server will call this method. | |||||
| void HandleCountRequest(const std::shared_ptr<core::MessageHandler> &message); | |||||
| // Callback for the querying whether threshold count is reached message from other servers. Only counting | |||||
| // server will call this method. | |||||
| void HandleCountReachThresholdRequest(const std::shared_ptr<core::MessageHandler> &message); | |||||
| // Callback for the first/last event message from the counting server. Only other servers will call this | |||||
| // method. | |||||
| void HandleCounterEvent(const std::shared_ptr<core::MessageHandler> &message); | |||||
| // Call the callbacks when the first/last count event is triggered. | |||||
| void TriggerCounterEvent(const std::string &name); | |||||
| void TriggerFirstCountEvent(const std::string &name); | |||||
| void TriggerLastCountEvent(const std::string &name); | |||||
| // Members for the communication between counting server and other servers. | |||||
| std::shared_ptr<core::ServerNode> server_node_; | |||||
| std::shared_ptr<core::TcpCommunicator> communicator_; | |||||
| uint32_t local_rank_; | |||||
| uint32_t server_num_; | |||||
| // Only one server will be set to do the real counting. | |||||
| uint32_t counting_server_rank_; | |||||
| // Key: name, e.g, startFLJob, updateModel, push. | |||||
| // Value: a set of id without repeatation because each work may report multiple times. | |||||
| std::unordered_map<std::string, std::set<std::string>> global_current_count_; | |||||
| // Key: name, e.g, StartFLJobCount. | |||||
| // Value: global threshold count in the server cluster dimension for this name. | |||||
| std::unordered_map<std::string, size_t> global_threshold_count_; | |||||
| // First/last count event callbacks of the name. | |||||
| std::unordered_map<std::string, CounterHandlers> counter_handlers_; | |||||
| // Because the count is increased/queried conccurently, we must ensure the operations are threadsafe. | |||||
| std::unordered_map<std::string, std::mutex> mutex_; | |||||
| }; | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_COUNT_SERVICE_H_ | |||||
| @@ -0,0 +1,201 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/distributed_metadata_store.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| void DistributedMetadataStore::Initialize(const std::shared_ptr<core::ServerNode> &server_node) { | |||||
| server_node_ = server_node; | |||||
| MS_EXCEPTION_IF_NULL(server_node); | |||||
| communicator_ = | |||||
| std::dynamic_pointer_cast<core::TcpCommunicator>(server_node_->GetOrCreateTcpComm("", 0, 0, 0, nullptr)); | |||||
| MS_EXCEPTION_IF_NULL(communicator_); | |||||
| local_rank_ = server_node_->rank_id(); | |||||
| server_num_ = PSContext::instance()->initial_server_num(); | |||||
| InitHashRing(); | |||||
| RegisterCallback(); | |||||
| return; | |||||
| } | |||||
| void DistributedMetadataStore::RegisterMetadata(const std::string &name, const PBMetadata &meta) { | |||||
| if (router_ == nullptr) { | |||||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||||
| return; | |||||
| } | |||||
| uint32_t stored_rank = router_->Find(name); | |||||
| if (local_rank_ == stored_rank) { | |||||
| if (metadata_.count(name) != 0) { | |||||
| MS_LOG(ERROR) << "The metadata for " << name << " is already registered."; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Rank " << local_rank_ << " register storage for metadata " << name; | |||||
| metadata_[name] = meta; | |||||
| mutex_[name]; | |||||
| } | |||||
| return; | |||||
| } | |||||
| void DistributedMetadataStore::ResetMetadata(const std::string &name) { | |||||
| if (router_ == nullptr) { | |||||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||||
| return; | |||||
| } | |||||
| uint32_t stored_rank = router_->Find(name); | |||||
| if (local_rank_ == stored_rank) { | |||||
| if (metadata_.count(name) == 0) { | |||||
| MS_LOG(ERROR) << "The metadata for " << name << " is not registered."; | |||||
| return; | |||||
| } | |||||
| MS_LOG(INFO) << "Rank " << local_rank_ << " reset metadata for " << name; | |||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||||
| PBMetadata empty_meta; | |||||
| metadata_[name] = empty_meta; | |||||
| } | |||||
| return; | |||||
| } | |||||
| void DistributedMetadataStore::UpdateMetadata(const std::string &name, const PBMetadata &meta) { | |||||
| if (router_ == nullptr) { | |||||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||||
| return; | |||||
| } | |||||
| uint32_t stored_rank = router_->Find(name); | |||||
| MS_LOG(INFO) << "Rank " << local_rank_ << " update value for " << name << " which is stored in rank " << stored_rank; | |||||
| if (local_rank_ == stored_rank) { | |||||
| if (!DoUpdateMetadata(name, meta)) { | |||||
| MS_LOG(ERROR) << "Updating meta data failed."; | |||||
| return; | |||||
| } | |||||
| } else { | |||||
| PBMetadataWithName metadata_with_name; | |||||
| metadata_with_name.set_name(name); | |||||
| *metadata_with_name.mutable_metadata() = meta; | |||||
| if (!communicator_->SendPbRequest(metadata_with_name, stored_rank, core::TcpUserCommand::kUpdateMetadata)) { | |||||
| MS_LOG(ERROR) << "Sending updating metadata message to server " << stored_rank << " failed."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| PBMetadata DistributedMetadataStore::GetMetadata(const std::string &name) { | |||||
| if (router_ == nullptr) { | |||||
| MS_LOG(ERROR) << "The consistent hash ring is not initialized yet."; | |||||
| return {}; | |||||
| } | |||||
| uint32_t stored_rank = router_->Find(name); | |||||
| MS_LOG(INFO) << "Rank " << local_rank_ << " get metadata for " << name << " which is stored in rank " << stored_rank; | |||||
| if (local_rank_ == stored_rank) { | |||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||||
| return metadata_[name]; | |||||
| } else { | |||||
| GetMetadataRequest get_metadata_req; | |||||
| get_metadata_req.set_name(name); | |||||
| PBMetadata get_metadata_rsp; | |||||
| std::shared_ptr<std::vector<unsigned char>> get_meta_rsp_msg = nullptr; | |||||
| if (!communicator_->SendPbRequest(get_metadata_req, stored_rank, core::TcpUserCommand::kGetMetadata, | |||||
| &get_meta_rsp_msg)) { | |||||
| MS_LOG(ERROR) << "Sending getting metadata message to server " << stored_rank << " failed."; | |||||
| return get_metadata_rsp; | |||||
| } | |||||
| get_metadata_rsp.ParseFromArray(get_meta_rsp_msg->data(), get_meta_rsp_msg->size()); | |||||
| return get_metadata_rsp; | |||||
| } | |||||
| } | |||||
| void DistributedMetadataStore::InitHashRing() { | |||||
| router_ = std::make_shared<ConsistentHashRing>(32); | |||||
| MS_EXCEPTION_IF_NULL(router_); | |||||
| for (uint32_t i = 0; i < server_num_; i++) { | |||||
| bool ret = router_->Insert(i); | |||||
| if (!ret) { | |||||
| MS_LOG(EXCEPTION) << "Add node " << i << " to router of meta storage failed."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| return; | |||||
| } | |||||
| void DistributedMetadataStore::RegisterCallback() { | |||||
| communicator_->RegisterMsgCallBack( | |||||
| "updateMetadata", std::bind(&DistributedMetadataStore::HandleUpdateMetadataRequest, this, std::placeholders::_1)); | |||||
| communicator_->RegisterMsgCallBack( | |||||
| "getMetadata", std::bind(&DistributedMetadataStore::HandleGetMetadataRequest, this, std::placeholders::_1)); | |||||
| return; | |||||
| } | |||||
| void DistributedMetadataStore::HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) { | |||||
| if (message == nullptr) { | |||||
| MS_LOG(ERROR) << "Message is nullptr."; | |||||
| return; | |||||
| } | |||||
| PBMetadataWithName meta_with_name; | |||||
| meta_with_name.ParseFromArray(message->data(), message->len()); | |||||
| const std::string &name = meta_with_name.name(); | |||||
| MS_LOG(INFO) << "Update metadata for " << name; | |||||
| std::string update_meta_rsp_msg; | |||||
| if (!DoUpdateMetadata(name, meta_with_name.metadata())) { | |||||
| update_meta_rsp_msg = "Updating meta data failed."; | |||||
| } else { | |||||
| update_meta_rsp_msg = "Success"; | |||||
| } | |||||
| communicator_->SendResponse(update_meta_rsp_msg.data(), update_meta_rsp_msg.size(), message); | |||||
| return; | |||||
| } | |||||
| void DistributedMetadataStore::HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message) { | |||||
| if (message == nullptr) { | |||||
| MS_LOG(ERROR) << "Message is nullptr."; | |||||
| return; | |||||
| } | |||||
| GetMetadataRequest get_metadata_req; | |||||
| get_metadata_req.ParseFromArray(message->data(), message->len()); | |||||
| const std::string &name = get_metadata_req.name(); | |||||
| MS_LOG(INFO) << "Getting metadata for " << name; | |||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||||
| PBMetadata stored_meta = metadata_[name]; | |||||
| std::string getting_meta_rsp_msg = stored_meta.SerializeAsString(); | |||||
| communicator_->SendResponse(getting_meta_rsp_msg.data(), getting_meta_rsp_msg.size(), message); | |||||
| return; | |||||
| } | |||||
| bool DistributedMetadataStore::DoUpdateMetadata(const std::string &name, const PBMetadata &meta) { | |||||
| std::unique_lock<std::mutex> lock(mutex_[name]); | |||||
| metadata_[name] = meta; | |||||
| return true; | |||||
| } | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,101 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <unordered_map> | |||||
| #include "proto/ps.pb.h" | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/core/server_node.h" | |||||
| #include "ps/core/communicator/tcp_communicator.h" | |||||
| #include "ps/server/consistent_hash_ring.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| // This class is used for distributed metadata storage using consistent hash. All metadata is distributedly | |||||
| // stored in all servers. Caller doesn't need to know which server stores the metadata. It only needs to know what kind | |||||
| // of operations should be done to the metadata. | |||||
| // The metadata stored in the server is in protobuffer format because it's easy for serializing and communicating. The | |||||
| // type of the protobuffer struct is decided by the caller using protobuffer's API. | |||||
| class DistributedMetadataStore { | |||||
| public: | |||||
| static DistributedMetadataStore &GetInstance() { | |||||
| static DistributedMetadataStore instance; | |||||
| return instance; | |||||
| } | |||||
| // Initialize metadata storage with the server node because communication is needed. | |||||
| void Initialize(const std::shared_ptr<core::ServerNode> &server_node); | |||||
| // Register metadata for the name with the initial value. This method should be only called once for each name. | |||||
| void RegisterMetadata(const std::string &name, const PBMetadata &meta); | |||||
| // Reset the metadata value for the name. | |||||
| void ResetMetadata(const std::string &name); | |||||
| // Update the metadata for the name. | |||||
| void UpdateMetadata(const std::string &name, const PBMetadata &meta); | |||||
| // Get the metadata for the name. | |||||
| PBMetadata GetMetadata(const std::string &name); | |||||
| private: | |||||
| DistributedMetadataStore() = default; | |||||
| ~DistributedMetadataStore() = default; | |||||
| DistributedMetadataStore(const DistributedMetadataStore &) = delete; | |||||
| DistributedMetadataStore &operator=(const DistributedMetadataStore &) = delete; | |||||
| // Initialize the consistent hash ring for distributed storage. | |||||
| void InitHashRing(); | |||||
| // Register callbacks for the server to handle update/get metadata messages from other servers. | |||||
| void RegisterCallback(); | |||||
| // Callback for updating metadata request sent to the server. | |||||
| void HandleUpdateMetadataRequest(const std::shared_ptr<core::MessageHandler> &message); | |||||
| // Callback for getting metadata request sent to the server. | |||||
| void HandleGetMetadataRequest(const std::shared_ptr<core::MessageHandler> &message); | |||||
| // Do updating metadata in the server where the metadata for the name is stored. | |||||
| bool DoUpdateMetadata(const std::string &name, const PBMetadata &meta); | |||||
| // Members for the communication between servers. | |||||
| std::shared_ptr<core::ServerNode> server_node_; | |||||
| std::shared_ptr<core::TcpCommunicator> communicator_; | |||||
| uint32_t local_rank_; | |||||
| uint32_t server_num_; | |||||
| // Consistent hash ring. This is used for DistributedMetadataStore to find which server node the meta data is stored. | |||||
| std::shared_ptr<ConsistentHashRing> router_; | |||||
| // We store metadata which is serialized by ProtoBuffer so that data storage and data transmission API is easy to use. | |||||
| // Key: data name. | |||||
| // Value: ProtoBuffer Struct. | |||||
| std::unordered_map<std::string, PBMetadata> metadata_; | |||||
| // Because the metadata is read/written conccurently, we must ensure the operations are threadsafe. | |||||
| std::unordered_map<std::string, std::mutex> mutex_; | |||||
| }; | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_DISTRIBUTED_META_STORE_H_ | |||||
| @@ -169,7 +169,7 @@ bool Executor::HandleOverwriteWeightsByKey(const std::map<std::string, Address> | |||||
| } | } | ||||
| AddressPtr Executor::HandlePull(const std::string ¶m_name) { | AddressPtr Executor::HandlePull(const std::string ¶m_name) { | ||||
| MS_LOG(INFO) << "Handle blocking pull msg for parameter " << param_name; | |||||
| MS_LOG(INFO) << "Handle blocking pull message for parameter " << param_name; | |||||
| if (param_aggrs_.count(param_name) == 0) { | if (param_aggrs_.count(param_name) == 0) { | ||||
| MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; | MS_LOG(WARNING) << "Parameter " << param_name << " is not registered in server."; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -193,11 +193,6 @@ AddressPtr Executor::HandlePull(const std::string ¶m_name) { | |||||
| return addr; | return addr; | ||||
| } | } | ||||
| std::map<std::string, AddressPtr> Executor::HandleAsyncGetModel() { | |||||
| std::unique_lock<std::mutex> lock(model_mutex_); | |||||
| return GetModel(); | |||||
| } | |||||
| std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> ¶m_names) { | std::map<std::string, AddressPtr> Executor::HandleGetWeightsByKey(const std::vector<std::string> ¶m_names) { | ||||
| std::map<std::string, AddressPtr> weights; | std::map<std::string, AddressPtr> weights; | ||||
| for (const auto ¶m_name : param_names) { | for (const auto ¶m_name : param_names) { | ||||
| @@ -63,10 +63,6 @@ class Executor { | |||||
| // asynchronously. | // asynchronously. | ||||
| bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map); | bool HandleModelUpdateAsync(const std::map<std::string, UploadData> &feature_map); | ||||
| // Called in asynchronous federated learning training mode. Returns whole model in key-value where key refers to the | |||||
| // parameter name. | |||||
| std::map<std::string, AddressPtr> HandleAsyncGetModel(); | |||||
| // Forcibly overwrite specific weights in overwriteWeights message. | // Forcibly overwrite specific weights in overwriteWeights message. | ||||
| bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map); | bool HandleOverwriteWeightsByKey(const std::map<std::string, Address> &feature_map); | ||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/iteration.h" | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include <numeric> | |||||
| #include "ps/server/model_store.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| Iteration::Iteration() : iteration_num_(1) { LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); } | |||||
| void Iteration::AddRound(const std::shared_ptr<Round> &round) { | |||||
| MS_EXCEPTION_IF_NULL(round); | |||||
| rounds_.push_back(round); | |||||
| } | |||||
| void Iteration::InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators, | |||||
| const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb) { | |||||
| if (communicators.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Communicators for rounds is empty."; | |||||
| return; | |||||
| } | |||||
| std::for_each(communicators.begin(), communicators.end(), | |||||
| [&](const std::shared_ptr<core::CommunicatorBase> &communicator) { | |||||
| for (auto &round : rounds_) { | |||||
| if (round == nullptr) { | |||||
| continue; | |||||
| } | |||||
| round->Initialize(communicator, timeout_cb, finish_iteration_cb); | |||||
| } | |||||
| }); | |||||
| // The time window for one iteration, which will be used in some round kernels. | |||||
| size_t iteration_time_window = | |||||
| std::accumulate(rounds_.begin(), rounds_.end(), 0, | |||||
| [](size_t total, const std::shared_ptr<Round> &round) { return total + round->time_window(); }); | |||||
| LocalMetaStore::GetInstance().put_value(kCtxTotalTimeoutDuration, iteration_time_window); | |||||
| return; | |||||
| } | |||||
| void Iteration::ProceedToNextIter() { | |||||
| iteration_num_ = LocalMetaStore::GetInstance().curr_iter_num(); | |||||
| // Store the model for each iteration. | |||||
| const auto &model = Executor::GetInstance().GetModel(); | |||||
| ModelStore::GetInstance().StoreModelByIterNum(iteration_num_, model); | |||||
| for (auto &round : rounds_) { | |||||
| round->Reset(); | |||||
| } | |||||
| iteration_num_++; | |||||
| LocalMetaStore::GetInstance().set_curr_iter_num(iteration_num_); | |||||
| MS_LOG(INFO) << "Proceed to next iteration:" << iteration_num_ << "\n"; | |||||
| } | |||||
| const std::vector<std::shared_ptr<Round>> &Iteration::rounds() { return rounds_; } | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,58 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ | |||||
| #include <memory> | |||||
| #include <vector> | |||||
| #include "ps/core/communicator/communicator_base.h" | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/server/round.h" | |||||
| #include "ps/server/local_meta_store.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| // In server's logic, Iteration is the minimum execution unit. For each execution, it consists of multiple kinds of | |||||
| // Rounds, only after all the rounds are finished, this iteration is considered as completed. | |||||
| class Iteration { | |||||
| public: | |||||
| Iteration(); | |||||
| ~Iteration() = default; | |||||
| // Add a round for the iteration. This method will be called multiple times for each round. | |||||
| void AddRound(const std::shared_ptr<Round> &round); | |||||
| // Initialize all the rounds in the iteration. | |||||
| void InitRounds(const std::vector<std::shared_ptr<core::CommunicatorBase>> &communicators, | |||||
| const TimeOutCb &timeout_cb, const FinishIterCb &finish_iteration_cb); | |||||
| // The server proceeds to the next iteration only after the last iteration finishes. | |||||
| void ProceedToNextIter(); | |||||
| const std::vector<std::shared_ptr<Round>> &rounds(); | |||||
| private: | |||||
| std::vector<std::shared_ptr<Round>> rounds_; | |||||
| // Server's current iteration number. | |||||
| size_t iteration_num_; | |||||
| }; | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_ITERATION_H_ | |||||
| @@ -0,0 +1,127 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/kernel/round/round_kernel.h" | |||||
| #include <mutex> | |||||
| #include <queue> | |||||
| #include <chrono> | |||||
| #include <thread> | |||||
| #include <utility> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| RoundKernel::RoundKernel() : name_(""), current_count_(0), required_count_(0), error_reason_("") { | |||||
| release_thread_ = std::thread([&]() { | |||||
| while (true) { | |||||
| std::unique_lock<std::mutex> release_lock(release_mtx_); | |||||
| // Detect whether there's any data needs to be released every 100 milliseconds. | |||||
| if (heap_data_to_release_.empty()) { | |||||
| release_lock.unlock(); | |||||
| std::this_thread::sleep_for(std::chrono::milliseconds(100)); | |||||
| continue; | |||||
| } | |||||
| AddressPtr addr_ptr = heap_data_to_release_.front(); | |||||
| heap_data_to_release_.pop(); | |||||
| release_lock.unlock(); | |||||
| std::unique_lock<std::mutex> heap_data_lock(heap_data_mtx_); | |||||
| if (heap_data_.count(addr_ptr) == 0) { | |||||
| MS_LOG(ERROR) << "The data is not stored."; | |||||
| continue; | |||||
| } | |||||
| // Manually release unique_ptr data. | |||||
| heap_data_[addr_ptr].reset(nullptr); | |||||
| heap_data_.erase(heap_data_.find(addr_ptr)); | |||||
| } | |||||
| }); | |||||
| release_thread_.detach(); | |||||
| } | |||||
| void RoundKernel::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; } | |||||
| void RoundKernel::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { return; } | |||||
| void RoundKernel::StopTimer() { | |||||
| if (stop_timer_cb_) { | |||||
| stop_timer_cb_(); | |||||
| } | |||||
| return; | |||||
| } | |||||
| void RoundKernel::FinishIteration() { | |||||
| if (finish_iteration_cb_) { | |||||
| finish_iteration_cb_(); | |||||
| } | |||||
| return; | |||||
| } | |||||
| void RoundKernel::Release(AddressPtr addr_ptr) { | |||||
| if (addr_ptr == nullptr) { | |||||
| MS_LOG(ERROR) << "Data to be released is empty."; | |||||
| return; | |||||
| } | |||||
| std::unique_lock<std::mutex> lock(release_mtx_); | |||||
| heap_data_to_release_.push(addr_ptr); | |||||
| return; | |||||
| } | |||||
| void RoundKernel::set_name(const std::string &name) { name_ = name; } | |||||
| void RoundKernel::set_stop_timer_cb(StopTimerCb timer_stopper) { stop_timer_cb_ = timer_stopper; } | |||||
| void RoundKernel::set_finish_iteration_cb(FinishIterCb finish_iteration_cb) { | |||||
| finish_iteration_cb_ = finish_iteration_cb; | |||||
| } | |||||
| void RoundKernel::GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len) { | |||||
| if (data == nullptr) { | |||||
| MS_LOG(ERROR) << "The data is nullptr."; | |||||
| return; | |||||
| } | |||||
| if (outputs.empty()) { | |||||
| MS_LOG(ERROR) << "Generating output failed. Outputs size is empty."; | |||||
| return; | |||||
| } | |||||
| std::unique_ptr<unsigned char[]> output_data = std::make_unique<unsigned char[]>(len); | |||||
| if (output_data == nullptr) { | |||||
| MS_LOG(ERROR) << "Output data is nullptr."; | |||||
| return; | |||||
| } | |||||
| size_t dst_size = len; | |||||
| int ret = memcpy_s(output_data.get(), dst_size, data, len); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||||
| return; | |||||
| } | |||||
| outputs[0]->addr = output_data.get(); | |||||
| outputs[0]->size = len; | |||||
| std::unique_lock<std::mutex> lock(heap_data_mtx_); | |||||
| heap_data_.insert(std::make_pair(outputs[0], std::move(output_data))); | |||||
| return; | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,130 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <mutex> | |||||
| #include <queue> | |||||
| #include <utility> | |||||
| #include <chrono> | |||||
| #include <thread> | |||||
| #include <unordered_map> | |||||
| #include "backend/kernel_compiler/common_utils.h" | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/server/local_meta_store.h" | |||||
| #include "ps/server/distributed_count_service.h" | |||||
| #include "ps/server/distributed_metadata_store.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| // RoundKernel contains the main logic of server handling messages from workers. One iteration has multiple round | |||||
| // kernels to represent the process. They receive and parse messages from the server communication module. After | |||||
| // handling these messages, round kernels allocate response data and send it back. | |||||
| // For example, the main process of federated learning is: | |||||
| // startFLJob round->updateModel round->getModel round. | |||||
| class RoundKernel : virtual public CPUKernel { | |||||
| public: | |||||
| RoundKernel(); | |||||
| virtual ~RoundKernel() = default; | |||||
| // RoundKernel doesn't use InitKernel method of base class CPUKernel to initialize. So implementation of this | |||||
| // inherited method is empty. | |||||
| void InitKernel(const CNodePtr &kernel_node) override {} | |||||
| // Initialize RoundKernel with threshold_count which means that for every iteration, this round needs threshold_count | |||||
| // messages. | |||||
| virtual void InitKernel(size_t threshold_count) = 0; | |||||
| // Launch the round kernel logic to handle the message passed by the communication module. | |||||
| virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) = 0; | |||||
| // The callbacks when first message and last message for this round kernel is received. | |||||
| // These methods is called by class DistributedCountService and triggered by leader server(Rank 0). | |||||
| // virtual void OnFirstCountEvent(std::shared_ptr<core::MessageHandler> message); | |||||
| // virtual void OnLastCnt(std::shared_ptr<core::MessageHandler> message); | |||||
| // Some rounds could be stateful in a iteration. Reset method resets the status of this round. | |||||
| virtual bool Reset() = 0; | |||||
| // The counter event handlers for DistributedCountService. | |||||
| virtual void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message); | |||||
| virtual void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message); | |||||
| // Called when this round is finished. This round timer's Stop method will be called. | |||||
| void StopTimer(); | |||||
| // Called after this iteration(including all rounds) is finished. All rounds' Reset method will | |||||
| // be called. | |||||
| void FinishIteration(); | |||||
| // Release the response data allocated inside the round kernel. | |||||
| // Server framework must call this after the response data is sent back. | |||||
| void Release(AddressPtr addr_ptr); | |||||
| // Set round kernel name, which could be used in round kernel's methods. | |||||
| void set_name(const std::string &name); | |||||
| // Set callbacks to be called under certain triggered conditions. | |||||
| void set_stop_timer_cb(StopTimerCb timer_stopper); | |||||
| void set_finish_iteration_cb(FinishIterCb finish_iteration_cb); | |||||
| protected: | |||||
| // Generating response data of this round. The data is allocated on the heap to ensure it's not released before sent | |||||
| // back to worker. | |||||
| void GenerateOutput(const std::vector<AddressPtr> &outputs, void *data, size_t len); | |||||
| // Round kernel's name. | |||||
| std::string name_; | |||||
| // The current received message count for this round in this iteration. | |||||
| size_t current_count_; | |||||
| // The required received message count for this round in one iteration. | |||||
| size_t required_count_; | |||||
| // The reason causes the error in this round kernel. | |||||
| std::string error_reason_; | |||||
| StopTimerCb stop_timer_cb_; | |||||
| FinishIterCb finish_iteration_cb_; | |||||
| // Members below are used for allocating and releasing response data on the heap. | |||||
| // To ensure the performance, we use another thread to release data on the heap. So the operation on the data should | |||||
| // be threadsafe. | |||||
| std::thread release_thread_; | |||||
| // Data needs to be released and its mutex; | |||||
| std::mutex release_mtx_; | |||||
| std::queue<AddressPtr> heap_data_to_release_; | |||||
| std::mutex heap_data_mtx_; | |||||
| std::unordered_map<AddressPtr, std::unique_ptr<unsigned char[]>> heap_data_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_H_ | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/kernel/round/round_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| RoundKernelFactory &RoundKernelFactory::GetInstance() { | |||||
| static RoundKernelFactory instance; | |||||
| return instance; | |||||
| } | |||||
| void RoundKernelFactory::Register(const std::string &name, RoundKernelCreator &&creator) { | |||||
| name_to_creator_map_[name] = creator; | |||||
| } | |||||
| std::shared_ptr<RoundKernel> RoundKernelFactory::Create(const std::string &name) { | |||||
| if (name_to_creator_map_.count(name) == 0) { | |||||
| MS_LOG(ERROR) << "Round kernel " << name << " is not registered."; | |||||
| return nullptr; | |||||
| } | |||||
| auto kernel = name_to_creator_map_[name](); | |||||
| kernel->set_name(name); | |||||
| return kernel; | |||||
| } | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,62 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <unordered_map> | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/server/kernel/round/round_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| using RoundKernelCreator = std::function<std::shared_ptr<RoundKernel>()>; | |||||
| // Kernel factory of round kernels. | |||||
| class RoundKernelFactory { | |||||
| public: | |||||
| static RoundKernelFactory &GetInstance(); | |||||
| void Register(const std::string &name, RoundKernelCreator &&creator); | |||||
| std::shared_ptr<RoundKernel> Create(const std::string &name); | |||||
| private: | |||||
| RoundKernelFactory() = default; | |||||
| ~RoundKernelFactory() = default; | |||||
| RoundKernelFactory(const RoundKernelFactory &) = delete; | |||||
| RoundKernelFactory &operator=(const RoundKernelFactory &) = delete; | |||||
| std::unordered_map<std::string, RoundKernelCreator> name_to_creator_map_; | |||||
| }; | |||||
| class RoundKernelRegister { | |||||
| public: | |||||
| RoundKernelRegister(const std::string &name, RoundKernelCreator &&creator) { | |||||
| RoundKernelFactory::GetInstance().Register(name, std::move(creator)); | |||||
| } | |||||
| }; | |||||
| #define REG_ROUND_KERNEL(NAME, CLASS) \ | |||||
| static_assert(std::is_base_of<RoundKernel, CLASS>::value, " must be base of RoundKernel"); \ | |||||
| static const RoundKernelRegister g_##NAME##_round_kernel_reg(#NAME, []() { return std::make_shared<CLASS>(); }); | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_ROUND_ROUND_KERNEL_FACTORY_H_ | |||||
| @@ -0,0 +1,192 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/kernel/round/start_fl_job_kernel.h" | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| void StartFLJobKernel::InitKernel(size_t) { | |||||
| if (LocalMetaStore::GetInstance().has_value(kCtxTotalTimeoutDuration)) { | |||||
| iteration_time_window_ = LocalMetaStore::GetInstance().value<size_t>(kCtxTotalTimeoutDuration); | |||||
| } | |||||
| executor_ = &Executor::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(executor_); | |||||
| if (!executor_->initialized()) { | |||||
| MS_LOG(EXCEPTION) << "Executor must be initialized in server pipeline."; | |||||
| return; | |||||
| } | |||||
| PBMetadata devices_metas; | |||||
| DistributedMetadataStore::GetInstance().RegisterMetadata(kCtxDeviceMetas, devices_metas); | |||||
| return; | |||||
| } | |||||
| bool StartFLJobKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) { | |||||
| MS_LOG(INFO) << "Launching StartFLJobKernel kernel."; | |||||
| if (inputs.size() != 1 || outputs.size() != 1) { | |||||
| MS_LOG(ERROR) << "inputs or outputs size is invalid."; | |||||
| return false; | |||||
| } | |||||
| void *req_data = inputs[0]->addr; | |||||
| const std::shared_ptr<FBBuilder> &fbb = std::make_shared<FBBuilder>(); | |||||
| if (fbb == nullptr || req_data == nullptr) { | |||||
| MS_LOG(ERROR) << "FBBuilder builder or req_data is nullptr."; | |||||
| return false; | |||||
| } | |||||
| if (ReachThresholdForStartFLJob(fbb)) { | |||||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| return false; | |||||
| } | |||||
| const schema::RequestFLJob *start_fl_job_req = flatbuffers::GetRoot<schema::RequestFLJob>(req_data); | |||||
| DeviceMeta device_meta = CreateDeviceMetadata(start_fl_job_req); | |||||
| if (!ReadyForStartFLJob(fbb, device_meta)) { | |||||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| return false; | |||||
| } | |||||
| // If calling ReportCount before ReadyForStartFLJob, the result will be inconsistent if the device is not selected. | |||||
| if (!CountForStartFLJob(fbb, start_fl_job_req)) { | |||||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| return false; | |||||
| } | |||||
| StartFLJob(fbb, device_meta); | |||||
| GenerateOutput(outputs, fbb->GetBufferPointer(), fbb->GetSize()); | |||||
| return true; | |||||
| } | |||||
| bool StartFLJobKernel::Reset() { | |||||
| MS_LOG(INFO) << "Starting fl job kernel reset!"; | |||||
| StopTimer(); | |||||
| DistributedCountService::GetInstance().ResetCounter(name_); | |||||
| DistributedMetadataStore::GetInstance().ResetMetadata(kCtxDeviceMetas); | |||||
| return true; | |||||
| } | |||||
| bool StartFLJobKernel::ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb) { | |||||
| if (DistributedCountService::GetInstance().CountReachThreshold(name_)) { | |||||
| std::string reason = "Current amount for startFLJob has reached the threshold. Please startFLJob later."; | |||||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false, | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| MS_LOG(ERROR) << reason; | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| DeviceMeta StartFLJobKernel::CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req) { | |||||
| std::string fl_name = start_fl_job_req->fl_name()->str(); | |||||
| std::string fl_id = start_fl_job_req->fl_id()->str(); | |||||
| int data_size = start_fl_job_req->data_size(); | |||||
| MS_LOG(INFO) << "DeviceMeta fl_name:" << fl_name << ", fl_id:" << fl_id << ", data_size:" << data_size; | |||||
| DeviceMeta device_meta; | |||||
| device_meta.set_fl_name(fl_name); | |||||
| device_meta.set_fl_id(fl_id); | |||||
| device_meta.set_data_size(data_size); | |||||
| return device_meta; | |||||
| } | |||||
| bool StartFLJobKernel::ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) { | |||||
| bool ret = true; | |||||
| std::string reason = ""; | |||||
| if (device_meta.data_size() < 1) { | |||||
| reason = "FL job data size is not enough."; | |||||
| ret = false; | |||||
| } | |||||
| if (!ret) { | |||||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_NotSelected, reason, false, | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| MS_LOG(ERROR) << reason; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| bool StartFLJobKernel::CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, | |||||
| const schema::RequestFLJob *start_fl_job_req) { | |||||
| if (!DistributedCountService::GetInstance().Count(name_, start_fl_job_req->fl_id()->str())) { | |||||
| std::string reason = "startFLJob counting failed."; | |||||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_OutOfTime, reason, false, | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_)); | |||||
| MS_LOG(ERROR) << reason; | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| void StartFLJobKernel::StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta) { | |||||
| PBMetadata metadata; | |||||
| *metadata.mutable_device_meta() = device_meta; | |||||
| DistributedMetadataStore::GetInstance().UpdateMetadata(kCtxDeviceMetas, metadata); | |||||
| std::map<std::string, AddressPtr> feature_maps = executor_->GetModel(); | |||||
| BuildStartFLJobRsp(fbb, schema::ResponseCode_SUCCEED, "success", true, | |||||
| std::to_string(CURRENT_TIME_MILLI.count() + iteration_time_window_), feature_maps); | |||||
| return; | |||||
| } | |||||
| void StartFLJobKernel::BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||||
| const std::string &reason, const bool is_selected, | |||||
| const std::string &next_req_time, | |||||
| std::map<std::string, AddressPtr> feature_maps) { | |||||
| auto fbs_reason = fbb->CreateString(reason); | |||||
| auto fbs_next_req_time = fbb->CreateString(next_req_time); | |||||
| auto fbs_fl_name = fbb->CreateString(PSContext::instance()->fl_name()); | |||||
| schema::FLPlanBuilder fl_plan_builder(*(fbb.get())); | |||||
| fl_plan_builder.add_fl_name(fbs_fl_name); | |||||
| fl_plan_builder.add_iterations(PSContext::instance()->fl_iteration_num()); | |||||
| fl_plan_builder.add_epochs(PSContext::instance()->client_epoch_num()); | |||||
| fl_plan_builder.add_mini_batch(PSContext::instance()->client_batch_size()); | |||||
| auto fbs_fl_plan = fl_plan_builder.Finish(); | |||||
| std::vector<flatbuffers::Offset<schema::FeatureMap>> fbs_feature_maps; | |||||
| for (auto feature_map : feature_maps) { | |||||
| auto fbs_weight_fullname = fbb->CreateString(feature_map.first); | |||||
| auto fbs_weight_data = | |||||
| fbb->CreateVector(reinterpret_cast<float *>(feature_map.second->addr), feature_map.second->size / sizeof(float)); | |||||
| auto fbs_feature_map = schema::CreateFeatureMap(*(fbb.get()), fbs_weight_fullname, fbs_weight_data); | |||||
| fbs_feature_maps.push_back(fbs_feature_map); | |||||
| } | |||||
| auto fbs_feature_maps_vector = fbb->CreateVector(fbs_feature_maps); | |||||
| schema::ResponseFLJobBuilder rsp_fl_job_builder(*(fbb.get())); | |||||
| rsp_fl_job_builder.add_retcode(retcode); | |||||
| rsp_fl_job_builder.add_reason(fbs_reason); | |||||
| rsp_fl_job_builder.add_iteration(LocalMetaStore::GetInstance().curr_iter_num()); | |||||
| rsp_fl_job_builder.add_is_selected(is_selected); | |||||
| rsp_fl_job_builder.add_next_req_time(fbs_next_req_time); | |||||
| rsp_fl_job_builder.add_fl_plan_config(fbs_fl_plan); | |||||
| rsp_fl_job_builder.add_feature_map(fbs_feature_maps_vector); | |||||
| auto rsp_fl_job = rsp_fl_job_builder.Finish(); | |||||
| fbb->Finish(rsp_fl_job); | |||||
| return; | |||||
| } | |||||
| REG_ROUND_KERNEL(startFLJob, StartFLJobKernel) | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,74 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/server/executor.h" | |||||
| #include "ps/server/kernel/round/round_kernel.h" | |||||
| #include "ps/server/kernel/round/round_kernel_factory.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| namespace kernel { | |||||
| class StartFLJobKernel : public RoundKernel { | |||||
| public: | |||||
| StartFLJobKernel() = default; | |||||
| ~StartFLJobKernel() override = default; | |||||
| void InitKernel(size_t threshold_count) override; | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs) override; | |||||
| bool Reset() override; | |||||
| private: | |||||
| // Returns whether the startFLJob count of this iteration has reached the threshold. | |||||
| bool ReachThresholdForStartFLJob(const std::shared_ptr<FBBuilder> &fbb); | |||||
| // The metadata of device will be stored and queried in updateModel round. | |||||
| DeviceMeta CreateDeviceMetadata(const schema::RequestFLJob *start_fl_job_req); | |||||
| // Returns whether the request is valid for startFLJob.For now, the condition is simple. We will add more conditions | |||||
| // to device in later versions. | |||||
| bool ReadyForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta); | |||||
| // Distributed count service counts for startFLJob. | |||||
| bool CountForStartFLJob(const std::shared_ptr<FBBuilder> &fbb, const schema::RequestFLJob *start_fl_job_req); | |||||
| void StartFLJob(const std::shared_ptr<FBBuilder> &fbb, const DeviceMeta &device_meta); | |||||
| // Build response for startFLJob round no matter success or failure. | |||||
| void BuildStartFLJobRsp(const std::shared_ptr<FBBuilder> &fbb, const schema::ResponseCode retcode, | |||||
| const std::string &reason, const bool is_selected, const std::string &next_req_time, | |||||
| std::map<std::string, AddressPtr> feature_maps = {}); | |||||
| // The executor is for getting the initial model for startFLJob request. | |||||
| Executor *executor_; | |||||
| // The time window of one iteration. | |||||
| size_t iteration_time_window_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_KERNEL_START_FL_JOB_KERNEL_H_ | |||||
| @@ -14,30 +14,29 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "ps/server/local_meta_storage.h" | |||||
| #include <string> | |||||
| #include "ps/server/local_meta_store.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace server { | namespace server { | ||||
| void LocalMetaStorage::remove_value(const std::string &name) { | |||||
| void LocalMetaStore::remove_value(const std::string &name) { | |||||
| std::unique_lock<std::mutex> lock(mtx_); | std::unique_lock<std::mutex> lock(mtx_); | ||||
| if (key_to_meta_.count(name) != 0) { | if (key_to_meta_.count(name) != 0) { | ||||
| key_to_meta_.erase(key_to_meta_.find(name)); | key_to_meta_.erase(key_to_meta_.find(name)); | ||||
| } | } | ||||
| } | } | ||||
| bool LocalMetaStorage::has_value(const std::string &name) { | |||||
| bool LocalMetaStore::has_value(const std::string &name) { | |||||
| std::unique_lock<std::mutex> lock(mtx_); | std::unique_lock<std::mutex> lock(mtx_); | ||||
| return key_to_meta_.count(name) != 0; | return key_to_meta_.count(name) != 0; | ||||
| } | } | ||||
| void LocalMetaStorage::set_curr_iter_num(size_t num) { | |||||
| void LocalMetaStore::set_curr_iter_num(size_t num) { | |||||
| std::unique_lock<std::mutex> lock(mtx_); | std::unique_lock<std::mutex> lock(mtx_); | ||||
| curr_iter_num_ = num; | curr_iter_num_ = num; | ||||
| } | } | ||||
| const size_t LocalMetaStorage::curr_iter_num() { | |||||
| const size_t LocalMetaStore::curr_iter_num() { | |||||
| std::unique_lock<std::mutex> lock(mtx_); | std::unique_lock<std::mutex> lock(mtx_); | ||||
| return curr_iter_num_; | return curr_iter_num_; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ | |||||
| #include <any> | #include <any> | ||||
| #include <mutex> | #include <mutex> | ||||
| @@ -26,13 +26,13 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace server { | namespace server { | ||||
| // LocalMetaStorage class is used for metadata storage of this server process. | |||||
| // LocalMetaStore class is used for metadata storage of this server process. | |||||
| // For example, the current iteration number, time windows for round kernels, etc. | // For example, the current iteration number, time windows for round kernels, etc. | ||||
| // LocalMetaStorage is threadsafe. | |||||
| class LocalMetaStorage { | |||||
| // LocalMetaStore is threadsafe. | |||||
| class LocalMetaStore { | |||||
| public: | public: | ||||
| static LocalMetaStorage &GetInstance() { | |||||
| static LocalMetaStorage instance; | |||||
| static LocalMetaStore &GetInstance() { | |||||
| static LocalMetaStore instance; | |||||
| return instance; | return instance; | ||||
| } | } | ||||
| @@ -43,7 +43,7 @@ class LocalMetaStorage { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| const T &value(const std::string &name) { | |||||
| T value(const std::string &name) { | |||||
| std::unique_lock<std::mutex> lock(mtx_); | std::unique_lock<std::mutex> lock(mtx_); | ||||
| try { | try { | ||||
| T value = std::any_cast<T>(key_to_meta_[name]); | T value = std::any_cast<T>(key_to_meta_[name]); | ||||
| @@ -71,10 +71,10 @@ class LocalMetaStorage { | |||||
| const size_t curr_iter_num(); | const size_t curr_iter_num(); | ||||
| private: | private: | ||||
| LocalMetaStorage() = default; | |||||
| ~LocalMetaStorage() = default; | |||||
| LocalMetaStorage(const LocalMetaStorage &) = delete; | |||||
| LocalMetaStorage &operator=(const LocalMetaStorage &) = delete; | |||||
| LocalMetaStore() = default; | |||||
| ~LocalMetaStore() = default; | |||||
| LocalMetaStore(const LocalMetaStore &) = delete; | |||||
| LocalMetaStore &operator=(const LocalMetaStore &) = delete; | |||||
| // key_to_meta_ stores metadata with key-value format. | // key_to_meta_ stores metadata with key-value format. | ||||
| std::unordered_map<std::string, std::any> key_to_meta_; | std::unordered_map<std::string, std::any> key_to_meta_; | ||||
| @@ -85,4 +85,4 @@ class LocalMetaStorage { | |||||
| } // namespace server | } // namespace server | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORAGE_H_ | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_LOCAL_META_STORE_H_ | |||||
| @@ -0,0 +1,144 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/model_store.h" | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "ps/server/executor.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| void ModelStore::Init(uint32_t max_count) { | |||||
| if (!Executor::GetInstance().initialized()) { | |||||
| MS_LOG(EXCEPTION) << "Server's executor must be initialized before model storage."; | |||||
| return; | |||||
| } | |||||
| max_model_count_ = max_count; | |||||
| iteration_to_model_[kInitIterationNum] = AssignNewModelMemory(); | |||||
| model_size_ = ComputeModelSize(); | |||||
| } | |||||
| bool ModelStore::StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &new_model) { | |||||
| if (iteration_to_model_.count(iteration) != 0) { | |||||
| MS_LOG(WARNING) << "Model for iteration " << iteration << " is already stored"; | |||||
| return false; | |||||
| } | |||||
| if (new_model.empty()) { | |||||
| MS_LOG(ERROR) << "Model feature map is empty."; | |||||
| return false; | |||||
| } | |||||
| std::shared_ptr<MemoryRegister> memory_register; | |||||
| if (iteration_to_model_.size() < max_model_count_) { | |||||
| // If iteration_to_model_.size() is not max_model_count_, need to assign new memory for the model. | |||||
| memory_register = AssignNewModelMemory(); | |||||
| if (memory_register == nullptr) { | |||||
| MS_LOG(ERROR) << "Memory for the new model is nullptr."; | |||||
| return false; | |||||
| } | |||||
| iteration_to_model_[iteration] = memory_register; | |||||
| } else { | |||||
| // If iteration_to_model_ size is already max_model_count_, we need to replace earliest model with the newest model. | |||||
| memory_register = iteration_to_model_.begin()->second; | |||||
| if (memory_register == nullptr) { | |||||
| MS_LOG(ERROR) << "Earliest model is nullptr."; | |||||
| return false; | |||||
| } | |||||
| iteration_to_model_.erase(iteration_to_model_.begin()); | |||||
| } | |||||
| // Copy new model data to the the stored model. | |||||
| auto &stored_model = memory_register->addresses(); | |||||
| for (const auto &weight : new_model) { | |||||
| const std::string &weight_name = weight.first; | |||||
| if (stored_model.count(weight_name) != 0) { | |||||
| MS_LOG(ERROR) << "The stored model has no weight " << weight_name; | |||||
| continue; | |||||
| } | |||||
| void *dst_addr = stored_model[weight_name]->addr; | |||||
| size_t dst_size = stored_model[weight_name]->size; | |||||
| void *src_addr = weight.second->addr; | |||||
| size_t src_size = weight.second->size; | |||||
| int ret = memcpy_s(dst_addr, dst_size, src_addr, src_size); | |||||
| if (ret != 0) { | |||||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| iteration_to_model_[iteration] = memory_register; | |||||
| return true; | |||||
| } | |||||
| std::map<std::string, AddressPtr> ModelStore::GetModelByIterNum(size_t iteration) { | |||||
| std::map<std::string, AddressPtr> model = {}; | |||||
| if (iteration_to_model_.count(iteration) == 0) { | |||||
| MS_LOG(ERROR) << "Model for iteration " << iteration << " is not stored."; | |||||
| return model; | |||||
| } | |||||
| model = iteration_to_model_[iteration]->addresses(); | |||||
| return model; | |||||
| } | |||||
| const std::map<size_t, std::shared_ptr<MemoryRegister>> &ModelStore::iteration_to_model() const { | |||||
| return iteration_to_model_; | |||||
| } | |||||
| size_t ModelStore::model_size() const { return model_size_; } | |||||
| std::shared_ptr<MemoryRegister> ModelStore::AssignNewModelMemory() { | |||||
| std::map<std::string, AddressPtr> model = Executor::GetInstance().GetModel(); | |||||
| if (model.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Model feature map is empty."; | |||||
| return nullptr; | |||||
| } | |||||
| // Assign new memory for the model. | |||||
| std::shared_ptr<MemoryRegister> memory_register = std::make_shared<MemoryRegister>(); | |||||
| for (const auto &weight : model) { | |||||
| const std::string weight_name = weight.first; | |||||
| size_t weight_size = weight.second->size; | |||||
| auto weight_data = std::make_unique<char[]>(weight_size); | |||||
| if (weight_data == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Assign memory for weight failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memory_register->RegisterArray(weight_name, &weight_data, weight_size); | |||||
| } | |||||
| return memory_register; | |||||
| } | |||||
| size_t ModelStore::ComputeModelSize() { | |||||
| if (iteration_to_model_.empty()) { | |||||
| MS_LOG(EXCEPTION) << "Calculating model size failed: model for iteration 0 is not stored yet. "; | |||||
| return 0; | |||||
| } | |||||
| const auto &model = iteration_to_model_[kInitIterationNum]; | |||||
| MS_EXCEPTION_IF_NULL(model); | |||||
| size_t model_size = std::accumulate(model->addresses().begin(), model->addresses().end(), static_cast<size_t>(0), | |||||
| [](size_t s, const auto &weight) { return s + weight.second->size; }); | |||||
| MS_LOG(INFO) << "Model size in byte is " << model_size; | |||||
| return model_size; | |||||
| } | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,78 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/server/memory_register.h" | |||||
| #include "ps/server/executor.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| // The initial iteration number is 0 in server. | |||||
| constexpr size_t kInitIterationNum = 0; | |||||
| // Server framework use ModelStore to store and query models. | |||||
| // ModelStore stores multiple models because worker could get models of the previous iterations. | |||||
| class ModelStore { | |||||
| public: | |||||
| static ModelStore &GetInstance() { | |||||
| static ModelStore instance; | |||||
| return instance; | |||||
| } | |||||
| // Initialize ModelStore with max count of models need to be stored. | |||||
| void Init(uint32_t max_count = 3); | |||||
| // Store the model of the given iteration. The model is acquired from Executor. If the current model count is already | |||||
| // max_model_count_, the earliest model will be replaced. | |||||
| bool StoreModelByIterNum(size_t iteration, const std::map<std::string, AddressPtr> &model); | |||||
| // Get model of the given iteration. | |||||
| std::map<std::string, AddressPtr> GetModelByIterNum(size_t iteration); | |||||
| // Returns all models stored in ModelStore. | |||||
| const std::map<size_t, std::shared_ptr<MemoryRegister>> &iteration_to_model() const; | |||||
| // Returns the model size, which could be calculated at the initializing phase. | |||||
| size_t model_size() const; | |||||
| private: | |||||
| ModelStore() : max_model_count_(0), model_size_(0), iteration_to_model_({}) {} | |||||
| ~ModelStore() = default; | |||||
| ModelStore(const ModelStore &) = delete; | |||||
| ModelStore &operator=(const ModelStore &) = delete; | |||||
| // To store multiple models, new memory must assigned. The max memory size assigned for models is max_model_count_ * | |||||
| // model_size_. | |||||
| std::shared_ptr<MemoryRegister> AssignNewModelMemory(); | |||||
| // Calculate the model size. This method should be called after iteration_to_model_ is initialized. | |||||
| size_t ComputeModelSize(); | |||||
| size_t max_model_count_; | |||||
| size_t model_size_; | |||||
| std::map<size_t, std::shared_ptr<MemoryRegister>> iteration_to_model_; | |||||
| }; | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_MODEL_STORE_H_ | |||||
| @@ -25,15 +25,15 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| namespace server { | namespace server { | ||||
| bool ParameterAggregator::Init(const CNodePtr &cnode, size_t required_count) { | |||||
| bool ParameterAggregator::Init(const CNodePtr &cnode, size_t threshold_count) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| memory_register_ = std::make_shared<MemoryRegister>(); | memory_register_ = std::make_shared<MemoryRegister>(); | ||||
| MS_EXCEPTION_IF_NULL(memory_register_); | MS_EXCEPTION_IF_NULL(memory_register_); | ||||
| required_push_count_ = required_count; | |||||
| required_push_count_ = threshold_count; | |||||
| // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_. | // The required_pull_count_ is the count for Pull, which should be the same as required_push_count_. | ||||
| // required_pull_count_ normally used in parameter server training mode. | // required_pull_count_ normally used in parameter server training mode. | ||||
| required_pull_count_ = required_count; | |||||
| required_pull_count_ = threshold_count; | |||||
| MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode); | MS_LOG(DEBUG) << "Start initializing kernels for " << AnfAlgo::GetCNodeName(cnode); | ||||
| InitAggregationKernels(cnode); | InitAggregationKernels(cnode); | ||||
| @@ -61,8 +61,8 @@ class ParameterAggregator { | |||||
| ~ParameterAggregator() = default; | ~ParameterAggregator() = default; | ||||
| // Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now. | // Initialize ParameterAggregator with a cnode. This cnode is normally a optimizer kernel for now. | ||||
| // The parameter required_count helps ParameterAggregator to judge the current status if it's stateful. | |||||
| bool Init(const CNodePtr &cnode, size_t required_count = 0); | |||||
| // The parameter threshold_count helps ParameterAggregator to judge the current status if it's stateful. | |||||
| bool Init(const CNodePtr &cnode, size_t threshold_count = 0); | |||||
| // Update old data stored in ParameterAggregator with new data. | // Update old data stored in ParameterAggregator with new data. | ||||
| // The data could have many meanings: weights, gradients, learning_rate, momentum, etc. | // The data could have many meanings: weights, gradients, learning_rate, momentum, etc. | ||||
| @@ -0,0 +1,139 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/server/round.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count) | |||||
| : name_(name), | |||||
| check_timeout_(check_timeout), | |||||
| time_window_(time_window), | |||||
| check_count_(check_count), | |||||
| threshold_count_(threshold_count) {} | |||||
| void Round::Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb, | |||||
| FinishIterCb finish_iteration_cb) { | |||||
| MS_EXCEPTION_IF_NULL(communicator); | |||||
| communicator_ = communicator; | |||||
| // Register callback for round kernel. | |||||
| communicator_->RegisterMsgCallBack( | |||||
| name_, [&](std::shared_ptr<core::MessageHandler> message) { LaunchRoundKernel(message); }); | |||||
| // Callback when the iteration is finished. | |||||
| finish_iteration_cb_ = [this, finish_iteration_cb](void) -> void { | |||||
| MS_LOG(INFO) << "Round " << name_ << " finished! Proceed to next iteration."; | |||||
| finish_iteration_cb(); | |||||
| }; | |||||
| // Callback for finalizing the server. This can only be called once. | |||||
| finalize_cb_ = [&](void) -> void { communicator_->Stop(); }; | |||||
| if (check_timeout_) { | |||||
| iter_timer_ = std::make_shared<IterationTimer>(); | |||||
| // 1.Set the timeout callback for the timer. | |||||
| iter_timer_->SetTimeOutCallBack([this, timeout_cb](void) -> void { | |||||
| MS_LOG(INFO) << "Round " << name_ << " timeout! Proceed to next iteration."; | |||||
| timeout_cb(); | |||||
| }); | |||||
| // 2.Stopping timer callback which will be set to the round kernel. | |||||
| stop_timer_cb_ = [&](void) -> void { | |||||
| MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer."; | |||||
| iter_timer_->Stop(); | |||||
| }; | |||||
| } | |||||
| // Set counter event callbacks for this round if the round kernel is stateful. | |||||
| if (check_count_) { | |||||
| auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1); | |||||
| auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1); | |||||
| DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_, | |||||
| {first_count_handler, last_count_handler}); | |||||
| } | |||||
| } | |||||
| void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | |||||
| kernel_ = kernel; | |||||
| kernel_->set_stop_timer_cb(stop_timer_cb_); | |||||
| kernel_->set_finish_iteration_cb(finish_iteration_cb_); | |||||
| return; | |||||
| } | |||||
| void Round::LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message) { | |||||
| if (message == nullptr) { | |||||
| MS_LOG(ERROR) << "Message is nullptr."; | |||||
| return; | |||||
| } | |||||
| AddressPtr input = std::make_shared<Address>(); | |||||
| AddressPtr output = std::make_shared<Address>(); | |||||
| input->addr = message->data(); | |||||
| input->size = message->len(); | |||||
| bool ret = kernel_->Launch({input}, {}, {output}); | |||||
| if (output->size == 0) { | |||||
| std::string reason = "The output of the round " + name_ + " is empty."; | |||||
| MS_LOG(WARNING) << reason; | |||||
| communicator_->SendResponse(reason.c_str(), reason.size(), message); | |||||
| return; | |||||
| } | |||||
| // Must send response back no matter what value Launch method returns. | |||||
| if (!ret) { | |||||
| MS_LOG(WARNING) << "Launching round kernel of round " << name_ << " failed."; | |||||
| } | |||||
| communicator_->SendResponse(output->addr, output->size, message); | |||||
| kernel_->Release(output); | |||||
| return; | |||||
| } | |||||
| void Round::Reset() { kernel_->Reset(); } | |||||
| const std::string &Round::name() const { return name_; } | |||||
| size_t Round::threshold_count() const { return threshold_count_; } | |||||
| size_t Round::time_window() const { return time_window_; } | |||||
| void Round::OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &) { | |||||
| MS_LOG(INFO) << "Round " << name_ << " first count event is triggered."; | |||||
| // The timer starts only after the first count event is triggered by DistributedCountService. | |||||
| if (check_timeout_) { | |||||
| iter_timer_->Start(std::chrono::milliseconds(time_window_)); | |||||
| } | |||||
| return; | |||||
| } | |||||
| void Round::OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message) { | |||||
| MS_LOG(INFO) << "Round " << name_ << " last count event is triggered."; | |||||
| // Same as the first count event, the timer must be stopped by DistributedCountService. | |||||
| if (check_timeout_) { | |||||
| iter_timer_->Stop(); | |||||
| } | |||||
| // Some kernels override the OnLastCountEvent method. | |||||
| kernel_->OnLastCountEvent(message); | |||||
| return; | |||||
| } | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,95 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ | |||||
| #define MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "ps/core/communicator/communicator_base.h" | |||||
| #include "ps/server/common.h" | |||||
| #include "ps/server/iteration_timer.h" | |||||
| #include "ps/server/distributed_count_service.h" | |||||
| #include "ps/server/kernel/round/round_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| namespace server { | |||||
| // Round helps server to handle network round messages and launch round kernels. One iteration in server consists of | |||||
| // multiple rounds like startFLJob, updateModel, Push, Pull, etc. Some round kernels may be stateful because of counting | |||||
| // and timing. So Round helps register counter and timer so that the round kernels only need to focus on the logic. | |||||
| class Round { | |||||
| public: | |||||
| explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000, | |||||
| bool check_count = false, size_t threshold_count = 8); | |||||
| ~Round() = default; | |||||
| void Initialize(const std::shared_ptr<core::CommunicatorBase> &communicator, TimeOutCb timeout_cb, | |||||
| FinishIterCb finish_iteration_cb); | |||||
| // Bind a round kernel to this Round. This method should be called after Initialize. | |||||
| void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel); | |||||
| // This method is the callback which will be set to the communicator and called after the corresponding round message | |||||
| // is sent to the server. | |||||
| void LaunchRoundKernel(const std::shared_ptr<core::MessageHandler> &message); | |||||
| // Round needs to be reset after each iteration is finished or its timer expires. | |||||
| void Reset(); | |||||
| const std::string &name() const; | |||||
| size_t threshold_count() const; | |||||
| size_t time_window() const; | |||||
| private: | |||||
| // The callbacks which will be set to DistributedCounterService. | |||||
| void OnFirstCountEvent(const std::shared_ptr<core::MessageHandler> &message); | |||||
| void OnLastCountEvent(const std::shared_ptr<core::MessageHandler> &message); | |||||
| std::string name_; | |||||
| // Whether this round needs to use timer. Most rounds in federated learning with mobile devices scenario need to set | |||||
| // check_timeout_ to true. | |||||
| bool check_timeout_; | |||||
| // The time window duration for this round when check_timeout_ is set to true. | |||||
| size_t time_window_; | |||||
| // If check_count_ is true, it means the round has to do counting for every round message and the first/last count | |||||
| // event will be triggered. | |||||
| bool check_count_; | |||||
| // The threshold count for this round when check_count_ is set to true. The logic of this round has to check whether | |||||
| // the round message count has reached threshold_count_. | |||||
| size_t threshold_count_; | |||||
| std::shared_ptr<core::CommunicatorBase> communicator_; | |||||
| // The round kernel for this Round. | |||||
| std::shared_ptr<kernel::RoundKernel> kernel_; | |||||
| // Some rounds may need timer to eliminate the long tail effect. | |||||
| std::shared_ptr<IterationTimer> iter_timer_; | |||||
| // The callbacks which will be set to the round kernel. | |||||
| StopTimerCb stop_timer_cb_; | |||||
| FinishIterCb finish_iteration_cb_; | |||||
| FinalizeCb finalize_cb_; | |||||
| }; | |||||
| } // namespace server | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_SERVER_ROUND_H_ | |||||
| @@ -31,7 +31,7 @@ | |||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #include "debug/env_config_parser.h" | #include "debug/env_config_parser.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | #include "ps/ps_cache/ps_cache_manager.h" | ||||
| #endif | #endif | ||||
| @@ -307,7 +307,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| } | } | ||||
| need_alloc_nodes.push_back(item); | need_alloc_nodes.push_back(item); | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| bool ps_cache_check = false; | bool ps_cache_check = false; | ||||
| #endif | #endif | ||||
| for (auto &item : need_alloc_nodes) { | for (auto &item : need_alloc_nodes) { | ||||
| @@ -320,7 +320,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| DeviceAddressPtr device_address = nullptr; | DeviceAddressPtr device_address = nullptr; | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| const std::string ¶m_name = item->fullname_with_scope(); | const std::string ¶m_name = item->fullname_with_scope(); | ||||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | if (ps::ps_cache_instance.IsHashTable(param_name)) { | ||||
| MS_LOG(INFO) << "Parameter(" << param_name << ")" | MS_LOG(INFO) << "Parameter(" << param_name << ")" | ||||
| @@ -1038,7 +1038,7 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st | |||||
| return device_address; | return device_address; | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, | void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, | ||||
| AnfNodePtr *const first_cache_input_index, | AnfNodePtr *const first_cache_input_index, | ||||
| size_t *const first_cache_size) { | size_t *const first_cache_size) { | ||||
| @@ -142,7 +142,7 @@ class KernelRuntime { | |||||
| void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); | void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); | ||||
| void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); | void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); | ||||
| DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); | DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index, | void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *const first_cache_input_index, | ||||
| size_t *const first_cache_size); | size_t *const first_cache_size); | ||||
| void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph); | void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph); | ||||
| @@ -16,14 +16,14 @@ | |||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | #include "ps/ps_cache/ps_cache_manager.h" | ||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| void KernelRuntimeManager::ClearRuntimeResource() { | void KernelRuntimeManager::ClearRuntimeResource() { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | ||||
| ps::ps_cache_instance.SyncEmbeddingTable(); | ps::ps_cache_instance.SyncEmbeddingTable(); | ||||
| } | } | ||||
| @@ -125,7 +125,7 @@ void KernelRuntimeManager::ReleaseKernelRuntime(const std::string &device_name, | |||||
| if (runtime == nullptr) { | if (runtime == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #if (ENABLE_CPU && !_WIN32) | |||||
| if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | ||||
| ps::ps_cache_instance.SyncEmbeddingTable(); | ps::ps_cache_instance.SyncEmbeddingTable(); | ||||
| } | } | ||||
| @@ -0,0 +1,123 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| namespace mindspore.schema; | |||||
| table CipherPublicParams { | |||||
| t:int; | |||||
| p:[ubyte]; | |||||
| g:int; | |||||
| prime:[ubyte]; | |||||
| dp_eps:float; | |||||
| dp_delta:float; | |||||
| dp_norm_clip:float; | |||||
| encrypt_type:int; | |||||
| } | |||||
| table ClientPublicKeys { | |||||
| fl_id:string; | |||||
| c_pk:[ubyte]; | |||||
| s_pk: [ubyte]; | |||||
| } | |||||
| table ClientShare { | |||||
| fl_id:string; | |||||
| share:[ubyte]; | |||||
| index:int; | |||||
| } | |||||
| table RequestExchangeKeys{ | |||||
| fl_id:string; | |||||
| c_pk:[ubyte]; | |||||
| s_pk:[ubyte]; | |||||
| iteration:int; | |||||
| timestamp:string; | |||||
| } | |||||
| table ResponseExchangeKeys{ | |||||
| retcode:int; | |||||
| reason:string; | |||||
| next_req_time:string; | |||||
| iteration:int; | |||||
| } | |||||
| table GetExchangeKeys{ | |||||
| fl_id:string; | |||||
| iteration:int; | |||||
| timestamp:string; | |||||
| } | |||||
| table ReturnExchangeKeys{ | |||||
| retcode:int; | |||||
| iteration:int; | |||||
| remote_publickeys:[ClientPublicKeys]; | |||||
| next_req_time:string; | |||||
| } | |||||
| table RequestShareSecrets{ | |||||
| fl_id:string; | |||||
| encrypted_shares:[ClientShare]; | |||||
| iteration:int; | |||||
| timestamp:string; | |||||
| } | |||||
| table ResponseShareSecrets{ | |||||
| retcode:int; | |||||
| reason:string; | |||||
| next_req_time:string; | |||||
| iteration:int; | |||||
| } | |||||
| table GetShareSecrets{ | |||||
| fl_id:string; | |||||
| iteration:int; | |||||
| timestamp:string; | |||||
| } | |||||
| table ReturnShareSecrets{ | |||||
| retcode:int; | |||||
| iteration:int; | |||||
| encrypted_shares: [ClientShare]; | |||||
| next_req_time:string; | |||||
| } | |||||
| table GetClientList{ | |||||
| fl_id:string; | |||||
| iteration:int; | |||||
| timestamp:string; | |||||
| } | |||||
| table ReturnClientList{ | |||||
| retcode:int; | |||||
| reason:string; | |||||
| clients:[string]; | |||||
| iteration:int; | |||||
| next_req_time:string; | |||||
| } | |||||
| table SendReconstructSecret{ | |||||
| fl_id:string; | |||||
| reconstruct_secret_shares:[ClientShare]; | |||||
| iteration:int; | |||||
| timestamp:string; | |||||
| } | |||||
| table ReconstructSecret{ | |||||
| retcode:int; | |||||
| reason:string; | |||||
| iteration:int; | |||||
| next_req_time:string; | |||||
| } | |||||
| @@ -0,0 +1,159 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| include "cipher.fbs"; | |||||
| namespace mindspore.schema; | |||||
| file_identifier "FLJ0"; | |||||
| file_extension "fl"; | |||||
| enum ResponseCode: int { | |||||
| SUCCEED=200, | |||||
| SucNotReady=201, | |||||
| RepeatRequest=202, | |||||
| SucNotMatch=204, | |||||
| OutOfTime=300, | |||||
| NotSelected=301, | |||||
| RequestError=400, | |||||
| SystemError=500 | |||||
| } | |||||
| enum AggregationType:byte {FedAvg=0, FedAdam = 1, FedAdagrag=2, FedMeta=3, qffl=4} | |||||
| enum Metrics:byte {accuracy = 0, precision = 1, recall = 2, AUC = 3,f1=4, fbeta=5} | |||||
| enum EarlyStopType:byte {loss_diff = 0, loss_abs = 1, weight_diff = 2} | |||||
| table Aggregation { | |||||
| type:AggregationType; | |||||
| weights:[float]; | |||||
| } | |||||
| table EarlyStop { | |||||
| early_stop_type:EarlyStopType; | |||||
| weight:float; | |||||
| rounds:int; | |||||
| } | |||||
| table FeatureMap{ | |||||
| weight_fullname:string; | |||||
| data:[float]; | |||||
| } | |||||
| table RequestFLJob{ | |||||
| fl_name:string; | |||||
| fl_id:string; | |||||
| iteration:int; | |||||
| data_size:int; | |||||
| timestamp:string; | |||||
| } | |||||
| table ResponseFLJob { | |||||
| retcode:int; | |||||
| reason:string; | |||||
| iteration:int; | |||||
| is_selected:bool = false; | |||||
| next_req_time:string; | |||||
| fl_plan_config:FLPlan; | |||||
| feature_map:[FeatureMap]; | |||||
| timestamp:string; | |||||
| } | |||||
| table FLPlan { | |||||
| fl_name:string; | |||||
| iterations:int; | |||||
| epochs:int; | |||||
| early_stop:EarlyStop; | |||||
| mini_batch:int; | |||||
| shuffle:bool = false; | |||||
| lr:float; | |||||
| aggregation:Aggregation; | |||||
| metrics:[Metrics]; | |||||
| cipher:CipherPublicParams; | |||||
| } | |||||
| table RequestUpdateModel{ | |||||
| fl_name:string; | |||||
| fl_id:string; | |||||
| iteration:int; | |||||
| feature_map:[FeatureMap]; | |||||
| timestamp:string; | |||||
| } | |||||
| table ResponseUpdateModel{ | |||||
| retcode:int; | |||||
| reason:string; | |||||
| feature_map:[FeatureMap]; | |||||
| next_req_time:string; | |||||
| timestamp:string; | |||||
| } | |||||
| table RequestAsyncUpdateModel{ | |||||
| fl_name:string; | |||||
| fl_id:string; | |||||
| iteration:int; | |||||
| data_size:int; | |||||
| feature_map:[FeatureMap]; | |||||
| } | |||||
| table ResponseAsyncUpdateModel{ | |||||
| retcode:int; | |||||
| reason:string; | |||||
| iteration:int; | |||||
| } | |||||
| table RequestOverwriteWeightsByKey{ | |||||
| iteration:int; | |||||
| feature_map:[FeatureMap]; | |||||
| } | |||||
| table ResponseOverwriteWeightsByKey{ | |||||
| retcode:int; | |||||
| reason:string; | |||||
| } | |||||
| table RequestGetModel{ | |||||
| fl_name:string; | |||||
| iteration:int; | |||||
| timestamp:string; | |||||
| } | |||||
| table ResponseGetModel{ | |||||
| retcode:int; | |||||
| reason:string; | |||||
| iteration:int; | |||||
| feature_map:[FeatureMap]; | |||||
| timestamp:string; | |||||
| } | |||||
| table RequestAsyncGetModel{ | |||||
| fl_name:string; | |||||
| iteration:int; | |||||
| } | |||||
| table ResponseAsyncGetModel{ | |||||
| retcode:int; | |||||
| reason:string; | |||||
| iteration:int; | |||||
| feature_map:[FeatureMap]; | |||||
| } | |||||
| table RequestGetWeightsByKey{ | |||||
| iteration:int; | |||||
| weight_names:[string]; | |||||
| } | |||||
| table ResponseGetWeightsByKey{ | |||||
| retcode:int; | |||||
| reason:string; | |||||
| feature_map:[FeatureMap]; | |||||
| } | |||||
| // FeatureMapList refers to the whole trained model. | |||||
| table FeatureMapList { | |||||
| feature_map:[FeatureMap]; | |||||
| } | |||||