diff --git a/cmake/package.cmake b/cmake/package.cmake index 5740bbc08d..d6d5e3a73e 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -194,6 +194,14 @@ if (ENABLE_GPU) ) endif () +if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) + install( + TARGETS ps_cache + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore + ) +endif() + if (ENABLE_SERVING OR ENABLE_TESTCASES) file(GLOB_RECURSE LIBEVENT_LIB_LIST ${libevent_LIBPATH}/libevent* diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 2ec0221826..e48e5a7e80 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -308,7 +308,7 @@ elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") else () if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) - target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core) + target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) if (${ENABLE_IBVERBS} STREQUAL "ON") target_link_libraries(mindspore ibverbs rdmacm) endif() diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc index 98e84b83b1..8ef799b678 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc @@ -75,6 +75,9 @@ void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector &inputs if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) { node_so_ = CUST_AICPU_OPS_SO_NAME; node_name_ = kCustRunApi; + } else if (kCacheKernelOps.find(node_name_) != kCacheKernelOps.end()) { + node_so_ = AICPU_OPS_SO_NAME; + node_name_ = kCustRunApi; } else { node_so_ = AICPU_OPS_SO_NAME; } @@ -161,6 +164,9 @@ std::vector AicpuOpKernelMod::GenTask(const std::vector if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) { node_so_ = CUST_AICPU_OPS_SO_NAME; node_name_ = kCustRunApi; + } else if (kCacheKernelOps.find(node_name_) != kCacheKernelOps.end()) { + node_so_ = AICPU_OPS_SO_NAME; + node_name_ = kCustRunApi; } else { node_so_ = AICPU_OPS_SO_NAME; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h index 7bd5974a17..2762a2f87f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h @@ -49,6 +49,7 @@ constexpr auto kIdentity = "Identity"; constexpr auto kUpdateCache = "UpdateCache"; constexpr auto kCustRunApi = "RunCpuKernel"; const std::set kCustAiCpuKernelOps{kEditDistance, kIdentity}; +const std::set kCacheKernelOps{kUpdateCache}; struct AicpuParamHead { uint32_t length; // Total length: include cunstom message diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc index 5c5f80e2a7..179c4f0a04 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc @@ -15,6 +15,7 @@ */ #include "backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h" #include +#include #include "ps/worker.h" namespace mindspore { @@ -38,10 +39,13 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); } std::vector keys{key_, key_, key_}; - std::vector values; - values.insert(values.end(), input_shape.begin(), input_shape.end()); - values.insert(values.end(), indices_shape.begin(), indices_shape.end()); - values.insert(values.end(), output_shape.begin(), output_shape.end()); + std::vector values; + std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(values), + [](size_t dim) -> float { return SizeToFloat(dim); }); + std::transform(indices_shape.begin(), indices_shape.end(), std::back_inserter(values), + [](size_t dim) -> float { return SizeToFloat(dim); }); + std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(values), + [](size_t dim) -> float { return SizeToFloat(dim); }); MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; std::vector lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()), diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc index c3ec61daa1..97694e452a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc @@ -72,6 +72,23 @@ bool EmbeddingLookUpPSKernel::Execute(const std::vector &inputs, con return Launch(inputs, workspace, outputs); } +void EmbeddingLookUpPSKernel::UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, + const float *update_vals, size_t ids_size) { + size_t copy_lens = outer_dim_size_ * sizeof(float); + for (size_t i = 0; i < ids_size; ++i) { + int index = lookup_ids[i] - offset_; + if (index >= 0 && index < SizeToInt(first_dim_size_)) { + auto ret = + memcpy_s(embedding_table + index * outer_dim_size_, copy_lens, update_vals + i * outer_dim_size_, copy_lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; + } + } else { + MS_LOG(EXCEPTION) << "UpdateEmbeddings index invalid."; + } + } +} + const std::vector &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; } const std::vector &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h index 251aaba5e3..66e27cca99 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h @@ -35,7 +35,8 @@ class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerK bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) override; - + void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals, + size_t ids_size) override; const std::vector &input_sizes() const override; const std::vector &output_sizes() const override; const std::vector &workspace_sizes() const override; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h index a2845af863..c36466d580 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h @@ -38,7 +38,8 @@ class PServerKernel { virtual void ReInit(const std::vector> &) {} virtual bool Execute(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs) = 0; - + virtual void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals, + size_t ids_size) {} virtual const std::vector &input_sizes() const = 0; virtual const std::vector &output_sizes() const = 0; virtual const std::vector &workspace_sizes() const = 0; diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index d7e3692698..5881c5f68a 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -56,6 +56,7 @@ #include "toolchain/adx_datadump_server.h" #if ENABLE_CPU && ENABLE_D #include "ps/util.h" +#include "ps/ps_cache/ps_cache_manager.h" #endif namespace mindspore { @@ -487,11 +488,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull func_graph) { // adjust kernel AdjustKernel(root_graph); #if ENABLE_CPU && ENABLE_D - if (ps::Util::IsParamServerMode()) { - CheckPSModeConsistence(root_graph); - // Assign parameter keys. - AssignParamKey(root_graph); - } + InitPsWorker(root_graph); #endif // assign stream AssignStream(NOT_NULL(root_graph)); @@ -568,6 +565,9 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) { } // adjust execution order because merge child graph and other special operations AdjustKernel(graph); +#if ENABLE_CPU && ENABLE_D + InitPsWorker(graph); +#endif // Reorder optimizer order auto execution_order = graph->execution_order(); Reorder(&execution_order); @@ -644,6 +644,10 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector &kernel_graph, auto input_node = input_nodes[i]; MS_EXCEPTION_IF_NULL(input_node); if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { +#if ENABLE_CPU && ENABLE_GPU + const std::string ¶m_name = input_node->fullname_with_scope(); + if (ps::ps_cache_instance.IsHashTable(param_name)) { + continue; + } +#endif auto pk_node = input_node->cast(); auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); auto tensor_address = std::dynamic_pointer_cast(tensor->device_address()); @@ -306,16 +313,11 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr HardwareOptimize(graph); // Graph kernel fusion optimization GraphKernelOptimize(graph); - -#if ENABLE_CPU && ENABLE_GPU - if (ps::Util::IsParamServerMode()) { - CheckPSModeConsistence(graph); - // Assign parameter keys. - AssignParamKey(graph); - } -#endif // Start gpu kernel runtime StartKernelRT(); +#if ENABLE_CPU && ENABLE_GPU + InitPsWorker(graph); +#endif // Assign CUDA streams AssignStream(graph); // Dump .pb graph before remove nop nodes @@ -380,6 +382,12 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vectorexecution_order().size(); int64_t loopsize = (kernel_num > 1) ? ConfigManager::GetInstance().gpu_loopsink_size() : 1; for (int64_t i = 0; i < loopsize; i++) { +#if ENABLE_CPU && ENABLE_GPU + std::string channel_name; + if (ps::PsDataPrefetch::GetInstance().cache_enable() && IsGetNextGraph(graph_id, &channel_name)) { + ps::ps_cache_instance.IncreaseGraphStep(channel_name); + } +#endif Execute(kernel_graph); } // In pynative mode, device addresses of tensors in value nodes need be clean. diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 1b361eee77..4fb20f4d61 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -41,8 +41,10 @@ #include "utils/trace_base.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) -#include "ps/worker.h" +#include "ps/ps_cache/ps_cache_manager.h" +#include "ps/common.h" #include "ps/util.h" +#include "abstract/abstract_value.h" #endif namespace mindspore { @@ -1125,6 +1127,12 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); } if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + const std::string ¶m_name = input_node->fullname_with_scope(); + if (ps::ps_cache_instance.IsHashTable(param_name)) { + continue; + } +#endif auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); MS_EXCEPTION_IF_NULL(device_address); if (size != 0 && !device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), size, @@ -1715,8 +1723,64 @@ void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptrexecution_order()) { + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + if (kernel_name == kGetNextOpName) { + auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); + MS_EXCEPTION_IF_NULL(prim); + *channel_name = GetValue(prim->GetAttr("shared_name")); + return true; + } + } + return false; +} + #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) -void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { +void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { + if (!ps::Util::IsRoleOfWorker()) { + return; + } + CheckPSModeConsistence(kernel_graph); + if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + if (!ps::ps_cache_instance.initialized_ps_cache()) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto devcie_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(devcie_target, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + auto context = runtime_instance->context(); + const auto &kernels = kernel_graph->execution_order(); + if (kernels.size() > 0 && AnfAlgo::GetCNodeName(kernels[0]) == "InitDataSetQueue") { + GetBatchElements(kernels[0]); + ps::ps_cache_instance.Initialize(); + } + ps::ps_cache_instance.DoProcessData(device_id_, context); + } + } else { + // Assign parameter keys. + AssignParamKey(kernel_graph); + } +} + +void SessionBasic::GetBatchElements(const AnfNodePtr &kernel_node) const { + auto shapes = AnfAlgo::GetNodeAttr>>(kernel_node, "shapes"); + auto types = AnfAlgo::GetNodeAttr>(kernel_node, "types"); + if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) { + MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size " + << types; + } + size_t batch_elements = 1; + const auto &shape = shapes[0]; + for (size_t i = 0; i < shape.size(); ++i) { + batch_elements *= shape[i]; + } + ps::ps_cache_instance.set_batch_elements(batch_elements); +} + +void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const { auto input_nodes = kernel_graph->inputs(); for (const auto &input_node : input_nodes) { if (!input_node->isa()) { @@ -1725,8 +1789,9 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { auto pk_node = input_node->cast(); MS_EXCEPTION_IF_NULL(pk_node); auto param_info_ptr = pk_node->param_info(); - if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) { - const std::string ¶m_name = pk_node->fullname_with_scope(); + const std::string ¶m_name = pk_node->fullname_with_scope(); + if (param_info_ptr != nullptr && param_info_ptr->init_in_server() && + !ps::ps_cache_instance.IsHashTable(param_name)) { MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name << "] in server, this parameter is used by kernel which executes in device"; } @@ -1734,10 +1799,6 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { } void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { - if (!ps::Util::IsRoleOfWorker()) { - MS_LOG(INFO) << "Not parameter server mode."; - return; - } MS_EXCEPTION_IF_NULL(kernel_graph); std::vector node_list = TopoSort(kernel_graph->get_return()); for (auto &node : node_list) { @@ -1775,16 +1836,8 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, return; } std::vector inputs(inputs_const); - size_t input_ctrl_size = 1; MS_EXCEPTION_IF_NULL(kernel_graph); - if (kernel_graph->input_ctrl_tensors()) { - input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); - } auto input_nodes = kernel_graph->inputs(); - if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) { - MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() - << ", input_ctrl_size:" << input_ctrl_size; - } auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); for (size_t i = 0; i < inputs.size(); ++i) { diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 1f17a56abb..9d6b64903b 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -99,9 +99,9 @@ class SessionBasic : public std::enable_shared_from_this { // get graph id in child graphs by ME front anf node pointer virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const; virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } - void CheckPSModeConsistence(const KernelGraphPtr &Kernel_graph); void AssignParamKey(const KernelGraphPtr &kernel_graph); void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector &inputs_const); + bool IsGetNextGraph(const GraphId &graph_id, std::string *channel_name); virtual bool CheckModelInputs(uint32_t graph_id, const std::vector &inputs, std::string *error_msg) const { return true; @@ -195,6 +195,11 @@ class SessionBasic : public std::enable_shared_from_this { AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector &node_list); void UpdateGraphDynamicShapeAttr(const NotNull &root_graph); void UpdateAllGraphDynamicShapeAttr(const std::vector &all_graphs); +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const; + void GetBatchElements(const AnfNodePtr &kernel_node) const; + void InitPsWorker(const KernelGraphPtr &kernel_graph); +#endif std::unordered_map> graphs_; std::unordered_map> run_op_graphs_; @@ -207,6 +212,9 @@ class SessionBasic : public std::enable_shared_from_this { #if !defined(_WIN32) && !defined(_WIN64) std::shared_ptr debugger_; #endif +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + bool initialized_ps_cache_{false}; +#endif }; using SessionPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index adba2b3bcc..980673ad48 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -24,6 +24,9 @@ #include "frontend/parallel/device_matrix.h" #include "frontend/parallel/graph_util/generate_graph.h" +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#endif namespace mindspore { namespace parallel { @@ -514,6 +517,12 @@ Status GatherV2PInfo::InferBias() { if (repeated_calc_num_ > 1) { rank = rank / repeated_calc_num_; } +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + bias_ = 0; + return SUCCESS; + } +#endif bias_ = rank / params_strategy.at(1) * slice_size_; return SUCCESS; } diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index cfd6f396e1..cb347102f8 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -46,10 +46,18 @@ #include "ir/anf.h" #include "ir/param_info.h" #include "ir/tensor.h" +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "ps/util.h" +#endif namespace mindspore { namespace parallel { bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { + return false; + } +#endif MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 7fe85d60e0..e0e344fecb 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -44,6 +44,9 @@ #include "utils/comm_manager.h" #include "utils/ms_context.h" #include "utils/symbolic.h" +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "ps/util.h" +#endif using mindspore::tensor::Tensor; @@ -3036,6 +3039,11 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) { } bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { + return false; + } +#endif MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index a19c57d196..64461cd175 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -202,6 +202,7 @@ else () if (${ENABLE_IBVERBS} STREQUAL "ON") target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) endif () + target_link_libraries(_c_dataengine PRIVATE ps_cache) endif () endif () diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc index 594614266f..327942d48a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -322,6 +322,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector &data_size, con bool profiling, int32_t *push_time) { std::vector items; double start_time; + bool ps_data_prefetch = false; for (int i = 0; i < data_size.size(); i++) { device::DataItemGpu data_item; data_item.data_len_ = data_size[i]; @@ -334,6 +335,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector &data_size, con if (profiling) { start_time = ProfilingTime::GetCurMilliSecond(); } + // Data prefetch only when PS mode enables cache. + if ((!ps_data_prefetch) && (items.size() > 0)) { + ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_); + ps_data_prefetch = true; + } BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); if (profiling) { double end_time = ProfilingTime::GetCurMilliSecond(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h index 5099cb20e5..7804e123a3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.h @@ -24,6 +24,7 @@ #include "minddata/dataset/engine/datasetops/pipeline_op.h" #include "minddata/dataset/engine/datasetops/repeat_op.h" #include "minddata/dataset/util/status.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" #ifdef ENABLE_TDTQUE #include "minddata/dataset/util/queue.h" diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc index bcac339d64..78c8964e86 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc @@ -17,6 +17,8 @@ #include "utils/ms_utils.h" #include "minddata/dataset/engine/perf/profiling.h" #include "minddata/dataset/util/log_adapter.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" + namespace mindspore { namespace dataset { static std::shared_ptr instance_ptr_ = nullptr; @@ -48,6 +50,10 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe if (profiling) { start_time = ProfilingTime::GetCurMilliSecond(); } + // Data prefetch only when PS mode enables cache. + if (items.size() > 0) { + ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_); + } if (tdt::TdtHostPushData(channel_name, items) != 0) { return FAILED; } diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index a6b677a50a..01b2af52e4 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -308,7 +308,14 @@ PYBIND11_MODULE(_c_expression, m) { .def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.") .def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.") .def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.") - .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id."); + .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.") + .def("insert_hash_table_size", &PSContext::InsertHashTableSize, "Insert hash table size.") + .def("reinsert_hash_table_size", &PSContext::ReInsertHashTableSize, + "Insert hash table size with new parameter name.") + .def("insert_weight_init_info", &PSContext::InsertWeightInitInfo, "Insert embedding table initialization seed.") + .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.") + .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.") + .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not."); (void)py::class_>(m, "OpInfoLoaderPy") .def(py::init()) diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 0e85d79fdb..bb2cc4910f 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -52,6 +52,7 @@ #include "ps/common.h" #include "ps/util.h" #include "ps/worker.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" #endif #if (ENABLE_GE || ENABLE_D) @@ -921,6 +922,11 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, const std::vector &types, const std::vector> &shapes, const std::vector &input_indexes, bool need_run) { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if ((ps::Util::IsParamServerMode()) && (!ps::Util::IsRoleOfWorker())) { + return true; + } +#endif MS_LOG(INFO) << "Start InitDataSet Entry"; ShapeVector int_input_indexes; (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), @@ -966,7 +972,17 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { backend->Link(runner.graph_id); } - ConfigManager::GetInstance().set_iter_num(size); + // PS mode does not support loop sink. +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (ps::Util::IsRoleOfWorker()) { + ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); + ConfigManager::GetInstance().set_iter_num(1); + } else { +#endif + ConfigManager::GetInstance().set_iter_num(size); +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + } +#endif if (!(*runner.run)) { // empty function @@ -981,7 +997,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc } MS_LOG(DEBUG) << "InitDataSetVm End."; return true; -} +} // namespace pipeline void ResetOpId() { mindspore::id_generator::reset_id(); } diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index c45c7d281f..2700a89d63 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -14,22 +14,20 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") + list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") endif () if (NOT ENABLE_D) list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc") - list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") endif() if (NOT ENABLE_GPU) list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") - list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") endif() list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") add_subdirectory(ps_cache) - set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) diff --git a/mindspore/ccsrc/ps/common.h b/mindspore/ccsrc/ps/common.h index 4b854020bc..dab5567976 100644 --- a/mindspore/ccsrc/ps/common.h +++ b/mindspore/ccsrc/ps/common.h @@ -64,6 +64,7 @@ constexpr int64_t kInitWeightToOptimIdCmd = 11; constexpr int64_t kInitOptimInputsShapeCmd = 12; constexpr int64_t kInitKeyToPushNodeIdCmd = 13; constexpr int64_t kInitEmbeddingsCmd = 20; +constexpr int64_t kUpdateEmbeddingsCmd = 21; constexpr int64_t kCheckReadyForPushCmd = 25; constexpr int64_t kCheckReadyForPullCmd = 26; constexpr int64_t kEmbeddingLookupCmd = 30; diff --git a/mindspore/ccsrc/ps/parameter_server.h b/mindspore/ccsrc/ps/parameter_server.h index cbb7e5be6a..1ce1d096ae 100644 --- a/mindspore/ccsrc/ps/parameter_server.h +++ b/mindspore/ccsrc/ps/parameter_server.h @@ -51,6 +51,8 @@ #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#include "ps/random_normal/random_normal.h" namespace mindspore { namespace ps { @@ -100,6 +102,7 @@ class ParameterServer { void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); + void HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); ParameterServer *ps_; @@ -118,13 +121,15 @@ class ParameterServer { void InitWeight(const Key &key, const WeightPtr &weight); void InitGrad(const Key &key, const GradPtr &grad); void InitEmbeddingTable(const Key &key, - const std::shared_ptr>>> &shapes); + const std::shared_ptr>>> &shapes, + const ParamInitInfo ¶m_init_info); bool HasWeight(const Key &key); void Finalize(); void UpdateWeights(); void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); WeightPtr weight(const Key &key); void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); + void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals); bool ReadyForUpdateWeights(); bool ReadyForPush(const Key &key); bool ReadyForPull(const Key &key); @@ -193,6 +198,7 @@ void ParameterServer::ServerHandler::Init() { handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush; handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull; handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup; + handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings; handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize; } @@ -302,7 +308,17 @@ void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta for (int64_t k = 0; k < lens[2]; k++) { output_shape->push_back(static_cast(req_data.vals[index++])); } - ps_->InitEmbeddingTable(key, shapes); + ParamInitInfo param_init_info; + if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + param_init_info.param_type_ = static_cast(lens[3]); + if (param_init_info.param_type_ == kWeight) { + param_init_info.global_seed_ = static_cast(lens[4]); + param_init_info.op_seed_ = static_cast(lens[5]); + } else if (param_init_info.param_type_ == kAccumulation) { + param_init_info.init_val_ = req_data.vals[index]; + } + } + ps_->InitEmbeddingTable(key, shapes, param_init_info); } template @@ -338,6 +354,18 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res); } +template +void ParameterServer::ServerHandler::HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, + const ::ps::KVPairs &req_data, + ::ps::KVPairs *res) { + std::unique_lock lock(ps_->mutex()); + MS_EXCEPTION_IF_NULL(res); + const Key &key = req_data.keys[0]; + const LookupIds &lookup_ids = req_data.keys.segment(1, req_data.keys.size()); + const Values &update_vals = req_data.vals; + ps_->UpdateEmbeddings(key, lookup_ids, update_vals); +} + template void ParameterServer::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { @@ -476,7 +504,8 @@ void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) { template void ParameterServer::InitEmbeddingTable( - const Key &key, const std::shared_ptr>>> &shapes) { + const Key &key, const std::shared_ptr>>> &shapes, + const ParamInitInfo ¶m_init_info) { MS_EXCEPTION_IF_NULL(shapes); if (weights_.count(key) == 0) { std::shared_ptr lookup = @@ -493,8 +522,18 @@ void ParameterServer::InitEmbeddingTable( T *embedding_data = embedding->data(); std::default_random_engine engine; std::normal_distribution random(0, 0.01); - for (size_t i = 0; i < total_dims; i++) { - embedding_data[i] = random(engine); + if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + if (param_init_info.param_type_ == kWeight) { + InitRandomNormal(0, 0.01, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_, embedding_data); + } else if (param_init_info.param_type_ == kAccumulation) { + for (size_t i = 0; i < total_dims; i++) { + embedding_data[i] = param_init_info.init_val_; + } + } + } else { + for (size_t i = 0; i < total_dims; i++) { + embedding_data[i] = random(engine); + } } weights_[key] = embedding; tokens_[key] = 0; @@ -673,6 +712,23 @@ void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, res->lens.push_back(res->vals.size()); } +template +void ParameterServer::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) { + if (weights_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding table key " << key; + return; + } + if (embedding_lookup_ops_.count(key) == 0) { + MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; + return; + } + WeightPtr table_ptr = weights_[key]; + MS_EXCEPTION_IF_NULL(table_ptr); + std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; + MS_EXCEPTION_IF_NULL(table_lookup_op); + table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size()); +} + template inline bool ParameterServer::ReadyForUpdateWeights() { return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index f2c66a3da5..f70539485e 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -70,9 +70,16 @@ void PsCacheManager::InsertWeightInitInfo(const std::string ¶m_name, size_t MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; } auto &hash_table_info = iter->second; + if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { + return; + } hash_table_info.param_init_info_.param_type_ = kWeight; hash_table_info.param_init_info_.global_seed_ = global_seed; hash_table_info.param_init_info_.op_seed_ = op_seed; + if (CheckFinishInsertInitInfo()) { + finish_insert_init_info_ = true; + insert_init_info_.notify_one(); + } } void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) { @@ -81,8 +88,26 @@ void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float i MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; } auto &hash_table_info = iter->second; + if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { + return; + } hash_table_info.param_init_info_.param_type_ = kAccumulation; hash_table_info.param_init_info_.init_val_ = init_val; + if (CheckFinishInsertInitInfo()) { + finish_insert_init_info_ = true; + insert_init_info_.notify_one(); + } +} + +bool PsCacheManager::CheckFinishInsertInitInfo() const { + for (const auto &item : hash_tables_) { + const auto &hash_table_info = item.second; + const auto ¶m_init_info = hash_table_info.param_init_info_; + if (param_init_info.param_type_ == kUnKnown) { + return false; + } + } + return true; } void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) { @@ -113,35 +138,49 @@ void PsCacheManager::Initialize() { } embedding_device_cache_ = std::make_shared(batch_elements_, cache_vocab_size_); embedding_host_cache_ = std::make_shared(batch_elements_, host_cache_vocab_size_); - InitParameterServer(); + AddEmbeddingTable(); AllocMemForHashTable(); SetLocalIdRank(); initialized_ps_cache_ = true; } -void PsCacheManager::InitParameterServer() { +void PsCacheManager::AddEmbeddingTable() const { for (const auto &item : hash_tables_) { const auto ¶m_name = item.first; size_t key = worker.SetParamKey(param_name); size_t row_count = item.second.vocab_size; - std::vector keys{key, key, key, key}; + // if worker role + worker.AddEmbeddingTable(key, row_count); + } +} + +void PsCacheManager::InitParameterServer() { + std::unique_lock locker(data_mutex_); + insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; }); + + for (const auto &item : hash_tables_) { + const auto ¶m_name = item.first; + size_t key = worker.SetParamKey(param_name); + std::vector keys{key, key, key, key, key, key}; std::vector values{ SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1}; std::vector lens{2, 2, 3}; const auto &hash_table_info = item.second; const auto ¶m_init_info = hash_table_info.param_init_info_; if (param_init_info.param_type_ == kWeight) { - lens.push_back(0); - values.push_back(SizeToFloat(param_init_info.global_seed_)); - values.push_back(SizeToFloat(param_init_info.op_seed_)); - } else if (param_init_info.param_type_ == kAccumulation) { lens.push_back(1); - values.push_back(param_init_info.init_val_); + } else if (param_init_info.param_type_ == kAccumulation) { + lens.push_back(2); } + values.push_back(param_init_info.init_val_); + lens.push_back(param_init_info.global_seed_); + lens.push_back(param_init_info.op_seed_); // if worker role - worker.AddEmbeddingTable(key, row_count); worker.InitPSEmbeddingTable(keys, values, lens); } + + finish_init_parameter_server_ = true; + data_prase_.notify_one(); } void PsCacheManager::AllocMemForHashTable() { @@ -208,10 +247,538 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { if (graph_step_ >= UINT64_MAX) { MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; } + if (graph_step_ == 0) { + std::unique_lock locker(data_mutex_); + data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; }); + } graph_step_++; set_channel_name(channel_name); PsDataPrefetch::GetInstance().TryWakeChannel(channel_name); data_prase_.notify_one(); } + +void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { + if (!initialized_ps_cache_) { + MS_LOG(EXCEPTION) << "PS cache does not init."; + } + auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); + process_data_thread.detach(); +} + +void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + embedding_device_cache_->cache_->InitDevice(device_id, context); + InitParameterServer(); + while (true) { + ProcessData(); + } +} + +void PsCacheManager::ProcessData() { + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + struct timeval start_time, end_time; + const uint64_t kUSecondInSecond = 1000000; + (void)gettimeofday(&start_time, nullptr); + auto channel = channel_name(); + if (channel.empty()) { + std::unique_lock locker(data_mutex_); + data_prase_.wait(locker, [this] { return !channel_name_.empty(); }); + } + auto data = PsDataPrefetch::GetInstance().data(channel_name_); + if (data == nullptr) { + MS_LOG(INFO) << "No data process, channel name:" << channel_name_; + std::unique_lock locker(data_mutex_); + (void)data_prase_.wait_for(locker, std::chrono::milliseconds(100)); + return; + } + IncreaseStep(); + auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_); + auto batch_ids = reinterpret_cast(data); + auto batch_ids_len = data_size / sizeof(int); + std::unique_ptr hash_index(new int[batch_ids_len]); + if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) { + MS_LOG(EXCEPTION) << "Process data memset failed."; + } + // Get hash swap in/out index and ids. + ParseData(batch_ids, batch_ids_len, hash_index.get()); + for (const auto &item : hash_tables_) { + auto key = worker.GetParamKey(item.first); + auto hash_info = item.second; + HashSwapHostToServer(key, hash_info); + HashSwapDeviceToHost(hash_info); + HashSwapServerToHost(key, hash_info); + HashSwapHostToDevice(hash_info); + } + // Replace the batch_ids by hash index for getNext-op getting hash index as input. + if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) { + MS_LOG(EXCEPTION) << "Process data memcpy failed."; + } + embedding_device_cache_->cache_->SynchronizeStream(); + // Finish the data process and notify data prefetch. + PsDataPrefetch::GetInstance().FinalizeData(channel_name_); + (void)gettimeofday(&end_time, nullptr); + uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); + cost += static_cast(end_time.tv_usec - start_time.tv_usec); + MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_ + << ",graph step:" << graph_running_step_ << " channel name:" << channel_name_ + << ", time cost:" << cost / 1000 << "ms)."; +} + +void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { + MS_EXCEPTION_IF_NULL(batch_ids); + MS_EXCEPTION_IF_NULL(hash_index); + for (size_t i = 0; i < batch_ids_len; i++) { + bool need_swap_host_to_device = true; + bool need_swap_device_to_host = true; + auto id = batch_ids[i]; + if ((id < SizeToInt(range_bound_.first)) || (id >= SizeToInt(range_bound_.second))) { + hash_index[i] = -1; + continue; + } + hash_index[i] = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); + if (need_swap_host_to_device) { + ParseHostDataHostToDevice(id); + } + if (need_swap_device_to_host) { + ParseHostDataDeviceToHost(id); + } + } + // Each 1000 step prints ps cache hit rate. + if (data_step_ % 1000 == 0) { + statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_; + auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; + MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%."; + } +} + +void PsCacheManager::WaitGraphRun() { + MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes."; + std::unique_lock locker(data_mutex_); + if (!data_prase_.wait_for(locker, std::chrono::seconds(120), [this] { return graph_step_ > graph_running_step_; })) { + MS_LOG(EXCEPTION) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_ + << ", graph running step:" << graph_running_step_ << ")."; + } + set_current_graph_step(); +} + +int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) { + MS_EXCEPTION_IF_NULL(need_swap_device_to_host); + MS_EXCEPTION_IF_NULL(need_swap_host_to_device); + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); + int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); + int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); + int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); + MS_EXCEPTION_IF_NULL(device_to_host_index); + MS_EXCEPTION_IF_NULL(device_to_host_ids); + MS_EXCEPTION_IF_NULL(host_to_device_index); + MS_EXCEPTION_IF_NULL(host_to_device_ids); + + auto device_hash_map = embedding_device_cache_->device_hash_map_; + MS_EXCEPTION_IF_NULL(device_hash_map); + int index = 0; + auto iter = device_hash_map->id_iter(id); + if (device_hash_map->IsIdExist(iter)) { + *need_swap_device_to_host = false; + *need_swap_host_to_device = false; + index = iter->second; + if (device_hash_map->hash_step(index) != data_step_) { + statistics_info_.hash_hit_count_++; + device_hash_map->set_hash_step(index, data_step_); + } + } else { + auto tmp_device_to_host_size = statistics_info_.device_to_host_size_; + while (true) { + index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, + &(statistics_info_.device_to_host_size_)); + if (index == INVALID_INDEX_VALUE) { + WaitGraphRun(); + continue; + } + host_to_device_index[statistics_info_.host_to_device_size_] = index; + host_to_device_ids[statistics_info_.host_to_device_size_] = id; + statistics_info_.host_to_device_size_++; + *need_swap_device_to_host = statistics_info_.device_to_host_size_ > tmp_device_to_host_size; + break; + } + } + return index; +} + +void PsCacheManager::ParseHostDataHostToDevice(size_t id) { + MS_EXCEPTION_IF_NULL(embedding_host_cache_); + int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); + int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); + int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); + int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); + int *host_to_device_index = embedding_host_cache_->host_to_device_index.get(); + MS_EXCEPTION_IF_NULL(host_to_server_index); + MS_EXCEPTION_IF_NULL(host_to_server_ids); + MS_EXCEPTION_IF_NULL(server_to_host_index); + MS_EXCEPTION_IF_NULL(server_to_host_ids); + MS_EXCEPTION_IF_NULL(host_to_device_index); + + auto host_hash_map = embedding_host_cache_->host_hash_map_; + MS_EXCEPTION_IF_NULL(host_hash_map); + auto iter = host_hash_map->id_iter(id); + if (host_hash_map->IsIdExist(iter)) { + auto index = iter->second; + if (host_hash_map->hash_step(index) != data_step_) { + host_hash_map->set_hash_step(index, data_step_); + } + host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; + } else { + while (true) { + auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, + graph_running_step_, &statistics_info_.host_to_server_size_); + if (index == INVALID_INDEX_VALUE) { + WaitGraphRun(); + continue; + } + host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; + server_to_host_index[statistics_info_.server_to_host_size_] = index; + server_to_host_ids[statistics_info_.server_to_host_size_++] = id; + break; + } + } +} + +void PsCacheManager::ParseHostDataDeviceToHost(size_t id) { + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + MS_EXCEPTION_IF_NULL(embedding_host_cache_); + int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); + int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); + int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); + int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); + MS_EXCEPTION_IF_NULL(device_to_host_ids); + MS_EXCEPTION_IF_NULL(host_to_server_index); + MS_EXCEPTION_IF_NULL(host_to_server_ids); + MS_EXCEPTION_IF_NULL(device_to_host_index); + + auto host_hash_map = embedding_host_cache_->host_hash_map_; + MS_EXCEPTION_IF_NULL(host_hash_map); + int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1]; + auto iter = host_hash_map->id_iter(swap_device_to_host_id); + if (host_hash_map->IsIdExist(iter)) { + auto index = iter->second; + if (host_hash_map->hash_step(index) != data_step_) { + host_hash_map->set_hash_step(index, data_step_); + } + device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; + } else { + while (true) { + auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, + graph_running_step_, &statistics_info_.host_to_server_size_); + if (index == INVALID_INDEX_VALUE) { + WaitGraphRun(); + continue; + } + device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; + break; + } + } +} + +void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, + const float *input_addr, const int *indices_addr, float *output_addr) { + auto type_size = sizeof(float); + size_t lens = outer_dim_size * type_size; + for (size_t i = 0; i < indices_lens; ++i) { + int index = indices_addr[i]; + if (index >= 0 && index < SizeToInt(first_dim_size)) { + size_t pos = index * outer_dim_size; + auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; + } + } else { + auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; + } + } + output_addr += outer_dim_size; + } +} + +void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, + const int *indices_addr, float *output_addr) { + size_t first_dim_size = host_cache_vocab_size_; + size_t outer_dim_size = embedding_size; + + size_t thread_num = indices_lens / 10000 + 1; + thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; + std::thread threads[kMaxThreadNum]; + size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num; + size_t i; + size_t task_offset = 0; + MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens; + for (i = 0; i < thread_num; i++) { + if (task_offset >= indices_lens) { + break; + } + MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens; + threads[i] = std::thread(&PsCacheManager::LookUpTableTask, this, task_proc_lens, outer_dim_size, first_dim_size, + hash_table_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size); + task_offset += task_proc_lens; + if (task_offset + task_proc_lens > indices_lens) { + task_proc_lens = indices_lens - task_offset; + } + } + for (size_t j = 0; j < i; j++) { + threads[j].join(); + } +} + +void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, + float *insert_data, float *hash_table_addr) { + size_t first_dim_size = host_cache_vocab_size_; + size_t thread_num = insert_indices_size / 10000 + 1; + thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; + std::thread threads[kMaxThreadNum]; + size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num; + size_t i; + size_t task_offset = 0; + + auto insert_hash_table_task = [](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, + int *insert_indices, float *insert_data, float *hash_table_addr) { + auto type_size = sizeof(float); + size_t lens = outer_dim_size * type_size; + for (size_t i = 0; i < insert_indices_size; ++i) { + int index = insert_indices[i]; + if (index >= 0 && index < SizeToInt(first_dim_size)) { + auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Insert hash table task memcpy failed."; + } + } + } + }; + + for (i = 0; i < thread_num; i++) { + if (task_offset >= insert_indices_size) { + break; + } + MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens; + threads[i] = std::thread(insert_hash_table_task, task_proc_lens, embedding_size, first_dim_size, + insert_indices + task_offset, insert_data + task_offset * embedding_size, hash_table_addr); + task_offset += task_proc_lens; + if (task_offset + task_proc_lens > insert_indices_size) { + task_proc_lens = insert_indices_size - task_offset; + } + } + + for (size_t j = 0; j < i; j++) { + threads[j].join(); + } +} + +void PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + MS_EXCEPTION_IF_NULL(embedding_host_cache_); + auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get(); + auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get(); + auto swap_indices_size = statistics_info_.host_to_device_size_; + if (swap_indices_size == 0) { + return; + } + auto embedding_size = hash_info.embedding_size; + auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); + auto hash_table_size = hash_info.device_address.size; + auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); + auto swap_out_data = std::make_unique(swap_indices_size * embedding_size); + LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_cache_host_to_device_index, + swap_out_data.get()); + embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, + swap_out_data.get(), + swap_indices_size * embedding_size * sizeof(float)); + embedding_device_cache_->cache_->CopyHostMemToDevice( + embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index, swap_indices_size * sizeof(int)); + embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, + embedding_device_cache_->hash_swap_index_addr_, hash_table_size, + embedding_size, swap_indices_size); +} + +void PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + MS_EXCEPTION_IF_NULL(embedding_host_cache_); + auto swap_indices_size = statistics_info_.device_to_host_size_; + auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get(); + auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get(); + if (swap_indices_size == 0) { + return; + } + auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); + auto hash_table_size = hash_info.device_address.size; + auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); + auto embedding_size = hash_info.embedding_size; + auto swap_out_data = std::make_unique(swap_indices_size * embedding_size); + embedding_device_cache_->cache_->CopyHostMemToDevice( + embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index, swap_indices_size * sizeof(int)); + embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, + embedding_device_cache_->hash_swap_index_addr_, hash_table_size, + embedding_size, swap_indices_size); + embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data.get(), + embedding_device_cache_->hash_swap_value_addr_, + swap_indices_size * embedding_size * sizeof(float)); + embedding_device_cache_->cache_->SynchronizeStream(); + InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index, + swap_out_data.get(), host_hash_table_addr); +} + +void PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) { + MS_EXCEPTION_IF_NULL(embedding_host_cache_); + auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); + auto host_to_server_index = embedding_host_cache_->host_to_server_index.get(); + auto swap_indices_size = statistics_info_.host_to_server_size_; + if (swap_indices_size == 0) { + return; + } + ::ps::SArray lookup_ids(swap_indices_size, 0); + ::ps::SArray swap_out_data; + auto embedding_size = hash_info.embedding_size; + swap_out_data.resize(swap_indices_size * embedding_size); + auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); + LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, + swap_out_data.data()); + + auto copy_len = swap_indices_size * sizeof(int); + auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + } + worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); +} + +void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) { + MS_EXCEPTION_IF_NULL(embedding_host_cache_); + auto swap_indices_size = statistics_info_.server_to_host_size_; + auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); + auto server_to_host_index = embedding_host_cache_->server_to_host_index.get(); + if (swap_indices_size == 0) { + return; + } + auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); + auto embedding_size = hash_info.embedding_size; + ::ps::SArray lengths{swap_indices_size}; + ::ps::SArray lookup_result(swap_indices_size * embedding_size, 0); + ::ps::SArray lookup_ids(swap_indices_size, 0); + auto copy_len = swap_indices_size * sizeof(int); + auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + } + worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); + InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(), + host_hash_table_addr); +} + +void PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray *swap_out_data, + const HashTableInfo &hash_info) { + MS_EXCEPTION_IF_NULL(swap_out_index); + MS_EXCEPTION_IF_NULL(swap_out_data); + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + auto swap_out_index_size = statistics_info_.device_to_host_size_; + if (swap_out_index_size == 0) { + return; + } + auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); + auto hash_table_size = hash_info.device_address.size; + auto embedding_size = hash_info.embedding_size; + swap_out_data->resize(swap_out_index_size * embedding_size); + embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_out_index, + swap_out_index_size * sizeof(int)); + embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, + embedding_device_cache_->hash_swap_index_addr_, hash_table_size, + embedding_size, swap_out_index_size); + embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data->data(), + embedding_device_cache_->hash_swap_value_addr_, + swap_out_index_size * embedding_size * sizeof(float)); + embedding_device_cache_->cache_->RecordEvent(); +} + +void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, + size_t key) { + MS_EXCEPTION_IF_NULL(swap_in_ids); + MS_EXCEPTION_IF_NULL(swap_in_index); + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + auto swap_in_ids_size = statistics_info_.host_to_device_size_; + if (swap_in_ids_size == 0) { + return; + } + auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); + auto hash_table_size = hash_info.device_address.size; + auto embedding_size = hash_info.embedding_size; + // Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device). + ::ps::SArray lengths{swap_in_ids_size}; + ::ps::SArray lookup_result(swap_in_ids_size * embedding_size, 0); + ::ps::SArray lookup_ids(swap_in_ids_size, 0); + auto copy_len = swap_in_ids_size * sizeof(int); + auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + } + worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); + // Hash swap-in in device. + embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, + lookup_result.data(), + swap_in_ids_size * embedding_size * sizeof(float)); + embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_in_index, + swap_in_ids_size * sizeof(int)); + embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, + embedding_device_cache_->hash_swap_index_addr_, hash_table_size, + embedding_size, swap_in_ids_size); +} + +void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray &swap_out_data, int *swap_out_ids, size_t key) { + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + MS_EXCEPTION_IF_NULL(swap_out_ids); + auto swap_out_ids_size = statistics_info_.device_to_host_size_; + if (swap_out_ids_size == 0) { + return; + } + ::ps::SArray lookup_ids(swap_out_ids_size, 0); + auto copy_len = swap_out_ids_size * sizeof(int); + auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + } + // Need synchronize event to ensure that the swap-out in device is completed. + embedding_device_cache_->cache_->SynchronizeEvent(); + worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); +} + +void PsCacheManager::DumpHashTables() const { + MS_EXCEPTION_IF_NULL(embedding_device_cache_); + MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + for (const auto &item : hash_tables_) { + const auto ¶m_name = item.first; + size_t cache_vocab_size = item.second.cache_vocab_size; + size_t embedding_size = item.second.embedding_size; + size_t vocab_size = item.second.vocab_size; + MS_LOG(INFO) << "Dump hash tables: " << param_name << " || " << cache_vocab_size << " || " << embedding_size + << " || " << vocab_size << " || " << reinterpret_cast(item.second.device_address.addr) + << " || " << reinterpret_cast(item.second.host_address.get()); + float *output = new float[item.second.device_address.size / 4]; + embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr, + item.second.device_address.size); + embedding_device_cache_->cache_->SynchronizeStream(); + for (size_t i = 0; i < cache_vocab_size; i++) { + for (size_t j = 0; j < embedding_size; j++) { + std::cout << output[i * embedding_size + j] << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + delete[] output; + } +} } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index c36a50b2ee..a97459511e 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -49,6 +49,7 @@ struct HashTableInfo { size_t vocab_size{0}; Address device_address{nullptr, 0}; std::shared_ptr host_address{nullptr}; + ParamInitInfo param_init_info_; }; struct EmbeddingDeviceCache { @@ -158,6 +159,8 @@ class PsCacheManager { void UpdataEmbeddingTable(const ::ps::SArray &swap_out_data, int *swap_out_ids, size_t key); void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, const int *indices_addr, float *output_addr); + bool CheckFinishInsertInitInfo() const; + void AddEmbeddingTable() const; bool initialized_ps_cache_{false}; std::string channel_name_; @@ -167,6 +170,7 @@ class PsCacheManager { size_t data_step_{0}; std::mutex data_mutex_; std::condition_variable data_prase_; + std::condition_variable insert_init_info_; std::map hash_tables_; std::shared_ptr embedding_device_cache_; @@ -178,6 +182,8 @@ class PsCacheManager { size_t batch_elements_{0}; PsCacheStatisticsInfo statistics_info_; std::pair range_bound_; + std::atomic_bool finish_insert_init_info_{false}; + std::atomic_bool finish_init_parameter_server_{false}; }; static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h index f4c00bfb90..044e6f834b 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h @@ -26,7 +26,7 @@ namespace mindspore { namespace ps { -class PsDataPrefetch { +class EXPORT PsDataPrefetch { public: EXPORT static PsDataPrefetch &GetInstance() { static PsDataPrefetch instance; diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 03de6a77f1..32c6ce7524 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -17,6 +17,11 @@ #include "ps/ps_context.h" #include "utils/log_adapter.h" #include "utils/ms_utils.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#include "backend/kernel_compiler/kernel.h" +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "ps/ps_cache/ps_cache_manager.h" +#endif namespace mindspore { namespace ps { @@ -80,5 +85,43 @@ bool PSContext::is_role_sched() const { return is_sched_; } void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; } int PSContext::ps_rank_id() const { return rank_id_; } + +void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, + size_t vocab_size) const { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size); +#endif +} + +void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, + size_t cache_vocab_size, size_t embedding_size) const { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size); +#endif +} + +void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed); +#endif +} + +void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + ps_cache_instance.InsertAccumuInitInfo(param_name, init_val); +#endif +} + +void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + ps_cache_instance.CloneHashTable(dest_param_name, src_param_name); +#endif +} + +void PSContext::set_cache_enable(bool cache_enable) const { +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); +#endif +} } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.h b/mindspore/ccsrc/ps/ps_context.h index ce5a98658d..070a4df464 100644 --- a/mindspore/ccsrc/ps/ps_context.h +++ b/mindspore/ccsrc/ps/ps_context.h @@ -44,6 +44,14 @@ class PSContext { bool is_role_sched() const; void SetPSRankId(int rank_id); int ps_rank_id() const; + void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, + size_t vocab_size) const; + void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, + size_t cache_vocab_size, size_t embedding_size) const; + void InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const; + void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const; + void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const; + void set_cache_enable(bool cache_enable) const; private: PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {} diff --git a/mindspore/ccsrc/ps/random_normal/random_normal.cc b/mindspore/ccsrc/ps/random_normal/random_normal.cc new file mode 100644 index 0000000000..d15930fd55 --- /dev/null +++ b/mindspore/ccsrc/ps/random_normal/random_normal.cc @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ps/random_normal/random_normal.h" +#include +#include +#include +#include "utils/convert_utils_base.h" +#include "pybind_api/random_normal/random_cpu_kernel.h" + +namespace mindspore { +namespace ps { +bool InitRandomNormal(float mean, float stddev, std::vector out_shape, size_t global_seed, size_t op_seed, + float *output_data) { + if (out_shape.size() == 0) { + std::cout << "output data shape is error" << std::endl; + } + int64_t total_count = 1; + for (uint32_t i = 0; i < out_shape.size(); i++) { + total_count *= SizeToLong(out_shape[i]); + } + uint32_t thread_num = 16; + if (total_count <= thread_num) { + thread_num = 1; + } + float *start_ptr = output_data; + if (start_ptr == nullptr) { + std::cout << "start_ptr is nullptr" << std::endl; + return false; + } + int64_t batchSize = total_count / thread_num; + std::vector threads(thread_num); + int64_t seed = SizeToLong(global_seed); + int64_t seed2 = SizeToLong(op_seed); + seed = (seed == 0 && seed2 == 0) ? clock() : seed; + PhiloxGenerator generator = PhiloxGenerator(seed, seed2); + if (thread_num != 1) { + for (uint32_t i = 0; i < thread_num - 1; i++) { + float *offset_ptr = start_ptr + batchSize * i; + threads[i] = + std::thread(FillRandoms>, generator, offset_ptr, batchSize, i); + } + float *offset_ptr = start_ptr + batchSize * (thread_num - 1); + threads[thread_num - 1] = std::thread(FillRandoms>, generator, + offset_ptr, total_count - (thread_num - 1) * batchSize, thread_num - 1); + } else { + threads[0] = + std::thread(FillRandoms>, generator, start_ptr, total_count, 0); + } + for (uint32_t i = 0; i < thread_num; i++) { + threads[i].join(); + } + for (int64_t i = 0; i < total_count; i++) { + output_data[i] *= stddev; + } + return true; +} +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/random_normal/random_normal.h b/mindspore/ccsrc/ps/random_normal/random_normal.h new file mode 100644 index 0000000000..9e432f73d6 --- /dev/null +++ b/mindspore/ccsrc/ps/random_normal/random_normal.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_ +#define MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_ +#include + +namespace mindspore { +namespace ps { +bool InitRandomNormal(float mean, float stddev, std::vector out_shape, size_t global_seed, size_t op_seed, + float *output_data); +} // namespace ps +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_ diff --git a/mindspore/ccsrc/ps/util.h b/mindspore/ccsrc/ps/util.h index eda5bc1bd9..2a832d00f2 100644 --- a/mindspore/ccsrc/ps/util.h +++ b/mindspore/ccsrc/ps/util.h @@ -26,6 +26,15 @@ namespace mindspore { namespace ps { +enum ParamType { kUnKnown = 0, kWeight = 1, kAccumulation = 2 }; + +struct ParamInitInfo { + ParamType param_type_{kUnKnown}; + size_t global_seed_{0}; + size_t op_seed_{0}; + float init_val_{0}; +}; + class Util { public: static bool IsParamServerMode(); diff --git a/mindspore/ccsrc/ps/worker.h b/mindspore/ccsrc/ps/worker.h index f4dacabe68..a3b1930c5d 100644 --- a/mindspore/ccsrc/ps/worker.h +++ b/mindspore/ccsrc/ps/worker.h @@ -32,6 +32,7 @@ #include "ps/common.h" #include "ps/worker_proxy.h" #include "utils/shape_utils.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" namespace mindspore { namespace ps { @@ -47,15 +48,19 @@ class Worker { void Push(const std::vector &keys, std::vector addrs, const ShapeVector &sizes); void Pull(const size_t key, void *dev_addr, const size_t size); size_t SetParamKey(const std::string ¶m_name); + size_t GetParamKey(const std::string ¶m_name); void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); bool GetParamInitInServer(const std::string ¶m_name); void SetKeyOptimId(size_t key, const std::string &optimizer_name); void SetOptimInputShapes(size_t key, const ShapeVector &shape); void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); - void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const ShapeVector &sizes); + void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const ShapeVector &sizes); void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor); void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *lookup_result, int64_t cmd); + void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &vals); + bool running() { return running_; } void Finalize(); private: @@ -65,7 +70,6 @@ class Worker { Worker &operator=(const Worker &) = delete; bool IsKeyInit(const size_t key); - size_t GetParamKey(const std::string ¶m_name); void InitPSOptimId(const size_t param_key); void InitPSOptimInputShapes(const size_t key); void InitPSParamData(const std::vector &keys, void *origin_addr, size_t size); @@ -187,6 +191,12 @@ void Worker::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const : kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd); } +template +void Worker::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &vals) { + kv_worker_->UpdateEmbeddingTable(keys, lookup_ids, vals); +} + template void Worker::Finalize() { if (running_) { @@ -286,7 +296,7 @@ size_t Worker::GetParamKey(const std::string ¶m_name) { size_t key = kInvalidKey; if (param_to_key_.find(param_name) != param_to_key_.end()) { key = param_to_key_[param_name]; - MS_LOG(INFO) << "Get key of parameter " << param_name << " key is " << key; + MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key; } return key; } @@ -310,8 +320,7 @@ void Worker::InitPSOptimId(const size_t param_key) { } template -void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, - const ShapeVector &sizes) { +void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const ShapeVector &sizes) { bool has_init = IsKeyInit(keys[0]); if (has_init) { MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized."; @@ -319,7 +328,7 @@ void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vecto } ::ps::SArray shapes_val; for (auto dim : shapes) { - shapes_val.push_back(static_cast(dim)); + shapes_val.push_back(dim); } std::vector sizes_int; (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int), @@ -337,9 +346,6 @@ void Worker::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor:: const std::string ¶m_name = pk_node->fullname_with_scope(); void *param_data = tensor->data_c(); size_t param_size = LongToSize(tensor->data().nbytes()); - if (param_size > INT_MAX) { - MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " << param_size; - } size_t param_key = GetParamKey(param_name); if (param_key == kInvalidKey) { @@ -357,11 +363,17 @@ void Worker::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor:: MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name << ", whether init in server: " << init_in_server; kv_worker_->AddKeyToServerId(param_key); - if (!init_in_server) { - InitPSParamData({param_key}, param_data, param_size); + if (!PsDataPrefetch::GetInstance().cache_enable()) { + if (!init_in_server) { + if (param_size > INT_MAX) { + MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " + << param_size; + } + InitPSParamData({param_key}, param_data, param_size); + } + InitPSOptimId(param_key); + InitPSOptimInputShapes(param_key); } - InitPSOptimId(param_key); - InitPSOptimInputShapes(param_key); } } diff --git a/mindspore/ccsrc/ps/worker_proxy.h b/mindspore/ccsrc/ps/worker_proxy.h index e0962765ec..051308eda1 100644 --- a/mindspore/ccsrc/ps/worker_proxy.h +++ b/mindspore/ccsrc/ps/worker_proxy.h @@ -45,6 +45,7 @@ class WorkerProxy : public ::ps::KVWorker { explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id) : Worker(app_id, customer_id) { server_num_ = ::ps::NumServers(); + MS_LOG(INFO) << "Server num:" << server_num_; PSContext::instance()->SetPSRankId(::ps::MyRank()); using std::placeholders::_1; using std::placeholders::_2; @@ -60,6 +61,7 @@ class WorkerProxy : public ::ps::KVWorker { broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3, _4, _5); round_robin_slicer_ = std::bind(&WorkerProxy::RoundRobinSlicer, this, _1, _2, _3, _4, _5); worker_init_embedding_slicer_ = std::bind(&WorkerProxy::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5); + update_embedding_slicer_ = std::bind(&WorkerProxy::UpdateEmbeddingSlicer, this, _1, _2, _3, _4, _5); } ~WorkerProxy() override = default; @@ -70,6 +72,8 @@ class WorkerProxy : public ::ps::KVWorker { const Callback &cb = nullptr, int64_t priority = 0); int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int64_t priority = 0); + void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &vals, const Callback &cb = nullptr, int64_t priority = 0); bool IsReadyForPush(const Key &key); bool IsReadyForPull(const Key &key); void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, @@ -98,6 +102,9 @@ class WorkerProxy : public ::ps::KVWorker { void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced, const std::map &attrs); + void UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs &send, const std::vector<::ps::Range> &, + std::vector>> *sliced, + const std::map &attrs); void ProcessLookupResult(const ::ps::Message &msg); void ProcessResponse(const ::ps::Message &msg); void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs &kvs, @@ -122,6 +129,7 @@ class WorkerProxy : public ::ps::KVWorker { Slicer broadcast_slicer_; Slicer round_robin_slicer_; Slicer worker_init_embedding_slicer_; + Slicer update_embedding_slicer_; std::unordered_map lookup_callbacks_; std::unordered_map general_callbacks_; std::unordered_map expected_result_count_; @@ -195,6 +203,24 @@ int64_t WorkerProxy::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, return ts; } +template +void WorkerProxy::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, + const ::ps::SArray &vals, const Callback &cb, int64_t priority) { + int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr); + ::ps::KVPairs kvs; + kvs.keys = keys; + kvs.lens = lookup_ids; + kvs.vals = vals; + kvs.priority = priority; + expected_result_count_[ts] = 0; + Send(general_customer_.get(), ts, true, false, kUpdateEmbeddingsCmd, kvs, update_embedding_slicer_); + if (expected_result_count_[ts] < server_num_) { + general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); + } + general_customer_->WaitRequest(ts); + expected_result_count_.erase(ts); +} + template bool WorkerProxy::IsReadyForPush(const Key &key) { ::ps::SArray result(1, 0); @@ -724,6 +750,47 @@ void WorkerProxy::WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KV } } +template +void WorkerProxy::UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs &send, + const std::vector<::ps::Range> &, + std::vector>> *sliced, + const std::map &attrs) { + MS_EXCEPTION_IF_NULL(sliced); + T *embedding_vals = send.vals.data(); + int *lookup_ids = send.lens.data(); + size_t val_size = send.vals.size(); + size_t id_size = send.lens.size(); + size_t embedding_dim = val_size / id_size; + + const Key &key = send.keys[0]; + const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); + sliced->resize(ranges.size()); + + for (size_t i = 0; i < ranges.size(); i++) { + const ::ps::Range &range = ranges[i]; + const auto &begin = range.begin(); + const auto &end = range.end(); + auto &kvs = sliced->at(i).second; + kvs.keys.push_back(key); + for (size_t j = 0; j < id_size; j++) { + auto lookup_id = static_cast(lookup_ids[j]); + if (lookup_id >= begin && lookup_id <= end) { + kvs.keys.push_back(lookup_id); + for (size_t k = 0; k < embedding_dim; k++) { + kvs.vals.push_back(embedding_vals[j * embedding_dim + k]); + } + } + } + + if (kvs.keys.size() <= 1) { + sliced->at(i).first = false; + } else { + sliced->at(i).first = true; + expected_result_count_[timestamp] += 1; + } + } +} + template void WorkerProxy::ProcessLookupResult(const ::ps::Message &msg) { int64_t ts = msg.meta.timestamp; diff --git a/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.cc b/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.cc index 016fdae242..40ef95a4bd 100644 --- a/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.cc +++ b/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.cc @@ -18,6 +18,7 @@ #include #include #include "runtime/device/cpu/cpu_device_address.h" +#include "ir/tensor.h" namespace mindspore { bool InitRandomNormal(float mean, float stddev, std::vector out_shape, int64_t seed, int64_t seed2, diff --git a/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.h b/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.h index 74eb7130bb..cb517a6259 100644 --- a/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.h +++ b/mindspore/ccsrc/pybind_api/random_normal/random_cpu_kernel.h @@ -19,8 +19,7 @@ #include "pybind_api/random_normal/philox_generator.h" #include "pybind11/pybind11.h" #include "pybind_api/api_register.h" -#include "backend/kernel_compiler/cpu/cpu_kernel.h" -#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" +#include "utils/log_adapter.h" namespace py = pybind11; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 4bbea9db8e..d77744a614 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -55,6 +55,7 @@ class AscendKernelRuntime : public KernelRuntime { bool SyncStream() override; void SetContext() override; void CreateContext() override; + void *context() const override { return rt_context_; } protected: DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc index eab93dd3e8..c7e032c8f8 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc @@ -18,6 +18,7 @@ #include "runtime/device/gpu/gpu_memory_allocator.h" #include "utils/ms_context.h" #include "utils/convert_utils.h" +#include "ps/ps_cache/ps_cache_manager.h" namespace mindspore { namespace device { namespace gpu { @@ -38,6 +39,9 @@ void GPUMemoryManager::MallocDeviceMemory() { MS_EXCEPTION_IF_NULL(context_ptr); // If use the dynamic memory pool, then alloc the first memory block to init. if (context_ptr->get_param(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) { + if (ps::ps_cache_instance.initialized_ps_cache()) { + return; + } auto device_addr = MallocMemFromMemPool(1); if (!device_addr) { MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index bdd29ea6e8..98d94a86a0 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -30,6 +30,10 @@ #include "utils/ms_utils.h" #include "utils/shape_utils.h" #include "utils/utils.h" +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +#include "ps/ps_cache/ps_cache_manager.h" +#endif + using mindspore::kernel::Address; using mindspore::kernel::AddressPtr; @@ -331,15 +335,27 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; continue; } + DeviceAddressPtr device_address = nullptr; +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + const std::string ¶m_name = item->fullname_with_scope(); + if (ps::ps_cache_instance.IsHashTable(param_name)) { + const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name); + MS_EXCEPTION_IF_NULL(address.addr); + device_address = + CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + AnfAlgo::SetOutputAddr(device_address, index, item.get()); + continue; + } +#endif auto tensor_size = CountNodeDeviceMemorySize(item, index); - auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope(); - if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { + if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address) == nullptr) { MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; } MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope() << " index: " << index << " size: " << tensor_size; - AnfAlgo::SetOutputAddr(address, index, item.get()); + AnfAlgo::SetOutputAddr(device_address, index, item.get()); } } MS_LOG(INFO) << "AssignStaticMemoryInput end"; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 8ecca3cf26..17ccaae282 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -78,6 +78,7 @@ class KernelRuntime { virtual void ClearGlobalIdleMem() {} virtual void CreateContext() {} virtual void SetContext() {} + virtual void *context() const { return nullptr; } uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { return mem_manager_->MallocMem(type, size, address); } diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 0aae4e2b3f..9529f18a0a 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -15,6 +15,7 @@ """Parameter for cell.""" from copy import copy +import numbers from .._c_expression import ParamInfo from .._c_expression import MetaTensor as MetaTensor_ from . import dtype as mstype @@ -23,7 +24,10 @@ from .tensor import Tensor, MetaTensor from .._checkparam import Validator from ..parallel._tensor import _get_slice_index from ..parallel._auto_parallel_context import auto_parallel_context -from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched +from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table +from ..parallel._ps_context import _reinsert_hash_table_size +from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info +from .seed import _get_global_and_op_seed __all__ = ['Parameter', 'ParameterTuple'] @@ -35,6 +39,18 @@ def _is_in_parallel_mode(): """Get parallel mode.""" return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"] +def init_to_value(init): + """Get value of initializer.""" + if isinstance(init, str): + if init == 'zeros': + return 0.0 + if init == 'ones': + return 1.0 + raise ValueError("init should be one of values in 'zeros', 'ones'.") + if isinstance(init, numbers.Number): + return float(init) + raise ValueError("init should be number or string") + class Parameter(MetaTensor_): """ @@ -118,6 +134,8 @@ class Parameter(MetaTensor_): def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False): self._param_info = ParamInfo() + self.init_in_server = False + self.cache_enable = False self.name = name self.requires_grad = requires_grad self.layerwise_parallel = layerwise_parallel @@ -129,7 +147,6 @@ class Parameter(MetaTensor_): self._sliced = False self.is_param_ps = False self._cast_type = None - self.init_in_server = False self._unique = False self.is_in_parallel = _is_in_parallel_mode() if isinstance(default_input, (MetaTensor, Tensor)): @@ -155,7 +172,7 @@ class Parameter(MetaTensor_): if isinstance(data, bool): raise ValueError('Parameter data can not be `bool`') if isinstance(data, MetaTensor): - if _is_in_parallel_mode() or _is_role_worker(): + if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched(): # do not init data while in auto parallel. return (MetaTensor_, data.dtype, data.shape) data = data.to_tensor() @@ -189,18 +206,18 @@ class Parameter(MetaTensor_): init_in_server (bool): Whether trainable parameter updated by parameter server is initialized on server. Default: False. """ - if _is_role_worker() or _is_role_pserver() or _is_role_sched(): - if init_in_server and (not self.name.endswith("embedding_table")): - raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " - "sparse operator support initialization in server.".format(self.name)) - self.is_param_ps = True - self.init_in_server = init_in_server - self._param_info.init_in_server = init_in_server - else: + if not(_is_role_worker() or _is_role_pserver() or _is_role_sched()): raise RuntimeError("Must complete following two steps before calling set_param_ps: \ 1. set_ps_context(enable_ps=True) \ 2. export MS_ROLE environment variable.") + if init_in_server and (not self.name.endswith("embedding_table")): + raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " + "sparse operator support initialization in server.".format(self.name)) + self.is_param_ps = True + self.init_in_server = init_in_server + self._param_info.init_in_server = init_in_server + @property def inited_param(self): @@ -238,6 +255,13 @@ class Parameter(MetaTensor_): format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) else: raise ValueError("The type of the name should be `str` or `None`.") + + if _is_role_worker() and self.cache_enable: + if len(self.shape) != 2: + raise RuntimeError("The dims of parameter '{}' must be 2, but got {}." + .format(self.name, len(self.shape))) + _reinsert_hash_table_size(name_, self._param_info.name, self.shape[0], self.shape[1]) + self._param_info.name = name_ @property @@ -297,6 +321,7 @@ class Parameter(MetaTensor_): x.is_init = False x.is_param_ps = self.is_param_ps x.init_in_server = self.init_in_server + x.cache_enable = self.cache_enable if init != 'same': shape = self.shape dtype = self.dtype @@ -431,15 +456,18 @@ class Parameter(MetaTensor_): raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) slice_index = int(_get_slice_index(layout[0], layout[1])) if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): - if _is_role_worker(): + if _is_role_worker() or _is_role_sched(): data = self.init_mode.to_tensor(0, [1]) else: data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) else: data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) else: + if _is_role_worker() and self.cache_enable: + global_seed, op_seed = _get_global_and_op_seed() + _insert_weight_init_info(self.name, global_seed, op_seed) if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): - if _is_role_worker(): + if _is_role_worker() or _is_role_sched(): data = self.init_mode.to_tensor(0, [1]) else: data = self.init_mode.to_tensor() @@ -502,6 +530,16 @@ class ParameterTuple(tuple): x1 = x.clone(init) x1.name = prefix + "." + x1.name new.append(x1) + + if not x1.cache_enable: + continue + if not x1.name.endswith("embedding_table"): + raise RuntimeError("Can not enable cache for parameter '{}', Only parameters of " + "sparse operator support enable cache.".format(x1.name)) + + if _is_role_worker(): + _clone_hash_table(x.name, x1.name) + _insert_accumu_init_info(x1.name, init_to_value(init)) return ParameterTuple(new) def __parameter_tuple__(self): diff --git a/mindspore/common/seed.py b/mindspore/common/seed.py index cf17956870..ad3cd859e9 100644 --- a/mindspore/common/seed.py +++ b/mindspore/common/seed.py @@ -195,6 +195,20 @@ def _get_op_seed(op_seed, kernel_name): return _KERNEL_SEED[(kernel_name, op_seed)] +def _get_global_and_op_seed(): + """Get global_seed and op_seed.""" + global_seed = get_seed() + op_seed = get_seed() + if global_seed == 0: + global_seed = DEFAULT_GRAPH_SEED + elif global_seed is None: + global_seed = 0 + Validator.check_non_negative_int(op_seed, "seed", "init") + temp_seed = _get_op_seed(op_seed, "init") + seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) + return seeds + + def _get_graph_seed(op_seed, kernel_name): """ Get the graph-level seed. diff --git a/mindspore/core/utils/convert_utils_base.h b/mindspore/core/utils/convert_utils_base.h index 83451ea336..b5673f5ac0 100644 --- a/mindspore/core/utils/convert_utils_base.h +++ b/mindspore/core/utils/convert_utils_base.h @@ -73,6 +73,8 @@ inline size_t FloatToSize(float u) { } inline float IntToFloat(int32_t v) { return static_cast(v); } +inline float SizeToFloat(size_t v) { return static_cast(v); } + inline double LongToDouble(int64_t v) { return static_cast(v); } inline double FloatToDouble(float v) { return static_cast(v); } diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 017d512a15..3f36ea5c18 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -22,6 +22,7 @@ from mindspore.common.initializer import initializer from mindspore.communication.management import get_group_size from mindspore.context import ParallelMode from mindspore.parallel._utils import _get_parallel_mode +from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker from mindspore._checkparam import Rel from mindspore._checkparam import Validator as validator from mindspore.ops.primitive import constexpr @@ -156,6 +157,7 @@ class EmbeddingLookup(Cell): max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 or None. Default: None sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. + vocab_cache_size (int): Cache size of the dictionary of embeddings. Inputs: - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. @@ -185,7 +187,7 @@ class EmbeddingLookup(Cell): def __init__(self, vocab_size, embedding_size, param_init='normal', target='CPU', slice_mode='batch_slice', manual_shapes=None, - max_norm=None, sparse=True): + max_norm=None, sparse=True, vocab_cache_size=0): super(EmbeddingLookup, self).__init__() self.target = target if target not in ('CPU', 'DEVICE'): @@ -199,11 +201,23 @@ class EmbeddingLookup(Cell): self.gatherv2 = P.GatherV2() self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) + self.vocab_cache_size = validator.check_value_type('vocab_cache_size', vocab_cache_size, [int], self.cls_name) self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) - self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), - name='embedding_table') parallel_mode = _get_parallel_mode() is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) + self.cache_enable = self.vocab_cache_size > 0 + if self.cache_enable: + if is_auto_parallel: + self.vocab_cache_size = self.vocab_cache_size * get_group_size() + self.vocab_size = self.vocab_cache_size + + self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), + name='embedding_table') + if self.cache_enable: + self.embedding_table.cache_enable = True + _set_cache_enable(True) + if _is_role_worker(): + _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) self.forward_unique = False self.gather_revert = P.GatherV2() self.unique = P.Unique().shard(((1,),)) @@ -222,7 +236,7 @@ class EmbeddingLookup(Cell): self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) elif slice_mode == "table_row_slice" and is_auto_parallel: - if target == 'DEVICE': + if target == 'DEVICE' and not self.cache_enable: indices_shape_size = 1 self.gather_revert.shard(((1, 1), (get_group_size(),))) self.forward_unique = True diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 4b3d139b68..7eb51065ab 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -88,14 +88,14 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") + "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, - beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter): + beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable): """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" success = True indices = gradient.indices values = gradient.values - if ps_parameter: + if ps_parameter and not cache_enable: op_shape = P.Shape() shapes = (op_shape(param), op_shape(m), op_shape(v), op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), @@ -158,12 +158,13 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") -def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, - beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2, ps_parameter): + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") +def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, + beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, + moment1, moment2, ps_parameter, cache_enable): """Apply adam optimizer to the weight parameter using Tensor.""" success = True - if ps_parameter: + if ps_parameter and not cache_enable: op_shape = P.Shape() success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) @@ -338,12 +339,12 @@ class Adam(Optimizer): success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, self.use_locking, self.use_nesterov, self._is_device, beta1_power, beta2_power, self.beta1, self.beta2, self.eps), - lr, gradients, params, moment1, moment2, self.ps_parameters) + lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable) else: success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, self.use_locking, self.use_nesterov, self._is_device, beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), - gradients, params, moment1, moment2, self.ps_parameters) + gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable) return success @Optimizer.target.setter diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index 5e9ebd1934..12fd426b63 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -24,14 +24,14 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", - "RowTensor", "Tensor", "Tensor", "Bool") + "RowTensor", "Tensor", "Tensor", "Bool", "Bool") def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, - gradient, weight, moment, ps_parameter): + gradient, weight, moment, ps_parameter, cache_enable): """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" success = True indices = gradient.indices values = gradient.values - if ps_parameter: + if ps_parameter and not cache_enable: op_shape = P.Shape() shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) success = F.depend(success, pull(push((values, indices), shapes), weight)) @@ -41,12 +41,12 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, le @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", - "Tensor", "Tensor", "Tensor", "Bool") + "Tensor", "Tensor", "Tensor", "Bool", "Bool") def _tensor_run_opt(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, - gradient, weight, moment, ps_parameter): + gradient, weight, moment, ps_parameter, cache_enable): """Apply ftrl optimizer to the weight parameter.""" success = True - if ps_parameter: + if ps_parameter and not cache_enable: op_shape = P.Shape() success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power), (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) @@ -185,7 +185,7 @@ class FTRL(Optimizer): success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, self.l1, self.l2, self.lr_power, lr), - linear, grads, params, moments, self.ps_parameters) + linear, grads, params, moments, self.ps_parameters, self.cache_enable) return success @Optimizer.target.setter diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index e1e7ad4929..c0b3331a32 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -156,6 +156,8 @@ class Optimizer(Cell): break ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) + ps_cache_filter = lambda x: x.cache_enable + self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters) self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) self.need_scale = loss_scale != 1.0 self.global_step_increase_tensor = Tensor(1, mstype.int32) diff --git a/mindspore/parallel/_ps_context.py b/mindspore/parallel/_ps_context.py index 999c5adde5..40b4ad6266 100644 --- a/mindspore/parallel/_ps_context.py +++ b/mindspore/parallel/_ps_context.py @@ -117,3 +117,21 @@ def _is_role_pserver(): def _is_role_sched(): return ps_context().is_role_sched() + +def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size): + ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size) + +def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size): + ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size) + +def _insert_weight_init_info(name, global_seed, op_seed): + ps_context().insert_weight_init_info(name, global_seed, op_seed) + +def _insert_accumu_init_info(name, init_val): + ps_context().insert_accumu_init_info(name, init_val) + +def _clone_hash_table(dest_param_name, src_param_name): + ps_context().clone_hash_table(dest_param_name, src_param_name) + +def _set_cache_enable(cache_enable): + ps_context().set_cache_enable(cache_enable) diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 5fbc3913c1..e7ef2e14c5 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -92,6 +92,10 @@ def connect_network_with_dataset(network, dataset_helper): if isinstance(dataset_iter, _DatasetIterNormal): raise RuntimeError("Dataset should be connected with network only in sink mode.") + ms_role = os.getenv("MS_ROLE") + if ms_role in ("MS_PSERVER", "MS_SCHED"): + return network + if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ and context.get_context("device_target") == "Ascend" \ @@ -166,14 +170,14 @@ class DatasetHelper: iterclass = _DatasetIterGE else: if context.get_context("mode") == context.GRAPH_MODE: - if context.get_context("device_target") == "Ascend": + ms_role = os.getenv("MS_ROLE") + if ms_role in ("MS_PSERVER", "MS_SCHED"): + iterclass = _DatasetIterPSServer + elif ms_role == "MS_WORKER": + iterclass = _DatasetIterPSWork + elif (context.get_context("device_target") == "Ascend") or \ + (context.get_context("device_target") == "GPU"): iterclass = _DatasetIterMSLoopSink - elif context.get_context("device_target") == "GPU": - ms_role = os.getenv("MS_ROLE") - if ms_role in ("MS_PSERVER", "MS_SCHED"): - iterclass = _DatasetIterPSLite - else: - iterclass = _DatasetIterMSLoopSink elif context.get_context("device_target") == "CPU": raise RuntimeError( "Currently dataset sink mode is not supported when the device target is CPU.") @@ -218,7 +222,10 @@ class _DatasetIter: if not hasattr(dataset, '__transfer_dataset__'): if hasattr(dataset, '__loop_size__'): - self.sink_size = dataset.__loop_size__ + ms_role = os.getenv("MS_ROLE") + # PS mode does not support loop sink and need get the real sink size. + if ms_role != "MS_WORKER": + self.sink_size = dataset.__loop_size__ create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and context.get_context( "device_target") == "Ascend") dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size, @@ -260,8 +267,12 @@ class _DatasetIter: def get_sink_size(self): """get sink_size to device""" sink_size = 1 + ms_role = os.getenv("MS_ROLE") if hasattr(self.dataset, '__loop_size__'): sink_size = self.dataset.__loop_size__ + elif ms_role == "MS_WORKER": + # PS mode does not support loop sink. + sink_size = 1 else: if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend" \ or context.get_context("device_target") == "GPU": @@ -311,9 +322,6 @@ class _DatasetIterMSLoopSink(_DatasetIter): def __init__(self, dataset, sink_size, epoch_num): super().__init__(dataset, sink_size, epoch_num) self.sink_count = self.get_sink_count(dataset) - ms_role = os.getenv("MS_ROLE") - if ms_role in ("MS_PSERVER", "MS_SCHED"): - self.sink_count = 1 # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. @@ -341,8 +349,8 @@ class _DatasetIterMS(_DatasetIter): self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) -class _DatasetIterPSLite(_DatasetIter): - """Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED""" +class _DatasetIterPSServer(_DatasetIter): + """Iter for context on MS_PSERVER or MS_SCHED""" def __init__(self, dataset, sink_size, epoch_num): super().__init__(dataset, sink_size, epoch_num) @@ -355,6 +363,20 @@ class _DatasetIterPSLite(_DatasetIter): self.op = op +class _DatasetIterPSWork(_DatasetIter): + """Iter for context on MS_WORKER""" + + def __init__(self, dataset, sink_size, epoch_num): + super().__init__(dataset, sink_size, epoch_num) + if sink_size > 0: + self.sink_count = sink_size + else: + self.sink_count = dataset.get_dataset_size() + + def op(): + return tuple() + + self.op = op class _DatasetIterNormal: """Iter for normal(non sink) mode, feed the data from host.""" diff --git a/model_zoo/official/recommend/wide_and_deep/src/config.py b/model_zoo/official/recommend/wide_and_deep/src/config.py index 7834baf91e..aeafdc39f7 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/config.py +++ b/model_zoo/official/recommend/wide_and_deep/src/config.py @@ -30,6 +30,7 @@ def argparse_init(): parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.") parser.add_argument("--field_size", type=int, default=39, help="The number of features.") parser.add_argument("--vocab_size", type=int, default=200000, help="The total features of dataset.") + parser.add_argument("--vocab_cache_size", type=int, default=0, help="The total features of hash table.") parser.add_argument("--emb_dim", type=int, default=80, help="The dense embedding dimension of sparse feature.") parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128], help="The dimension of all deep layers.") @@ -66,6 +67,7 @@ class WideDeepConfig(): self.eval_batch_size = 16000 self.field_size = 39 self.vocab_size = 200000 + self.vocab_cache_size = 100000 self.emb_dim = 80 self.deep_layer_dim = [1024, 512, 256, 128] self.deep_layer_act = 'relu' @@ -103,6 +105,7 @@ class WideDeepConfig(): self.eval_batch_size = args.eval_batch_size self.field_size = args.field_size self.vocab_size = args.vocab_size + self.vocab_cache_size = args.vocab_cache_size self.emb_dim = args.emb_dim self.deep_layer_dim = args.deep_layer_dim self.deep_layer_act = args.deep_layer_act diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index 6dd469532f..978b1be288 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -147,6 +147,7 @@ class WideDeepModel(nn.Cell): sparse = config.sparse self.field_size = config.field_size self.vocab_size = config.vocab_size + self.vocab_cache_size = config.vocab_cache_size self.emb_dim = config.emb_dim self.deep_layer_dims_list = config.deep_layer_dim self.deep_layer_act = config.deep_layer_act @@ -237,8 +238,20 @@ class WideDeepModel(nn.Cell): self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) self.embedding_table = self.deep_embeddinglookup.embedding_table elif parameter_server: - self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) - self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) + cache_enable = self.vocab_cache_size > 0 + target = 'DEVICE' if cache_enable else 'CPU' + if is_auto_parallel and config.full_batch and cache_enable: + self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, + slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, + sparse=sparse, vocab_cache_size=self.vocab_cache_size) + self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, + slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, + sparse=sparse, vocab_cache_size=self.vocab_cache_size) + else: + self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, + sparse=sparse, vocab_cache_size=self.vocab_cache_size) + self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, sparse=sparse, + vocab_cache_size=self.vocab_cache_size) self.embedding_table = self.deep_embeddinglookup.embedding_table self.deep_embeddinglookup.embedding_table.set_param_ps() self.wide_embeddinglookup.embedding_table.set_param_ps() @@ -344,7 +357,7 @@ class TrainStepWrap(nn.Cell): """ def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, - sparse=False): + sparse=False, cache_enable=False): super(TrainStepWrap, self).__init__() parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) @@ -361,7 +374,7 @@ class TrainStepWrap(nn.Cell): self.weights_w = ParameterTuple(weights_w) self.weights_d = ParameterTuple(weights_d) - if (sparse and is_auto_parallel) or parameter_server: + if (sparse and is_auto_parallel) or (parameter_server and not cache_enable): self.optimizer_d = LazyAdam( self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, @@ -417,10 +430,17 @@ class TrainStepWrap(nn.Cell): class PredictWithSigmoid(nn.Cell): + """ + Predict definition + """ def __init__(self, network): super(PredictWithSigmoid, self).__init__() self.network = network self.sigmoid = P.Sigmoid() + parallel_mode = context.get_auto_parallel_context("parallel_mode") + is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) + if is_auto_parallel: + self.sigmoid.shard(((1, 1),)) def construct(self, batch_ids, batch_wts, labels): logits, _, = self.network(batch_ids, batch_wts) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py index bac2061271..5fefc1b36c 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py @@ -39,7 +39,8 @@ def get_WideDeep_net(config): """ WideDeep_net = WideDeepModel(config) loss_net = NetWithLossClass(WideDeep_net, config) - train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server)) + train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server), + cache_enable=bool(config.vocab_cache_size > 0)) eval_net = PredictWithSigmoid(WideDeep_net) return train_net, eval_net @@ -81,6 +82,7 @@ def train_and_eval(config): else: dataset_type = DataType.H5 parameter_server = bool(config.parameter_server) + cache_enable = bool(config.vocab_cache_size > 0) print("epochs is {}".format(epochs)) ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, rank_id=get_rank(), @@ -111,7 +113,7 @@ def train_and_eval(config): callback_list.append(ckpoint_cb) model.train(epochs, ds_train, callbacks=callback_list, - dataset_sink_mode=(not parameter_server)) + dataset_sink_mode=(parameter_server and cache_enable)) if __name__ == "__main__":