| @@ -194,6 +194,14 @@ if (ENABLE_GPU) | |||
| ) | |||
| endif () | |||
| if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| install( | |||
| TARGETS ps_cache | |||
| DESTINATION ${INSTALL_LIB_DIR} | |||
| COMPONENT mindspore | |||
| ) | |||
| endif() | |||
| if (ENABLE_SERVING OR ENABLE_TESTCASES) | |||
| file(GLOB_RECURSE LIBEVENT_LIB_LIST | |||
| ${libevent_LIBPATH}/libevent* | |||
| @@ -308,7 +308,7 @@ elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| else () | |||
| if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) | |||
| target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core) | |||
| target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) | |||
| if (${ENABLE_IBVERBS} STREQUAL "ON") | |||
| target_link_libraries(mindspore ibverbs rdmacm) | |||
| endif() | |||
| @@ -75,6 +75,9 @@ void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector<AddressPtr> &inputs | |||
| if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) { | |||
| node_so_ = CUST_AICPU_OPS_SO_NAME; | |||
| node_name_ = kCustRunApi; | |||
| } else if (kCacheKernelOps.find(node_name_) != kCacheKernelOps.end()) { | |||
| node_so_ = AICPU_OPS_SO_NAME; | |||
| node_name_ = kCustRunApi; | |||
| } else { | |||
| node_so_ = AICPU_OPS_SO_NAME; | |||
| } | |||
| @@ -161,6 +164,9 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr> | |||
| if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) { | |||
| node_so_ = CUST_AICPU_OPS_SO_NAME; | |||
| node_name_ = kCustRunApi; | |||
| } else if (kCacheKernelOps.find(node_name_) != kCacheKernelOps.end()) { | |||
| node_so_ = AICPU_OPS_SO_NAME; | |||
| node_name_ = kCustRunApi; | |||
| } else { | |||
| node_so_ = AICPU_OPS_SO_NAME; | |||
| } | |||
| @@ -49,6 +49,7 @@ constexpr auto kIdentity = "Identity"; | |||
| constexpr auto kUpdateCache = "UpdateCache"; | |||
| constexpr auto kCustRunApi = "RunCpuKernel"; | |||
| const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity}; | |||
| const std::set<std::string> kCacheKernelOps{kUpdateCache}; | |||
| struct AicpuParamHead { | |||
| uint32_t length; // Total length: include cunstom message | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h" | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "ps/worker.h" | |||
| namespace mindspore { | |||
| @@ -38,10 +39,13 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey); | |||
| } | |||
| std::vector<size_t> keys{key_, key_, key_}; | |||
| std::vector<size_t> values; | |||
| values.insert(values.end(), input_shape.begin(), input_shape.end()); | |||
| values.insert(values.end(), indices_shape.begin(), indices_shape.end()); | |||
| values.insert(values.end(), output_shape.begin(), output_shape.end()); | |||
| std::vector<float> values; | |||
| std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(values), | |||
| [](size_t dim) -> float { return SizeToFloat(dim); }); | |||
| std::transform(indices_shape.begin(), indices_shape.end(), std::back_inserter(values), | |||
| [](size_t dim) -> float { return SizeToFloat(dim); }); | |||
| std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(values), | |||
| [](size_t dim) -> float { return SizeToFloat(dim); }); | |||
| MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape | |||
| << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; | |||
| std::vector<int64_t> lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()), | |||
| @@ -72,6 +72,23 @@ bool EmbeddingLookUpPSKernel::Execute(const std::vector<AddressPtr> &inputs, con | |||
| return Launch(inputs, workspace, outputs); | |||
| } | |||
| void EmbeddingLookUpPSKernel::UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, | |||
| const float *update_vals, size_t ids_size) { | |||
| size_t copy_lens = outer_dim_size_ * sizeof(float); | |||
| for (size_t i = 0; i < ids_size; ++i) { | |||
| int index = lookup_ids[i] - offset_; | |||
| if (index >= 0 && index < SizeToInt(first_dim_size_)) { | |||
| auto ret = | |||
| memcpy_s(embedding_table + index * outer_dim_size_, copy_lens, update_vals + i * outer_dim_size_, copy_lens); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "UpdateEmbeddings index invalid."; | |||
| } | |||
| } | |||
| } | |||
| const std::vector<size_t> &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; } | |||
| const std::vector<size_t> &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); } | |||
| @@ -35,7 +35,8 @@ class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerK | |||
| bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals, | |||
| size_t ids_size) override; | |||
| const std::vector<size_t> &input_sizes() const override; | |||
| const std::vector<size_t> &output_sizes() const override; | |||
| const std::vector<size_t> &workspace_sizes() const override; | |||
| @@ -38,7 +38,8 @@ class PServerKernel { | |||
| virtual void ReInit(const std::vector<std::vector<size_t>> &) {} | |||
| virtual bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) = 0; | |||
| virtual void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals, | |||
| size_t ids_size) {} | |||
| virtual const std::vector<size_t> &input_sizes() const = 0; | |||
| virtual const std::vector<size_t> &output_sizes() const = 0; | |||
| virtual const std::vector<size_t> &workspace_sizes() const = 0; | |||
| @@ -56,6 +56,7 @@ | |||
| #include "toolchain/adx_datadump_server.h" | |||
| #if ENABLE_CPU && ENABLE_D | |||
| #include "ps/util.h" | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -487,11 +488,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||
| // adjust kernel | |||
| AdjustKernel(root_graph); | |||
| #if ENABLE_CPU && ENABLE_D | |||
| if (ps::Util::IsParamServerMode()) { | |||
| CheckPSModeConsistence(root_graph); | |||
| // Assign parameter keys. | |||
| AssignParamKey(root_graph); | |||
| } | |||
| InitPsWorker(root_graph); | |||
| #endif | |||
| // assign stream | |||
| AssignStream(NOT_NULL(root_graph)); | |||
| @@ -568,6 +565,9 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) { | |||
| } | |||
| // adjust execution order because merge child graph and other special operations | |||
| AdjustKernel(graph); | |||
| #if ENABLE_CPU && ENABLE_D | |||
| InitPsWorker(graph); | |||
| #endif | |||
| // Reorder optimizer order | |||
| auto execution_order = graph->execution_order(); | |||
| Reorder(&execution_order); | |||
| @@ -644,6 +644,10 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tens | |||
| #if ENABLE_CPU && ENABLE_D | |||
| // Initialize parameter server | |||
| InitPSParamAndOptim(kernel_graph, inputs); | |||
| std::string channel_name; | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable() && IsGetNextGraph(graph_id, &channel_name)) { | |||
| ps::ps_cache_instance.IncreaseGraphStep(channel_name); | |||
| } | |||
| #endif | |||
| { | |||
| // run task on device | |||
| @@ -21,6 +21,9 @@ | |||
| #include "runtime/device/kernel_runtime_manager.h" | |||
| #include "utils/comm_manager.h" | |||
| #include "utils/scoped_long_running.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace session { | |||
| @@ -67,6 +67,7 @@ | |||
| #include "utils/ms_context.h" | |||
| #if ENABLE_CPU && ENABLE_GPU | |||
| #include "ps/util.h" | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -243,6 +244,12 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||
| auto input_node = input_nodes[i]; | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { | |||
| #if ENABLE_CPU && ENABLE_GPU | |||
| const std::string ¶m_name = input_node->fullname_with_scope(); | |||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| continue; | |||
| } | |||
| #endif | |||
| auto pk_node = input_node->cast<ParameterPtr>(); | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | |||
| auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | |||
| @@ -306,16 +313,11 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||
| HardwareOptimize(graph); | |||
| // Graph kernel fusion optimization | |||
| GraphKernelOptimize(graph); | |||
| #if ENABLE_CPU && ENABLE_GPU | |||
| if (ps::Util::IsParamServerMode()) { | |||
| CheckPSModeConsistence(graph); | |||
| // Assign parameter keys. | |||
| AssignParamKey(graph); | |||
| } | |||
| #endif | |||
| // Start gpu kernel runtime | |||
| StartKernelRT(); | |||
| #if ENABLE_CPU && ENABLE_GPU | |||
| InitPsWorker(graph); | |||
| #endif | |||
| // Assign CUDA streams | |||
| AssignStream(graph); | |||
| // Dump .pb graph before remove nop nodes | |||
| @@ -380,6 +382,12 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor: | |||
| int kernel_num = kernel_graph->execution_order().size(); | |||
| int64_t loopsize = (kernel_num > 1) ? ConfigManager::GetInstance().gpu_loopsink_size() : 1; | |||
| for (int64_t i = 0; i < loopsize; i++) { | |||
| #if ENABLE_CPU && ENABLE_GPU | |||
| std::string channel_name; | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable() && IsGetNextGraph(graph_id, &channel_name)) { | |||
| ps::ps_cache_instance.IncreaseGraphStep(channel_name); | |||
| } | |||
| #endif | |||
| Execute(kernel_graph); | |||
| } | |||
| // In pynative mode, device addresses of tensors in value nodes need be clean. | |||
| @@ -41,8 +41,10 @@ | |||
| #include "utils/trace_base.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/worker.h" | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "ps/common.h" | |||
| #include "ps/util.h" | |||
| #include "abstract/abstract_value.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -1125,6 +1127,12 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||
| size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); | |||
| } | |||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| const std::string ¶m_name = input_node->fullname_with_scope(); | |||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| continue; | |||
| } | |||
| #endif | |||
| auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | |||
| MS_EXCEPTION_IF_NULL(device_address); | |||
| if (size != 0 && !device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), size, | |||
| @@ -1715,8 +1723,64 @@ void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr<std::vector<ten | |||
| } | |||
| } | |||
| bool SessionBasic::IsGetNextGraph(const GraphId &graph_id, std::string *channel_name) { | |||
| auto kernel_graph = graphs_[graph_id]; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| for (const auto &kernel_node : kernel_graph->execution_order()) { | |||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||
| if (kernel_name == kGetNextOpName) { | |||
| auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); | |||
| MS_EXCEPTION_IF_NULL(prim); | |||
| *channel_name = GetValue<std::string>(prim->GetAttr("shared_name")); | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { | |||
| void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { | |||
| if (!ps::Util::IsRoleOfWorker()) { | |||
| return; | |||
| } | |||
| CheckPSModeConsistence(kernel_graph); | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| if (!ps::ps_cache_instance.initialized_ps_cache()) { | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| auto devcie_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(devcie_target, device_id_); | |||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||
| auto context = runtime_instance->context(); | |||
| const auto &kernels = kernel_graph->execution_order(); | |||
| if (kernels.size() > 0 && AnfAlgo::GetCNodeName(kernels[0]) == "InitDataSetQueue") { | |||
| GetBatchElements(kernels[0]); | |||
| ps::ps_cache_instance.Initialize(); | |||
| } | |||
| ps::ps_cache_instance.DoProcessData(device_id_, context); | |||
| } | |||
| } else { | |||
| // Assign parameter keys. | |||
| AssignParamKey(kernel_graph); | |||
| } | |||
| } | |||
| void SessionBasic::GetBatchElements(const AnfNodePtr &kernel_node) const { | |||
| auto shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes"); | |||
| auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "types"); | |||
| if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size " | |||
| << types; | |||
| } | |||
| size_t batch_elements = 1; | |||
| const auto &shape = shapes[0]; | |||
| for (size_t i = 0; i < shape.size(); ++i) { | |||
| batch_elements *= shape[i]; | |||
| } | |||
| ps::ps_cache_instance.set_batch_elements(batch_elements); | |||
| } | |||
| void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const { | |||
| auto input_nodes = kernel_graph->inputs(); | |||
| for (const auto &input_node : input_nodes) { | |||
| if (!input_node->isa<Parameter>()) { | |||
| @@ -1725,8 +1789,9 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { | |||
| auto pk_node = input_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(pk_node); | |||
| auto param_info_ptr = pk_node->param_info(); | |||
| if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) { | |||
| const std::string ¶m_name = pk_node->fullname_with_scope(); | |||
| const std::string ¶m_name = pk_node->fullname_with_scope(); | |||
| if (param_info_ptr != nullptr && param_info_ptr->init_in_server() && | |||
| !ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name | |||
| << "] in server, this parameter is used by kernel which executes in device"; | |||
| } | |||
| @@ -1734,10 +1799,6 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { | |||
| } | |||
| void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | |||
| if (!ps::Util::IsRoleOfWorker()) { | |||
| MS_LOG(INFO) << "Not parameter server mode."; | |||
| return; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| @@ -1775,16 +1836,8 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, | |||
| return; | |||
| } | |||
| std::vector<tensor::TensorPtr> inputs(inputs_const); | |||
| size_t input_ctrl_size = 1; | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| if (kernel_graph->input_ctrl_tensors()) { | |||
| input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); | |||
| } | |||
| auto input_nodes = kernel_graph->inputs(); | |||
| if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) { | |||
| MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() | |||
| << ", input_ctrl_size:" << input_ctrl_size; | |||
| } | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| @@ -99,9 +99,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| // get graph id in child graphs by ME front anf node pointer | |||
| virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const; | |||
| virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } | |||
| void CheckPSModeConsistence(const KernelGraphPtr &Kernel_graph); | |||
| void AssignParamKey(const KernelGraphPtr &kernel_graph); | |||
| void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const); | |||
| bool IsGetNextGraph(const GraphId &graph_id, std::string *channel_name); | |||
| virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||
| std::string *error_msg) const { | |||
| return true; | |||
| @@ -195,6 +195,11 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list); | |||
| void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph); | |||
| void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const; | |||
| void GetBatchElements(const AnfNodePtr &kernel_node) const; | |||
| void InitPsWorker(const KernelGraphPtr &kernel_graph); | |||
| #endif | |||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | |||
| std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | |||
| @@ -207,6 +212,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||
| #if !defined(_WIN32) && !defined(_WIN64) | |||
| std::shared_ptr<Debugger> debugger_; | |||
| #endif | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| bool initialized_ps_cache_{false}; | |||
| #endif | |||
| }; | |||
| using SessionPtr = std::shared_ptr<session::SessionBasic>; | |||
| @@ -24,6 +24,9 @@ | |||
| #include "frontend/parallel/device_matrix.h" | |||
| #include "frontend/parallel/graph_util/generate_graph.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -514,6 +517,12 @@ Status GatherV2PInfo::InferBias() { | |||
| if (repeated_calc_num_ > 1) { | |||
| rank = rank / repeated_calc_num_; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| bias_ = 0; | |||
| return SUCCESS; | |||
| } | |||
| #endif | |||
| bias_ = rank / params_strategy.at(1) * slice_size_; | |||
| return SUCCESS; | |||
| } | |||
| @@ -46,10 +46,18 @@ | |||
| #include "ir/anf.h" | |||
| #include "ir/param_info.h" | |||
| #include "ir/tensor.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/util.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { | |||
| return false; | |||
| } | |||
| #endif | |||
| MS_EXCEPTION_IF_NULL(root); | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); | |||
| @@ -44,6 +44,9 @@ | |||
| #include "utils/comm_manager.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/util.h" | |||
| #endif | |||
| using mindspore::tensor::Tensor; | |||
| @@ -3036,6 +3039,11 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) { | |||
| } | |||
| bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { | |||
| return false; | |||
| } | |||
| #endif | |||
| MS_EXCEPTION_IF_NULL(root); | |||
| MS_EXCEPTION_IF_NULL(optimizer); | |||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | |||
| @@ -202,6 +202,7 @@ else () | |||
| if (${ENABLE_IBVERBS} STREQUAL "ON") | |||
| target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) | |||
| endif () | |||
| target_link_libraries(_c_dataengine PRIVATE ps_cache) | |||
| endif () | |||
| endif () | |||
| @@ -322,6 +322,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con | |||
| bool profiling, int32_t *push_time) { | |||
| std::vector<device::DataItemGpu> items; | |||
| double start_time; | |||
| bool ps_data_prefetch = false; | |||
| for (int i = 0; i < data_size.size(); i++) { | |||
| device::DataItemGpu data_item; | |||
| data_item.data_len_ = data_size[i]; | |||
| @@ -334,6 +335,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con | |||
| if (profiling) { | |||
| start_time = ProfilingTime::GetCurMilliSecond(); | |||
| } | |||
| // Data prefetch only when PS mode enables cache. | |||
| if ((!ps_data_prefetch) && (items.size() > 0)) { | |||
| ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_); | |||
| ps_data_prefetch = true; | |||
| } | |||
| BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); | |||
| if (profiling) { | |||
| double end_time = ProfilingTime::GetCurMilliSecond(); | |||
| @@ -24,6 +24,7 @@ | |||
| #include "minddata/dataset/engine/datasetops/pipeline_op.h" | |||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | |||
| #include "minddata/dataset/util/status.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #ifdef ENABLE_TDTQUE | |||
| #include "minddata/dataset/util/queue.h" | |||
| @@ -17,6 +17,8 @@ | |||
| #include "utils/ms_utils.h" | |||
| #include "minddata/dataset/engine/perf/profiling.h" | |||
| #include "minddata/dataset/util/log_adapter.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| namespace mindspore { | |||
| namespace dataset { | |||
| static std::shared_ptr<TdtPlugin> instance_ptr_ = nullptr; | |||
| @@ -48,6 +50,10 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe | |||
| if (profiling) { | |||
| start_time = ProfilingTime::GetCurMilliSecond(); | |||
| } | |||
| // Data prefetch only when PS mode enables cache. | |||
| if (items.size() > 0) { | |||
| ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_); | |||
| } | |||
| if (tdt::TdtHostPushData(channel_name, items) != 0) { | |||
| return FAILED; | |||
| } | |||
| @@ -308,7 +308,14 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.") | |||
| .def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.") | |||
| .def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.") | |||
| .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id."); | |||
| .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.") | |||
| .def("insert_hash_table_size", &PSContext::InsertHashTableSize, "Insert hash table size.") | |||
| .def("reinsert_hash_table_size", &PSContext::ReInsertHashTableSize, | |||
| "Insert hash table size with new parameter name.") | |||
| .def("insert_weight_init_info", &PSContext::InsertWeightInitInfo, "Insert embedding table initialization seed.") | |||
| .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.") | |||
| .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.") | |||
| .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not."); | |||
| (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | |||
| .def(py::init()) | |||
| @@ -52,6 +52,7 @@ | |||
| #include "ps/common.h" | |||
| #include "ps/util.h" | |||
| #include "ps/worker.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #endif | |||
| #if (ENABLE_GE || ENABLE_D) | |||
| @@ -921,6 +922,11 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba | |||
| bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, | |||
| const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, | |||
| const std::vector<int64_t> &input_indexes, bool need_run) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if ((ps::Util::IsParamServerMode()) && (!ps::Util::IsRoleOfWorker())) { | |||
| return true; | |||
| } | |||
| #endif | |||
| MS_LOG(INFO) << "Start InitDataSet Entry"; | |||
| ShapeVector int_input_indexes; | |||
| (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), | |||
| @@ -966,7 +972,17 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| backend->Link(runner.graph_id); | |||
| } | |||
| ConfigManager::GetInstance().set_iter_num(size); | |||
| // PS mode does not support loop sink. | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsRoleOfWorker()) { | |||
| ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); | |||
| ConfigManager::GetInstance().set_iter_num(1); | |||
| } else { | |||
| #endif | |||
| ConfigManager::GetInstance().set_iter_num(size); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| } | |||
| #endif | |||
| if (!(*runner.run)) { | |||
| // empty function | |||
| @@ -981,7 +997,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| } | |||
| MS_LOG(DEBUG) << "InitDataSetVm End."; | |||
| return true; | |||
| } | |||
| } // namespace pipeline | |||
| void ResetOpId() { mindspore::id_generator::reset_id(); } | |||
| @@ -14,22 +14,20 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||
| endif () | |||
| if (NOT ENABLE_D) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||
| endif() | |||
| if (NOT ENABLE_GPU) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||
| endif() | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") | |||
| add_subdirectory(ps_cache) | |||
| set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | |||
| add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) | |||
| @@ -64,6 +64,7 @@ constexpr int64_t kInitWeightToOptimIdCmd = 11; | |||
| constexpr int64_t kInitOptimInputsShapeCmd = 12; | |||
| constexpr int64_t kInitKeyToPushNodeIdCmd = 13; | |||
| constexpr int64_t kInitEmbeddingsCmd = 20; | |||
| constexpr int64_t kUpdateEmbeddingsCmd = 21; | |||
| constexpr int64_t kCheckReadyForPushCmd = 25; | |||
| constexpr int64_t kCheckReadyForPullCmd = 26; | |||
| constexpr int64_t kEmbeddingLookupCmd = 30; | |||
| @@ -51,6 +51,8 @@ | |||
| #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "ps/random_normal/random_normal.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -100,6 +102,7 @@ class ParameterServer { | |||
| void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | |||
| void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | |||
| void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | |||
| void HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | |||
| void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | |||
| ParameterServer *ps_; | |||
| @@ -118,13 +121,15 @@ class ParameterServer { | |||
| void InitWeight(const Key &key, const WeightPtr &weight); | |||
| void InitGrad(const Key &key, const GradPtr &grad); | |||
| void InitEmbeddingTable(const Key &key, | |||
| const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes); | |||
| const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes, | |||
| const ParamInitInfo ¶m_init_info); | |||
| bool HasWeight(const Key &key); | |||
| void Finalize(); | |||
| void UpdateWeights(); | |||
| void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); | |||
| WeightPtr weight(const Key &key); | |||
| void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res); | |||
| void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals); | |||
| bool ReadyForUpdateWeights(); | |||
| bool ReadyForPush(const Key &key); | |||
| bool ReadyForPull(const Key &key); | |||
| @@ -193,6 +198,7 @@ void ParameterServer<T>::ServerHandler::Init() { | |||
| handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush; | |||
| handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull; | |||
| handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup; | |||
| handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings; | |||
| handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize; | |||
| } | |||
| @@ -302,7 +308,17 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta | |||
| for (int64_t k = 0; k < lens[2]; k++) { | |||
| output_shape->push_back(static_cast<size_t>(req_data.vals[index++])); | |||
| } | |||
| ps_->InitEmbeddingTable(key, shapes); | |||
| ParamInitInfo param_init_info; | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| param_init_info.param_type_ = static_cast<ParamType>(lens[3]); | |||
| if (param_init_info.param_type_ == kWeight) { | |||
| param_init_info.global_seed_ = static_cast<size_t>(lens[4]); | |||
| param_init_info.op_seed_ = static_cast<size_t>(lens[5]); | |||
| } else if (param_init_info.param_type_ == kAccumulation) { | |||
| param_init_info.init_val_ = req_data.vals[index]; | |||
| } | |||
| } | |||
| ps_->InitEmbeddingTable(key, shapes, param_init_info); | |||
| } | |||
| template <typename T> | |||
| @@ -338,6 +354,18 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta | |||
| ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res); | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::ServerHandler::HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, | |||
| const ::ps::KVPairs<T> &req_data, | |||
| ::ps::KVPairs<T> *res) { | |||
| std::unique_lock<std::mutex> lock(ps_->mutex()); | |||
| MS_EXCEPTION_IF_NULL(res); | |||
| const Key &key = req_data.keys[0]; | |||
| const LookupIds &lookup_ids = req_data.keys.segment(1, req_data.keys.size()); | |||
| const Values &update_vals = req_data.vals; | |||
| ps_->UpdateEmbeddings(key, lookup_ids, update_vals); | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, | |||
| ::ps::KVPairs<T> *res) { | |||
| @@ -476,7 +504,8 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) { | |||
| template <typename T> | |||
| void ParameterServer<T>::InitEmbeddingTable( | |||
| const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | |||
| const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes, | |||
| const ParamInitInfo ¶m_init_info) { | |||
| MS_EXCEPTION_IF_NULL(shapes); | |||
| if (weights_.count(key) == 0) { | |||
| std::shared_ptr<PServerKernel> lookup = | |||
| @@ -493,8 +522,18 @@ void ParameterServer<T>::InitEmbeddingTable( | |||
| T *embedding_data = embedding->data(); | |||
| std::default_random_engine engine; | |||
| std::normal_distribution<float> random(0, 0.01); | |||
| for (size_t i = 0; i < total_dims; i++) { | |||
| embedding_data[i] = random(engine); | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| if (param_init_info.param_type_ == kWeight) { | |||
| InitRandomNormal(0, 0.01, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_, embedding_data); | |||
| } else if (param_init_info.param_type_ == kAccumulation) { | |||
| for (size_t i = 0; i < total_dims; i++) { | |||
| embedding_data[i] = param_init_info.init_val_; | |||
| } | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < total_dims; i++) { | |||
| embedding_data[i] = random(engine); | |||
| } | |||
| } | |||
| weights_[key] = embedding; | |||
| tokens_[key] = 0; | |||
| @@ -673,6 +712,23 @@ void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, | |||
| res->lens.push_back(res->vals.size()); | |||
| } | |||
| template <typename T> | |||
| void ParameterServer<T>::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) { | |||
| if (weights_.count(key) == 0) { | |||
| MS_LOG(ERROR) << "Invalid embedding table key " << key; | |||
| return; | |||
| } | |||
| if (embedding_lookup_ops_.count(key) == 0) { | |||
| MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; | |||
| return; | |||
| } | |||
| WeightPtr table_ptr = weights_[key]; | |||
| MS_EXCEPTION_IF_NULL(table_ptr); | |||
| std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key]; | |||
| MS_EXCEPTION_IF_NULL(table_lookup_op); | |||
| table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size()); | |||
| } | |||
| template <typename T> | |||
| inline bool ParameterServer<T>::ReadyForUpdateWeights() { | |||
| return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); | |||
| @@ -70,9 +70,16 @@ void PsCacheManager::InsertWeightInitInfo(const std::string ¶m_name, size_t | |||
| MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; | |||
| } | |||
| auto &hash_table_info = iter->second; | |||
| if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { | |||
| return; | |||
| } | |||
| hash_table_info.param_init_info_.param_type_ = kWeight; | |||
| hash_table_info.param_init_info_.global_seed_ = global_seed; | |||
| hash_table_info.param_init_info_.op_seed_ = op_seed; | |||
| if (CheckFinishInsertInitInfo()) { | |||
| finish_insert_init_info_ = true; | |||
| insert_init_info_.notify_one(); | |||
| } | |||
| } | |||
| void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) { | |||
| @@ -81,8 +88,26 @@ void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float i | |||
| MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; | |||
| } | |||
| auto &hash_table_info = iter->second; | |||
| if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { | |||
| return; | |||
| } | |||
| hash_table_info.param_init_info_.param_type_ = kAccumulation; | |||
| hash_table_info.param_init_info_.init_val_ = init_val; | |||
| if (CheckFinishInsertInitInfo()) { | |||
| finish_insert_init_info_ = true; | |||
| insert_init_info_.notify_one(); | |||
| } | |||
| } | |||
| bool PsCacheManager::CheckFinishInsertInitInfo() const { | |||
| for (const auto &item : hash_tables_) { | |||
| const auto &hash_table_info = item.second; | |||
| const auto ¶m_init_info = hash_table_info.param_init_info_; | |||
| if (param_init_info.param_type_ == kUnKnown) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) { | |||
| @@ -113,35 +138,49 @@ void PsCacheManager::Initialize() { | |||
| } | |||
| embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_); | |||
| embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_); | |||
| InitParameterServer(); | |||
| AddEmbeddingTable(); | |||
| AllocMemForHashTable(); | |||
| SetLocalIdRank(); | |||
| initialized_ps_cache_ = true; | |||
| } | |||
| void PsCacheManager::InitParameterServer() { | |||
| void PsCacheManager::AddEmbeddingTable() const { | |||
| for (const auto &item : hash_tables_) { | |||
| const auto ¶m_name = item.first; | |||
| size_t key = worker.SetParamKey(param_name); | |||
| size_t row_count = item.second.vocab_size; | |||
| std::vector<size_t> keys{key, key, key, key}; | |||
| // if worker role | |||
| worker.AddEmbeddingTable(key, row_count); | |||
| } | |||
| } | |||
| void PsCacheManager::InitParameterServer() { | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; }); | |||
| for (const auto &item : hash_tables_) { | |||
| const auto ¶m_name = item.first; | |||
| size_t key = worker.SetParamKey(param_name); | |||
| std::vector<size_t> keys{key, key, key, key, key, key}; | |||
| std::vector<float> values{ | |||
| SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1}; | |||
| std::vector<int64_t> lens{2, 2, 3}; | |||
| const auto &hash_table_info = item.second; | |||
| const auto ¶m_init_info = hash_table_info.param_init_info_; | |||
| if (param_init_info.param_type_ == kWeight) { | |||
| lens.push_back(0); | |||
| values.push_back(SizeToFloat(param_init_info.global_seed_)); | |||
| values.push_back(SizeToFloat(param_init_info.op_seed_)); | |||
| } else if (param_init_info.param_type_ == kAccumulation) { | |||
| lens.push_back(1); | |||
| values.push_back(param_init_info.init_val_); | |||
| } else if (param_init_info.param_type_ == kAccumulation) { | |||
| lens.push_back(2); | |||
| } | |||
| values.push_back(param_init_info.init_val_); | |||
| lens.push_back(param_init_info.global_seed_); | |||
| lens.push_back(param_init_info.op_seed_); | |||
| // if worker role | |||
| worker.AddEmbeddingTable(key, row_count); | |||
| worker.InitPSEmbeddingTable(keys, values, lens); | |||
| } | |||
| finish_init_parameter_server_ = true; | |||
| data_prase_.notify_one(); | |||
| } | |||
| void PsCacheManager::AllocMemForHashTable() { | |||
| @@ -208,10 +247,538 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||
| if (graph_step_ >= UINT64_MAX) { | |||
| MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; | |||
| } | |||
| if (graph_step_ == 0) { | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; }); | |||
| } | |||
| graph_step_++; | |||
| set_channel_name(channel_name); | |||
| PsDataPrefetch::GetInstance().TryWakeChannel(channel_name); | |||
| data_prase_.notify_one(); | |||
| } | |||
| void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | |||
| if (!initialized_ps_cache_) { | |||
| MS_LOG(EXCEPTION) << "PS cache does not init."; | |||
| } | |||
| auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | |||
| process_data_thread.detach(); | |||
| } | |||
| void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| embedding_device_cache_->cache_->InitDevice(device_id, context); | |||
| InitParameterServer(); | |||
| while (true) { | |||
| ProcessData(); | |||
| } | |||
| } | |||
| void PsCacheManager::ProcessData() { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| struct timeval start_time, end_time; | |||
| const uint64_t kUSecondInSecond = 1000000; | |||
| (void)gettimeofday(&start_time, nullptr); | |||
| auto channel = channel_name(); | |||
| if (channel.empty()) { | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| data_prase_.wait(locker, [this] { return !channel_name_.empty(); }); | |||
| } | |||
| auto data = PsDataPrefetch::GetInstance().data(channel_name_); | |||
| if (data == nullptr) { | |||
| MS_LOG(INFO) << "No data process, channel name:" << channel_name_; | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| (void)data_prase_.wait_for(locker, std::chrono::milliseconds(100)); | |||
| return; | |||
| } | |||
| IncreaseStep(); | |||
| auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_); | |||
| auto batch_ids = reinterpret_cast<int *>(data); | |||
| auto batch_ids_len = data_size / sizeof(int); | |||
| std::unique_ptr<int[]> hash_index(new int[batch_ids_len]); | |||
| if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) { | |||
| MS_LOG(EXCEPTION) << "Process data memset failed."; | |||
| } | |||
| // Get hash swap in/out index and ids. | |||
| ParseData(batch_ids, batch_ids_len, hash_index.get()); | |||
| for (const auto &item : hash_tables_) { | |||
| auto key = worker.GetParamKey(item.first); | |||
| auto hash_info = item.second; | |||
| HashSwapHostToServer(key, hash_info); | |||
| HashSwapDeviceToHost(hash_info); | |||
| HashSwapServerToHost(key, hash_info); | |||
| HashSwapHostToDevice(hash_info); | |||
| } | |||
| // Replace the batch_ids by hash index for getNext-op getting hash index as input. | |||
| if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) { | |||
| MS_LOG(EXCEPTION) << "Process data memcpy failed."; | |||
| } | |||
| embedding_device_cache_->cache_->SynchronizeStream(); | |||
| // Finish the data process and notify data prefetch. | |||
| PsDataPrefetch::GetInstance().FinalizeData(channel_name_); | |||
| (void)gettimeofday(&end_time, nullptr); | |||
| uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); | |||
| cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); | |||
| MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_ | |||
| << ",graph step:" << graph_running_step_ << " channel name:" << channel_name_ | |||
| << ", time cost:" << cost / 1000 << "ms)."; | |||
| } | |||
| void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { | |||
| MS_EXCEPTION_IF_NULL(batch_ids); | |||
| MS_EXCEPTION_IF_NULL(hash_index); | |||
| for (size_t i = 0; i < batch_ids_len; i++) { | |||
| bool need_swap_host_to_device = true; | |||
| bool need_swap_device_to_host = true; | |||
| auto id = batch_ids[i]; | |||
| if ((id < SizeToInt(range_bound_.first)) || (id >= SizeToInt(range_bound_.second))) { | |||
| hash_index[i] = -1; | |||
| continue; | |||
| } | |||
| hash_index[i] = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); | |||
| if (need_swap_host_to_device) { | |||
| ParseHostDataHostToDevice(id); | |||
| } | |||
| if (need_swap_device_to_host) { | |||
| ParseHostDataDeviceToHost(id); | |||
| } | |||
| } | |||
| // Each 1000 step prints ps cache hit rate. | |||
| if (data_step_ % 1000 == 0) { | |||
| statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_; | |||
| auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; | |||
| MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%."; | |||
| } | |||
| } | |||
| void PsCacheManager::WaitGraphRun() { | |||
| MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes."; | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| if (!data_prase_.wait_for(locker, std::chrono::seconds(120), [this] { return graph_step_ > graph_running_step_; })) { | |||
| MS_LOG(EXCEPTION) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_ | |||
| << ", graph running step:" << graph_running_step_ << ")."; | |||
| } | |||
| set_current_graph_step(); | |||
| } | |||
| int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) { | |||
| MS_EXCEPTION_IF_NULL(need_swap_device_to_host); | |||
| MS_EXCEPTION_IF_NULL(need_swap_host_to_device); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); | |||
| int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); | |||
| int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); | |||
| int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); | |||
| MS_EXCEPTION_IF_NULL(device_to_host_index); | |||
| MS_EXCEPTION_IF_NULL(device_to_host_ids); | |||
| MS_EXCEPTION_IF_NULL(host_to_device_index); | |||
| MS_EXCEPTION_IF_NULL(host_to_device_ids); | |||
| auto device_hash_map = embedding_device_cache_->device_hash_map_; | |||
| MS_EXCEPTION_IF_NULL(device_hash_map); | |||
| int index = 0; | |||
| auto iter = device_hash_map->id_iter(id); | |||
| if (device_hash_map->IsIdExist(iter)) { | |||
| *need_swap_device_to_host = false; | |||
| *need_swap_host_to_device = false; | |||
| index = iter->second; | |||
| if (device_hash_map->hash_step(index) != data_step_) { | |||
| statistics_info_.hash_hit_count_++; | |||
| device_hash_map->set_hash_step(index, data_step_); | |||
| } | |||
| } else { | |||
| auto tmp_device_to_host_size = statistics_info_.device_to_host_size_; | |||
| while (true) { | |||
| index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, | |||
| &(statistics_info_.device_to_host_size_)); | |||
| if (index == INVALID_INDEX_VALUE) { | |||
| WaitGraphRun(); | |||
| continue; | |||
| } | |||
| host_to_device_index[statistics_info_.host_to_device_size_] = index; | |||
| host_to_device_ids[statistics_info_.host_to_device_size_] = id; | |||
| statistics_info_.host_to_device_size_++; | |||
| *need_swap_device_to_host = statistics_info_.device_to_host_size_ > tmp_device_to_host_size; | |||
| break; | |||
| } | |||
| } | |||
| return index; | |||
| } | |||
| void PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | |||
| int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | |||
| int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); | |||
| int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); | |||
| int *host_to_device_index = embedding_host_cache_->host_to_device_index.get(); | |||
| MS_EXCEPTION_IF_NULL(host_to_server_index); | |||
| MS_EXCEPTION_IF_NULL(host_to_server_ids); | |||
| MS_EXCEPTION_IF_NULL(server_to_host_index); | |||
| MS_EXCEPTION_IF_NULL(server_to_host_ids); | |||
| MS_EXCEPTION_IF_NULL(host_to_device_index); | |||
| auto host_hash_map = embedding_host_cache_->host_hash_map_; | |||
| MS_EXCEPTION_IF_NULL(host_hash_map); | |||
| auto iter = host_hash_map->id_iter(id); | |||
| if (host_hash_map->IsIdExist(iter)) { | |||
| auto index = iter->second; | |||
| if (host_hash_map->hash_step(index) != data_step_) { | |||
| host_hash_map->set_hash_step(index, data_step_); | |||
| } | |||
| host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; | |||
| } else { | |||
| while (true) { | |||
| auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | |||
| graph_running_step_, &statistics_info_.host_to_server_size_); | |||
| if (index == INVALID_INDEX_VALUE) { | |||
| WaitGraphRun(); | |||
| continue; | |||
| } | |||
| host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; | |||
| server_to_host_index[statistics_info_.server_to_host_size_] = index; | |||
| server_to_host_ids[statistics_info_.server_to_host_size_++] = id; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| void PsCacheManager::ParseHostDataDeviceToHost(size_t id) { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); | |||
| int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | |||
| int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | |||
| int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); | |||
| MS_EXCEPTION_IF_NULL(device_to_host_ids); | |||
| MS_EXCEPTION_IF_NULL(host_to_server_index); | |||
| MS_EXCEPTION_IF_NULL(host_to_server_ids); | |||
| MS_EXCEPTION_IF_NULL(device_to_host_index); | |||
| auto host_hash_map = embedding_host_cache_->host_hash_map_; | |||
| MS_EXCEPTION_IF_NULL(host_hash_map); | |||
| int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1]; | |||
| auto iter = host_hash_map->id_iter(swap_device_to_host_id); | |||
| if (host_hash_map->IsIdExist(iter)) { | |||
| auto index = iter->second; | |||
| if (host_hash_map->hash_step(index) != data_step_) { | |||
| host_hash_map->set_hash_step(index, data_step_); | |||
| } | |||
| device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; | |||
| } else { | |||
| while (true) { | |||
| auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | |||
| graph_running_step_, &statistics_info_.host_to_server_size_); | |||
| if (index == INVALID_INDEX_VALUE) { | |||
| WaitGraphRun(); | |||
| continue; | |||
| } | |||
| device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, | |||
| const float *input_addr, const int *indices_addr, float *output_addr) { | |||
| auto type_size = sizeof(float); | |||
| size_t lens = outer_dim_size * type_size; | |||
| for (size_t i = 0; i < indices_lens; ++i) { | |||
| int index = indices_addr[i]; | |||
| if (index >= 0 && index < SizeToInt(first_dim_size)) { | |||
| size_t pos = index * outer_dim_size; | |||
| auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; | |||
| } | |||
| } else { | |||
| auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; | |||
| } | |||
| } | |||
| output_addr += outer_dim_size; | |||
| } | |||
| } | |||
| void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||
| const int *indices_addr, float *output_addr) { | |||
| size_t first_dim_size = host_cache_vocab_size_; | |||
| size_t outer_dim_size = embedding_size; | |||
| size_t thread_num = indices_lens / 10000 + 1; | |||
| thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; | |||
| std::thread threads[kMaxThreadNum]; | |||
| size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num; | |||
| size_t i; | |||
| size_t task_offset = 0; | |||
| MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens; | |||
| for (i = 0; i < thread_num; i++) { | |||
| if (task_offset >= indices_lens) { | |||
| break; | |||
| } | |||
| MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens; | |||
| threads[i] = std::thread(&PsCacheManager::LookUpTableTask, this, task_proc_lens, outer_dim_size, first_dim_size, | |||
| hash_table_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size); | |||
| task_offset += task_proc_lens; | |||
| if (task_offset + task_proc_lens > indices_lens) { | |||
| task_proc_lens = indices_lens - task_offset; | |||
| } | |||
| } | |||
| for (size_t j = 0; j < i; j++) { | |||
| threads[j].join(); | |||
| } | |||
| } | |||
| void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, | |||
| float *insert_data, float *hash_table_addr) { | |||
| size_t first_dim_size = host_cache_vocab_size_; | |||
| size_t thread_num = insert_indices_size / 10000 + 1; | |||
| thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; | |||
| std::thread threads[kMaxThreadNum]; | |||
| size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num; | |||
| size_t i; | |||
| size_t task_offset = 0; | |||
| auto insert_hash_table_task = [](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, | |||
| int *insert_indices, float *insert_data, float *hash_table_addr) { | |||
| auto type_size = sizeof(float); | |||
| size_t lens = outer_dim_size * type_size; | |||
| for (size_t i = 0; i < insert_indices_size; ++i) { | |||
| int index = insert_indices[i]; | |||
| if (index >= 0 && index < SizeToInt(first_dim_size)) { | |||
| auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Insert hash table task memcpy failed."; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| for (i = 0; i < thread_num; i++) { | |||
| if (task_offset >= insert_indices_size) { | |||
| break; | |||
| } | |||
| MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens; | |||
| threads[i] = std::thread(insert_hash_table_task, task_proc_lens, embedding_size, first_dim_size, | |||
| insert_indices + task_offset, insert_data + task_offset * embedding_size, hash_table_addr); | |||
| task_offset += task_proc_lens; | |||
| if (task_offset + task_proc_lens > insert_indices_size) { | |||
| task_proc_lens = insert_indices_size - task_offset; | |||
| } | |||
| } | |||
| for (size_t j = 0; j < i; j++) { | |||
| threads[j].join(); | |||
| } | |||
| } | |||
| void PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get(); | |||
| auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get(); | |||
| auto swap_indices_size = statistics_info_.host_to_device_size_; | |||
| if (swap_indices_size == 0) { | |||
| return; | |||
| } | |||
| auto embedding_size = hash_info.embedding_size; | |||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||
| auto hash_table_size = hash_info.device_address.size; | |||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||
| auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size); | |||
| LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_cache_host_to_device_index, | |||
| swap_out_data.get()); | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, | |||
| swap_out_data.get(), | |||
| swap_indices_size * embedding_size * sizeof(float)); | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice( | |||
| embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index, swap_indices_size * sizeof(int)); | |||
| embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||
| embedding_size, swap_indices_size); | |||
| } | |||
| void PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| auto swap_indices_size = statistics_info_.device_to_host_size_; | |||
| auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get(); | |||
| auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get(); | |||
| if (swap_indices_size == 0) { | |||
| return; | |||
| } | |||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||
| auto hash_table_size = hash_info.device_address.size; | |||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||
| auto embedding_size = hash_info.embedding_size; | |||
| auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size); | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice( | |||
| embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index, swap_indices_size * sizeof(int)); | |||
| embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||
| embedding_size, swap_indices_size); | |||
| embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data.get(), | |||
| embedding_device_cache_->hash_swap_value_addr_, | |||
| swap_indices_size * embedding_size * sizeof(float)); | |||
| embedding_device_cache_->cache_->SynchronizeStream(); | |||
| InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index, | |||
| swap_out_data.get(), host_hash_table_addr); | |||
| } | |||
| void PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) { | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | |||
| auto host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | |||
| auto swap_indices_size = statistics_info_.host_to_server_size_; | |||
| if (swap_indices_size == 0) { | |||
| return; | |||
| } | |||
| ::ps::SArray<int> lookup_ids(swap_indices_size, 0); | |||
| ::ps::SArray<float> swap_out_data; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| swap_out_data.resize(swap_indices_size * embedding_size); | |||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||
| LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, | |||
| swap_out_data.data()); | |||
| auto copy_len = swap_indices_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| } | |||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| } | |||
| void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) { | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| auto swap_indices_size = statistics_info_.server_to_host_size_; | |||
| auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); | |||
| auto server_to_host_index = embedding_host_cache_->server_to_host_index.get(); | |||
| if (swap_indices_size == 0) { | |||
| return; | |||
| } | |||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||
| auto embedding_size = hash_info.embedding_size; | |||
| ::ps::SArray<int> lengths{swap_indices_size}; | |||
| ::ps::SArray<float> lookup_result(swap_indices_size * embedding_size, 0); | |||
| ::ps::SArray<int> lookup_ids(swap_indices_size, 0); | |||
| auto copy_len = swap_indices_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| } | |||
| worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(), | |||
| host_hash_table_addr); | |||
| } | |||
| void PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, | |||
| const HashTableInfo &hash_info) { | |||
| MS_EXCEPTION_IF_NULL(swap_out_index); | |||
| MS_EXCEPTION_IF_NULL(swap_out_data); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| auto swap_out_index_size = statistics_info_.device_to_host_size_; | |||
| if (swap_out_index_size == 0) { | |||
| return; | |||
| } | |||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||
| auto hash_table_size = hash_info.device_address.size; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| swap_out_data->resize(swap_out_index_size * embedding_size); | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_out_index, | |||
| swap_out_index_size * sizeof(int)); | |||
| embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||
| embedding_size, swap_out_index_size); | |||
| embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data->data(), | |||
| embedding_device_cache_->hash_swap_value_addr_, | |||
| swap_out_index_size * embedding_size * sizeof(float)); | |||
| embedding_device_cache_->cache_->RecordEvent(); | |||
| } | |||
| void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, | |||
| size_t key) { | |||
| MS_EXCEPTION_IF_NULL(swap_in_ids); | |||
| MS_EXCEPTION_IF_NULL(swap_in_index); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| auto swap_in_ids_size = statistics_info_.host_to_device_size_; | |||
| if (swap_in_ids_size == 0) { | |||
| return; | |||
| } | |||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||
| auto hash_table_size = hash_info.device_address.size; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| // Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device). | |||
| ::ps::SArray<int> lengths{swap_in_ids_size}; | |||
| ::ps::SArray<float> lookup_result(swap_in_ids_size * embedding_size, 0); | |||
| ::ps::SArray<int> lookup_ids(swap_in_ids_size, 0); | |||
| auto copy_len = swap_in_ids_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| } | |||
| worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| // Hash swap-in in device. | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, | |||
| lookup_result.data(), | |||
| swap_in_ids_size * embedding_size * sizeof(float)); | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_in_index, | |||
| swap_in_ids_size * sizeof(int)); | |||
| embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||
| embedding_size, swap_in_ids_size); | |||
| } | |||
| void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_EXCEPTION_IF_NULL(swap_out_ids); | |||
| auto swap_out_ids_size = statistics_info_.device_to_host_size_; | |||
| if (swap_out_ids_size == 0) { | |||
| return; | |||
| } | |||
| ::ps::SArray<int> lookup_ids(swap_out_ids_size, 0); | |||
| auto copy_len = swap_out_ids_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| } | |||
| // Need synchronize event to ensure that the swap-out in device is completed. | |||
| embedding_device_cache_->cache_->SynchronizeEvent(); | |||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| } | |||
| void PsCacheManager::DumpHashTables() const { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| for (const auto &item : hash_tables_) { | |||
| const auto ¶m_name = item.first; | |||
| size_t cache_vocab_size = item.second.cache_vocab_size; | |||
| size_t embedding_size = item.second.embedding_size; | |||
| size_t vocab_size = item.second.vocab_size; | |||
| MS_LOG(INFO) << "Dump hash tables: " << param_name << " || " << cache_vocab_size << " || " << embedding_size | |||
| << " || " << vocab_size << " || " << reinterpret_cast<void *>(item.second.device_address.addr) | |||
| << " || " << reinterpret_cast<void *>(item.second.host_address.get()); | |||
| float *output = new float[item.second.device_address.size / 4]; | |||
| embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr, | |||
| item.second.device_address.size); | |||
| embedding_device_cache_->cache_->SynchronizeStream(); | |||
| for (size_t i = 0; i < cache_vocab_size; i++) { | |||
| for (size_t j = 0; j < embedding_size; j++) { | |||
| std::cout << output[i * embedding_size + j] << " "; | |||
| } | |||
| std::cout << std::endl; | |||
| } | |||
| std::cout << std::endl; | |||
| delete[] output; | |||
| } | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -49,6 +49,7 @@ struct HashTableInfo { | |||
| size_t vocab_size{0}; | |||
| Address device_address{nullptr, 0}; | |||
| std::shared_ptr<int[]> host_address{nullptr}; | |||
| ParamInitInfo param_init_info_; | |||
| }; | |||
| struct EmbeddingDeviceCache { | |||
| @@ -158,6 +159,8 @@ class PsCacheManager { | |||
| void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key); | |||
| void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, | |||
| const int *indices_addr, float *output_addr); | |||
| bool CheckFinishInsertInitInfo() const; | |||
| void AddEmbeddingTable() const; | |||
| bool initialized_ps_cache_{false}; | |||
| std::string channel_name_; | |||
| @@ -167,6 +170,7 @@ class PsCacheManager { | |||
| size_t data_step_{0}; | |||
| std::mutex data_mutex_; | |||
| std::condition_variable data_prase_; | |||
| std::condition_variable insert_init_info_; | |||
| std::map<std::string, HashTableInfo> hash_tables_; | |||
| std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_; | |||
| @@ -178,6 +182,8 @@ class PsCacheManager { | |||
| size_t batch_elements_{0}; | |||
| PsCacheStatisticsInfo statistics_info_; | |||
| std::pair<size_t, size_t> range_bound_; | |||
| std::atomic_bool finish_insert_init_info_{false}; | |||
| std::atomic_bool finish_init_parameter_server_{false}; | |||
| }; | |||
| static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | |||
| @@ -26,7 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| class PsDataPrefetch { | |||
| class EXPORT PsDataPrefetch { | |||
| public: | |||
| EXPORT static PsDataPrefetch &GetInstance() { | |||
| static PsDataPrefetch instance; | |||
| @@ -17,6 +17,11 @@ | |||
| #include "ps/ps_context.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -80,5 +85,43 @@ bool PSContext::is_role_sched() const { return is_sched_; } | |||
| void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; } | |||
| int PSContext::ps_rank_id() const { return rank_id_; } | |||
| void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, | |||
| size_t vocab_size) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size); | |||
| #endif | |||
| } | |||
| void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, | |||
| size_t cache_vocab_size, size_t embedding_size) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size); | |||
| #endif | |||
| } | |||
| void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed); | |||
| #endif | |||
| } | |||
| void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| ps_cache_instance.InsertAccumuInitInfo(param_name, init_val); | |||
| #endif | |||
| } | |||
| void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| ps_cache_instance.CloneHashTable(dest_param_name, src_param_name); | |||
| #endif | |||
| } | |||
| void PSContext::set_cache_enable(bool cache_enable) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); | |||
| #endif | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -44,6 +44,14 @@ class PSContext { | |||
| bool is_role_sched() const; | |||
| void SetPSRankId(int rank_id); | |||
| int ps_rank_id() const; | |||
| void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, | |||
| size_t vocab_size) const; | |||
| void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, | |||
| size_t cache_vocab_size, size_t embedding_size) const; | |||
| void InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const; | |||
| void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const; | |||
| void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const; | |||
| void set_cache_enable(bool cache_enable) const; | |||
| private: | |||
| PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {} | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/random_normal/random_normal.h" | |||
| #include <iostream> | |||
| #include <thread> | |||
| #include <memory> | |||
| #include "utils/convert_utils_base.h" | |||
| #include "pybind_api/random_normal/random_cpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| bool InitRandomNormal(float mean, float stddev, std::vector<size_t> out_shape, size_t global_seed, size_t op_seed, | |||
| float *output_data) { | |||
| if (out_shape.size() == 0) { | |||
| std::cout << "output data shape is error" << std::endl; | |||
| } | |||
| int64_t total_count = 1; | |||
| for (uint32_t i = 0; i < out_shape.size(); i++) { | |||
| total_count *= SizeToLong(out_shape[i]); | |||
| } | |||
| uint32_t thread_num = 16; | |||
| if (total_count <= thread_num) { | |||
| thread_num = 1; | |||
| } | |||
| float *start_ptr = output_data; | |||
| if (start_ptr == nullptr) { | |||
| std::cout << "start_ptr is nullptr" << std::endl; | |||
| return false; | |||
| } | |||
| int64_t batchSize = total_count / thread_num; | |||
| std::vector<std::thread> threads(thread_num); | |||
| int64_t seed = SizeToLong(global_seed); | |||
| int64_t seed2 = SizeToLong(op_seed); | |||
| seed = (seed == 0 && seed2 == 0) ? clock() : seed; | |||
| PhiloxGenerator generator = PhiloxGenerator(seed, seed2); | |||
| if (thread_num != 1) { | |||
| for (uint32_t i = 0; i < thread_num - 1; i++) { | |||
| float *offset_ptr = start_ptr + batchSize * i; | |||
| threads[i] = | |||
| std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator, offset_ptr, batchSize, i); | |||
| } | |||
| float *offset_ptr = start_ptr + batchSize * (thread_num - 1); | |||
| threads[thread_num - 1] = std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator, | |||
| offset_ptr, total_count - (thread_num - 1) * batchSize, thread_num - 1); | |||
| } else { | |||
| threads[0] = | |||
| std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator, start_ptr, total_count, 0); | |||
| } | |||
| for (uint32_t i = 0; i < thread_num; i++) { | |||
| threads[i].join(); | |||
| } | |||
| for (int64_t i = 0; i < total_count; i++) { | |||
| output_data[i] *= stddev; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,27 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_ | |||
| #define MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_ | |||
| #include <vector> | |||
| namespace mindspore { | |||
| namespace ps { | |||
| bool InitRandomNormal(float mean, float stddev, std::vector<size_t> out_shape, size_t global_seed, size_t op_seed, | |||
| float *output_data); | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_ | |||
| @@ -26,6 +26,15 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| enum ParamType { kUnKnown = 0, kWeight = 1, kAccumulation = 2 }; | |||
| struct ParamInitInfo { | |||
| ParamType param_type_{kUnKnown}; | |||
| size_t global_seed_{0}; | |||
| size_t op_seed_{0}; | |||
| float init_val_{0}; | |||
| }; | |||
| class Util { | |||
| public: | |||
| static bool IsParamServerMode(); | |||
| @@ -32,6 +32,7 @@ | |||
| #include "ps/common.h" | |||
| #include "ps/worker_proxy.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -47,15 +48,19 @@ class Worker { | |||
| void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes); | |||
| void Pull(const size_t key, void *dev_addr, const size_t size); | |||
| size_t SetParamKey(const std::string ¶m_name); | |||
| size_t GetParamKey(const std::string ¶m_name); | |||
| void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); | |||
| bool GetParamInitInServer(const std::string ¶m_name); | |||
| void SetKeyOptimId(size_t key, const std::string &optimizer_name); | |||
| void SetOptimInputShapes(size_t key, const ShapeVector &shape); | |||
| void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); | |||
| void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const ShapeVector &sizes); | |||
| void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes); | |||
| void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor); | |||
| void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd); | |||
| void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<T> &vals); | |||
| bool running() { return running_; } | |||
| void Finalize(); | |||
| private: | |||
| @@ -65,7 +70,6 @@ class Worker { | |||
| Worker &operator=(const Worker &) = delete; | |||
| bool IsKeyInit(const size_t key); | |||
| size_t GetParamKey(const std::string ¶m_name); | |||
| void InitPSOptimId(const size_t param_key); | |||
| void InitPSOptimInputShapes(const size_t key); | |||
| void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size); | |||
| @@ -187,6 +191,12 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const : | |||
| kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd); | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<T> &vals) { | |||
| kv_worker_->UpdateEmbeddingTable(keys, lookup_ids, vals); | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::Finalize() { | |||
| if (running_) { | |||
| @@ -286,7 +296,7 @@ size_t Worker<T>::GetParamKey(const std::string ¶m_name) { | |||
| size_t key = kInvalidKey; | |||
| if (param_to_key_.find(param_name) != param_to_key_.end()) { | |||
| key = param_to_key_[param_name]; | |||
| MS_LOG(INFO) << "Get key of parameter " << param_name << " key is " << key; | |||
| MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key; | |||
| } | |||
| return key; | |||
| } | |||
| @@ -310,8 +320,7 @@ void Worker<T>::InitPSOptimId(const size_t param_key) { | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, | |||
| const ShapeVector &sizes) { | |||
| void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes) { | |||
| bool has_init = IsKeyInit(keys[0]); | |||
| if (has_init) { | |||
| MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized."; | |||
| @@ -319,7 +328,7 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto | |||
| } | |||
| ::ps::SArray<T> shapes_val; | |||
| for (auto dim : shapes) { | |||
| shapes_val.push_back(static_cast<T>(dim)); | |||
| shapes_val.push_back(dim); | |||
| } | |||
| std::vector<int> sizes_int; | |||
| (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int), | |||
| @@ -337,9 +346,6 @@ void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor:: | |||
| const std::string ¶m_name = pk_node->fullname_with_scope(); | |||
| void *param_data = tensor->data_c(); | |||
| size_t param_size = LongToSize(tensor->data().nbytes()); | |||
| if (param_size > INT_MAX) { | |||
| MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " << param_size; | |||
| } | |||
| size_t param_key = GetParamKey(param_name); | |||
| if (param_key == kInvalidKey) { | |||
| @@ -357,11 +363,17 @@ void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor:: | |||
| MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name | |||
| << ", whether init in server: " << init_in_server; | |||
| kv_worker_->AddKeyToServerId(param_key); | |||
| if (!init_in_server) { | |||
| InitPSParamData({param_key}, param_data, param_size); | |||
| if (!PsDataPrefetch::GetInstance().cache_enable()) { | |||
| if (!init_in_server) { | |||
| if (param_size > INT_MAX) { | |||
| MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " | |||
| << param_size; | |||
| } | |||
| InitPSParamData({param_key}, param_data, param_size); | |||
| } | |||
| InitPSOptimId(param_key); | |||
| InitPSOptimInputShapes(param_key); | |||
| } | |||
| InitPSOptimId(param_key); | |||
| InitPSOptimInputShapes(param_key); | |||
| } | |||
| } | |||
| @@ -45,6 +45,7 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||
| explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id) | |||
| : Worker(app_id, customer_id) { | |||
| server_num_ = ::ps::NumServers(); | |||
| MS_LOG(INFO) << "Server num:" << server_num_; | |||
| PSContext::instance()->SetPSRankId(::ps::MyRank()); | |||
| using std::placeholders::_1; | |||
| using std::placeholders::_2; | |||
| @@ -60,6 +61,7 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||
| broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4, _5); | |||
| round_robin_slicer_ = std::bind(&WorkerProxy<T>::RoundRobinSlicer, this, _1, _2, _3, _4, _5); | |||
| worker_init_embedding_slicer_ = std::bind(&WorkerProxy<T>::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5); | |||
| update_embedding_slicer_ = std::bind(&WorkerProxy<T>::UpdateEmbeddingSlicer, this, _1, _2, _3, _4, _5); | |||
| } | |||
| ~WorkerProxy() override = default; | |||
| @@ -70,6 +72,8 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||
| const Callback &cb = nullptr, int64_t priority = 0); | |||
| int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, | |||
| const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int64_t priority = 0); | |||
| void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<T> &vals, const Callback &cb = nullptr, int64_t priority = 0); | |||
| bool IsReadyForPush(const Key &key); | |||
| bool IsReadyForPull(const Key &key); | |||
| void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {}, | |||
| @@ -98,6 +102,9 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||
| void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void ProcessLookupResult(const ::ps::Message &msg); | |||
| void ProcessResponse(const ::ps::Message &msg); | |||
| void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs<T> &kvs, | |||
| @@ -122,6 +129,7 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||
| Slicer broadcast_slicer_; | |||
| Slicer round_robin_slicer_; | |||
| Slicer worker_init_embedding_slicer_; | |||
| Slicer update_embedding_slicer_; | |||
| std::unordered_map<int64_t, Callback> lookup_callbacks_; | |||
| std::unordered_map<int64_t, Callback> general_callbacks_; | |||
| std::unordered_map<int64_t, int64_t> expected_result_count_; | |||
| @@ -195,6 +203,24 @@ int64_t WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, | |||
| return ts; | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<T> &vals, const Callback &cb, int64_t priority) { | |||
| int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.lens = lookup_ids; | |||
| kvs.vals = vals; | |||
| kvs.priority = priority; | |||
| expected_result_count_[ts] = 0; | |||
| Send(general_customer_.get(), ts, true, false, kUpdateEmbeddingsCmd, kvs, update_embedding_slicer_); | |||
| if (expected_result_count_[ts] < server_num_) { | |||
| general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); | |||
| } | |||
| general_customer_->WaitRequest(ts); | |||
| expected_result_count_.erase(ts); | |||
| } | |||
| template <typename T> | |||
| bool WorkerProxy<T>::IsReadyForPush(const Key &key) { | |||
| ::ps::SArray<T> result(1, 0); | |||
| @@ -724,6 +750,47 @@ void WorkerProxy<T>::WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KV | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, | |||
| const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attrs) { | |||
| MS_EXCEPTION_IF_NULL(sliced); | |||
| T *embedding_vals = send.vals.data(); | |||
| int *lookup_ids = send.lens.data(); | |||
| size_t val_size = send.vals.size(); | |||
| size_t id_size = send.lens.size(); | |||
| size_t embedding_dim = val_size / id_size; | |||
| const Key &key = send.keys[0]; | |||
| const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); | |||
| sliced->resize(ranges.size()); | |||
| for (size_t i = 0; i < ranges.size(); i++) { | |||
| const ::ps::Range &range = ranges[i]; | |||
| const auto &begin = range.begin(); | |||
| const auto &end = range.end(); | |||
| auto &kvs = sliced->at(i).second; | |||
| kvs.keys.push_back(key); | |||
| for (size_t j = 0; j < id_size; j++) { | |||
| auto lookup_id = static_cast<uint64_t>(lookup_ids[j]); | |||
| if (lookup_id >= begin && lookup_id <= end) { | |||
| kvs.keys.push_back(lookup_id); | |||
| for (size_t k = 0; k < embedding_dim; k++) { | |||
| kvs.vals.push_back(embedding_vals[j * embedding_dim + k]); | |||
| } | |||
| } | |||
| } | |||
| if (kvs.keys.size() <= 1) { | |||
| sliced->at(i).first = false; | |||
| } else { | |||
| sliced->at(i).first = true; | |||
| expected_result_count_[timestamp] += 1; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) { | |||
| int64_t ts = msg.meta.timestamp; | |||
| @@ -18,6 +18,7 @@ | |||
| #include <thread> | |||
| #include <memory> | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "ir/tensor.h" | |||
| namespace mindspore { | |||
| bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, | |||
| @@ -19,8 +19,7 @@ | |||
| #include "pybind_api/random_normal/philox_generator.h" | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind_api/api_register.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| #include "utils/log_adapter.h" | |||
| namespace py = pybind11; | |||
| @@ -55,6 +55,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||
| bool SyncStream() override; | |||
| void SetContext() override; | |||
| void CreateContext() override; | |||
| void *context() const override { return rt_context_; } | |||
| protected: | |||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | |||
| @@ -18,6 +18,7 @@ | |||
| #include "runtime/device/gpu/gpu_memory_allocator.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/convert_utils.h" | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| namespace mindspore { | |||
| namespace device { | |||
| namespace gpu { | |||
| @@ -38,6 +39,9 @@ void GPUMemoryManager::MallocDeviceMemory() { | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| // If use the dynamic memory pool, then alloc the first memory block to init. | |||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) { | |||
| if (ps::ps_cache_instance.initialized_ps_cache()) { | |||
| return; | |||
| } | |||
| auto device_addr = MallocMemFromMemPool(1); | |||
| if (!device_addr) { | |||
| MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; | |||
| @@ -30,6 +30,10 @@ | |||
| #include "utils/ms_utils.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "utils/utils.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| using mindspore::kernel::Address; | |||
| using mindspore::kernel::AddressPtr; | |||
| @@ -331,15 +335,27 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; | |||
| continue; | |||
| } | |||
| DeviceAddressPtr device_address = nullptr; | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| const std::string ¶m_name = item->fullname_with_scope(); | |||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name); | |||
| MS_EXCEPTION_IF_NULL(address.addr); | |||
| device_address = | |||
| CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||
| AnfAlgo::SetOutputAddr(device_address, index, item.get()); | |||
| continue; | |||
| } | |||
| #endif | |||
| auto tensor_size = CountNodeDeviceMemorySize(item, index); | |||
| auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||
| device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||
| MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope(); | |||
| if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { | |||
| if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address) == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; | |||
| } | |||
| MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope() | |||
| << " index: " << index << " size: " << tensor_size; | |||
| AnfAlgo::SetOutputAddr(address, index, item.get()); | |||
| AnfAlgo::SetOutputAddr(device_address, index, item.get()); | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "AssignStaticMemoryInput end"; | |||
| @@ -78,6 +78,7 @@ class KernelRuntime { | |||
| virtual void ClearGlobalIdleMem() {} | |||
| virtual void CreateContext() {} | |||
| virtual void SetContext() {} | |||
| virtual void *context() const { return nullptr; } | |||
| uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { | |||
| return mem_manager_->MallocMem(type, size, address); | |||
| } | |||
| @@ -15,6 +15,7 @@ | |||
| """Parameter for cell.""" | |||
| from copy import copy | |||
| import numbers | |||
| from .._c_expression import ParamInfo | |||
| from .._c_expression import MetaTensor as MetaTensor_ | |||
| from . import dtype as mstype | |||
| @@ -23,7 +24,10 @@ from .tensor import Tensor, MetaTensor | |||
| from .._checkparam import Validator | |||
| from ..parallel._tensor import _get_slice_index | |||
| from ..parallel._auto_parallel_context import auto_parallel_context | |||
| from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched | |||
| from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table | |||
| from ..parallel._ps_context import _reinsert_hash_table_size | |||
| from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info | |||
| from .seed import _get_global_and_op_seed | |||
| __all__ = ['Parameter', 'ParameterTuple'] | |||
| @@ -35,6 +39,18 @@ def _is_in_parallel_mode(): | |||
| """Get parallel mode.""" | |||
| return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"] | |||
| def init_to_value(init): | |||
| """Get value of initializer.""" | |||
| if isinstance(init, str): | |||
| if init == 'zeros': | |||
| return 0.0 | |||
| if init == 'ones': | |||
| return 1.0 | |||
| raise ValueError("init should be one of values in 'zeros', 'ones'.") | |||
| if isinstance(init, numbers.Number): | |||
| return float(init) | |||
| raise ValueError("init should be number or string") | |||
| class Parameter(MetaTensor_): | |||
| """ | |||
| @@ -118,6 +134,8 @@ class Parameter(MetaTensor_): | |||
| def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False): | |||
| self._param_info = ParamInfo() | |||
| self.init_in_server = False | |||
| self.cache_enable = False | |||
| self.name = name | |||
| self.requires_grad = requires_grad | |||
| self.layerwise_parallel = layerwise_parallel | |||
| @@ -129,7 +147,6 @@ class Parameter(MetaTensor_): | |||
| self._sliced = False | |||
| self.is_param_ps = False | |||
| self._cast_type = None | |||
| self.init_in_server = False | |||
| self._unique = False | |||
| self.is_in_parallel = _is_in_parallel_mode() | |||
| if isinstance(default_input, (MetaTensor, Tensor)): | |||
| @@ -155,7 +172,7 @@ class Parameter(MetaTensor_): | |||
| if isinstance(data, bool): | |||
| raise ValueError('Parameter data can not be `bool`') | |||
| if isinstance(data, MetaTensor): | |||
| if _is_in_parallel_mode() or _is_role_worker(): | |||
| if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched(): | |||
| # do not init data while in auto parallel. | |||
| return (MetaTensor_, data.dtype, data.shape) | |||
| data = data.to_tensor() | |||
| @@ -189,18 +206,18 @@ class Parameter(MetaTensor_): | |||
| init_in_server (bool): Whether trainable parameter updated by parameter server is | |||
| initialized on server. Default: False. | |||
| """ | |||
| if _is_role_worker() or _is_role_pserver() or _is_role_sched(): | |||
| if init_in_server and (not self.name.endswith("embedding_table")): | |||
| raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " | |||
| "sparse operator support initialization in server.".format(self.name)) | |||
| self.is_param_ps = True | |||
| self.init_in_server = init_in_server | |||
| self._param_info.init_in_server = init_in_server | |||
| else: | |||
| if not(_is_role_worker() or _is_role_pserver() or _is_role_sched()): | |||
| raise RuntimeError("Must complete following two steps before calling set_param_ps: \ | |||
| 1. set_ps_context(enable_ps=True) \ | |||
| 2. export MS_ROLE environment variable.") | |||
| if init_in_server and (not self.name.endswith("embedding_table")): | |||
| raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " | |||
| "sparse operator support initialization in server.".format(self.name)) | |||
| self.is_param_ps = True | |||
| self.init_in_server = init_in_server | |||
| self._param_info.init_in_server = init_in_server | |||
| @property | |||
| def inited_param(self): | |||
| @@ -238,6 +255,13 @@ class Parameter(MetaTensor_): | |||
| format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) | |||
| else: | |||
| raise ValueError("The type of the name should be `str` or `None`.") | |||
| if _is_role_worker() and self.cache_enable: | |||
| if len(self.shape) != 2: | |||
| raise RuntimeError("The dims of parameter '{}' must be 2, but got {}." | |||
| .format(self.name, len(self.shape))) | |||
| _reinsert_hash_table_size(name_, self._param_info.name, self.shape[0], self.shape[1]) | |||
| self._param_info.name = name_ | |||
| @property | |||
| @@ -297,6 +321,7 @@ class Parameter(MetaTensor_): | |||
| x.is_init = False | |||
| x.is_param_ps = self.is_param_ps | |||
| x.init_in_server = self.init_in_server | |||
| x.cache_enable = self.cache_enable | |||
| if init != 'same': | |||
| shape = self.shape | |||
| dtype = self.dtype | |||
| @@ -431,15 +456,18 @@ class Parameter(MetaTensor_): | |||
| raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) | |||
| slice_index = int(_get_slice_index(layout[0], layout[1])) | |||
| if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): | |||
| if _is_role_worker(): | |||
| if _is_role_worker() or _is_role_sched(): | |||
| data = self.init_mode.to_tensor(0, [1]) | |||
| else: | |||
| data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) | |||
| else: | |||
| data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) | |||
| else: | |||
| if _is_role_worker() and self.cache_enable: | |||
| global_seed, op_seed = _get_global_and_op_seed() | |||
| _insert_weight_init_info(self.name, global_seed, op_seed) | |||
| if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): | |||
| if _is_role_worker(): | |||
| if _is_role_worker() or _is_role_sched(): | |||
| data = self.init_mode.to_tensor(0, [1]) | |||
| else: | |||
| data = self.init_mode.to_tensor() | |||
| @@ -502,6 +530,16 @@ class ParameterTuple(tuple): | |||
| x1 = x.clone(init) | |||
| x1.name = prefix + "." + x1.name | |||
| new.append(x1) | |||
| if not x1.cache_enable: | |||
| continue | |||
| if not x1.name.endswith("embedding_table"): | |||
| raise RuntimeError("Can not enable cache for parameter '{}', Only parameters of " | |||
| "sparse operator support enable cache.".format(x1.name)) | |||
| if _is_role_worker(): | |||
| _clone_hash_table(x.name, x1.name) | |||
| _insert_accumu_init_info(x1.name, init_to_value(init)) | |||
| return ParameterTuple(new) | |||
| def __parameter_tuple__(self): | |||
| @@ -195,6 +195,20 @@ def _get_op_seed(op_seed, kernel_name): | |||
| return _KERNEL_SEED[(kernel_name, op_seed)] | |||
| def _get_global_and_op_seed(): | |||
| """Get global_seed and op_seed.""" | |||
| global_seed = get_seed() | |||
| op_seed = get_seed() | |||
| if global_seed == 0: | |||
| global_seed = DEFAULT_GRAPH_SEED | |||
| elif global_seed is None: | |||
| global_seed = 0 | |||
| Validator.check_non_negative_int(op_seed, "seed", "init") | |||
| temp_seed = _get_op_seed(op_seed, "init") | |||
| seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) | |||
| return seeds | |||
| def _get_graph_seed(op_seed, kernel_name): | |||
| """ | |||
| Get the graph-level seed. | |||
| @@ -73,6 +73,8 @@ inline size_t FloatToSize(float u) { | |||
| } | |||
| inline float IntToFloat(int32_t v) { return static_cast<float>(v); } | |||
| inline float SizeToFloat(size_t v) { return static_cast<float>(v); } | |||
| inline double LongToDouble(int64_t v) { return static_cast<double>(v); } | |||
| inline double FloatToDouble(float v) { return static_cast<double>(v); } | |||
| @@ -22,6 +22,7 @@ from mindspore.common.initializer import initializer | |||
| from mindspore.communication.management import get_group_size | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.parallel._utils import _get_parallel_mode | |||
| from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker | |||
| from mindspore._checkparam import Rel | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore.ops.primitive import constexpr | |||
| @@ -156,6 +157,7 @@ class EmbeddingLookup(Cell): | |||
| max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 | |||
| or None. Default: None | |||
| sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. | |||
| vocab_cache_size (int): Cache size of the dictionary of embeddings. | |||
| Inputs: | |||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | |||
| @@ -185,7 +187,7 @@ class EmbeddingLookup(Cell): | |||
| def __init__(self, vocab_size, embedding_size, param_init='normal', | |||
| target='CPU', slice_mode='batch_slice', manual_shapes=None, | |||
| max_norm=None, sparse=True): | |||
| max_norm=None, sparse=True, vocab_cache_size=0): | |||
| super(EmbeddingLookup, self).__init__() | |||
| self.target = target | |||
| if target not in ('CPU', 'DEVICE'): | |||
| @@ -199,11 +201,23 @@ class EmbeddingLookup(Cell): | |||
| self.gatherv2 = P.GatherV2() | |||
| self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | |||
| self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) | |||
| self.vocab_cache_size = validator.check_value_type('vocab_cache_size', vocab_cache_size, [int], self.cls_name) | |||
| self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) | |||
| self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | |||
| name='embedding_table') | |||
| parallel_mode = _get_parallel_mode() | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| self.cache_enable = self.vocab_cache_size > 0 | |||
| if self.cache_enable: | |||
| if is_auto_parallel: | |||
| self.vocab_cache_size = self.vocab_cache_size * get_group_size() | |||
| self.vocab_size = self.vocab_cache_size | |||
| self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | |||
| name='embedding_table') | |||
| if self.cache_enable: | |||
| self.embedding_table.cache_enable = True | |||
| _set_cache_enable(True) | |||
| if _is_role_worker(): | |||
| _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) | |||
| self.forward_unique = False | |||
| self.gather_revert = P.GatherV2() | |||
| self.unique = P.Unique().shard(((1,),)) | |||
| @@ -222,7 +236,7 @@ class EmbeddingLookup(Cell): | |||
| self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) | |||
| self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) | |||
| elif slice_mode == "table_row_slice" and is_auto_parallel: | |||
| if target == 'DEVICE': | |||
| if target == 'DEVICE' and not self.cache_enable: | |||
| indices_shape_size = 1 | |||
| self.gather_revert.shard(((1, 1), (get_group_size(),))) | |||
| self.forward_unique = True | |||
| @@ -88,14 +88,14 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d | |||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||
| def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, | |||
| beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter): | |||
| beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable): | |||
| """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" | |||
| success = True | |||
| indices = gradient.indices | |||
| values = gradient.values | |||
| if ps_parameter: | |||
| if ps_parameter and not cache_enable: | |||
| op_shape = P.Shape() | |||
| shapes = (op_shape(param), op_shape(m), op_shape(v), | |||
| op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), | |||
| @@ -158,12 +158,13 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, | |||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, | |||
| beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2, ps_parameter): | |||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, | |||
| beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, | |||
| moment1, moment2, ps_parameter, cache_enable): | |||
| """Apply adam optimizer to the weight parameter using Tensor.""" | |||
| success = True | |||
| if ps_parameter: | |||
| if ps_parameter and not cache_enable: | |||
| op_shape = P.Shape() | |||
| success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), | |||
| (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) | |||
| @@ -338,12 +339,12 @@ class Adam(Optimizer): | |||
| success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| self.use_locking, self.use_nesterov, self._is_device, | |||
| beta1_power, beta2_power, self.beta1, self.beta2, self.eps), | |||
| lr, gradients, params, moment1, moment2, self.ps_parameters) | |||
| lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable) | |||
| else: | |||
| success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| self.use_locking, self.use_nesterov, self._is_device, | |||
| beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), | |||
| gradients, params, moment1, moment2, self.ps_parameters) | |||
| gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable) | |||
| return success | |||
| @Optimizer.target.setter | |||
| @@ -24,14 +24,14 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | |||
| @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", | |||
| "RowTensor", "Tensor", "Tensor", "Bool") | |||
| "RowTensor", "Tensor", "Tensor", "Bool", "Bool") | |||
| def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, | |||
| gradient, weight, moment, ps_parameter): | |||
| gradient, weight, moment, ps_parameter, cache_enable): | |||
| """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" | |||
| success = True | |||
| indices = gradient.indices | |||
| values = gradient.values | |||
| if ps_parameter: | |||
| if ps_parameter and not cache_enable: | |||
| op_shape = P.Shape() | |||
| shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) | |||
| success = F.depend(success, pull(push((values, indices), shapes), weight)) | |||
| @@ -41,12 +41,12 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, le | |||
| @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", | |||
| "Tensor", "Tensor", "Tensor", "Bool") | |||
| "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||
| def _tensor_run_opt(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, | |||
| gradient, weight, moment, ps_parameter): | |||
| gradient, weight, moment, ps_parameter, cache_enable): | |||
| """Apply ftrl optimizer to the weight parameter.""" | |||
| success = True | |||
| if ps_parameter: | |||
| if ps_parameter and not cache_enable: | |||
| op_shape = P.Shape() | |||
| success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power), | |||
| (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) | |||
| @@ -185,7 +185,7 @@ class FTRL(Optimizer): | |||
| success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | |||
| self.l1, self.l2, self.lr_power, lr), | |||
| linear, grads, params, moments, self.ps_parameters) | |||
| linear, grads, params, moments, self.ps_parameters, self.cache_enable) | |||
| return success | |||
| @Optimizer.target.setter | |||
| @@ -156,6 +156,8 @@ class Optimizer(Cell): | |||
| break | |||
| ps_filter = lambda x: x.is_param_ps | |||
| self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) | |||
| ps_cache_filter = lambda x: x.cache_enable | |||
| self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters) | |||
| self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) | |||
| self.need_scale = loss_scale != 1.0 | |||
| self.global_step_increase_tensor = Tensor(1, mstype.int32) | |||
| @@ -117,3 +117,21 @@ def _is_role_pserver(): | |||
| def _is_role_sched(): | |||
| return ps_context().is_role_sched() | |||
| def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size): | |||
| ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size) | |||
| def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size): | |||
| ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size) | |||
| def _insert_weight_init_info(name, global_seed, op_seed): | |||
| ps_context().insert_weight_init_info(name, global_seed, op_seed) | |||
| def _insert_accumu_init_info(name, init_val): | |||
| ps_context().insert_accumu_init_info(name, init_val) | |||
| def _clone_hash_table(dest_param_name, src_param_name): | |||
| ps_context().clone_hash_table(dest_param_name, src_param_name) | |||
| def _set_cache_enable(cache_enable): | |||
| ps_context().set_cache_enable(cache_enable) | |||
| @@ -92,6 +92,10 @@ def connect_network_with_dataset(network, dataset_helper): | |||
| if isinstance(dataset_iter, _DatasetIterNormal): | |||
| raise RuntimeError("Dataset should be connected with network only in sink mode.") | |||
| ms_role = os.getenv("MS_ROLE") | |||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||
| return network | |||
| if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ | |||
| and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ | |||
| and context.get_context("device_target") == "Ascend" \ | |||
| @@ -166,14 +170,14 @@ class DatasetHelper: | |||
| iterclass = _DatasetIterGE | |||
| else: | |||
| if context.get_context("mode") == context.GRAPH_MODE: | |||
| if context.get_context("device_target") == "Ascend": | |||
| ms_role = os.getenv("MS_ROLE") | |||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||
| iterclass = _DatasetIterPSServer | |||
| elif ms_role == "MS_WORKER": | |||
| iterclass = _DatasetIterPSWork | |||
| elif (context.get_context("device_target") == "Ascend") or \ | |||
| (context.get_context("device_target") == "GPU"): | |||
| iterclass = _DatasetIterMSLoopSink | |||
| elif context.get_context("device_target") == "GPU": | |||
| ms_role = os.getenv("MS_ROLE") | |||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||
| iterclass = _DatasetIterPSLite | |||
| else: | |||
| iterclass = _DatasetIterMSLoopSink | |||
| elif context.get_context("device_target") == "CPU": | |||
| raise RuntimeError( | |||
| "Currently dataset sink mode is not supported when the device target is CPU.") | |||
| @@ -218,7 +222,10 @@ class _DatasetIter: | |||
| if not hasattr(dataset, '__transfer_dataset__'): | |||
| if hasattr(dataset, '__loop_size__'): | |||
| self.sink_size = dataset.__loop_size__ | |||
| ms_role = os.getenv("MS_ROLE") | |||
| # PS mode does not support loop sink and need get the real sink size. | |||
| if ms_role != "MS_WORKER": | |||
| self.sink_size = dataset.__loop_size__ | |||
| create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and context.get_context( | |||
| "device_target") == "Ascend") | |||
| dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size, | |||
| @@ -260,8 +267,12 @@ class _DatasetIter: | |||
| def get_sink_size(self): | |||
| """get sink_size to device""" | |||
| sink_size = 1 | |||
| ms_role = os.getenv("MS_ROLE") | |||
| if hasattr(self.dataset, '__loop_size__'): | |||
| sink_size = self.dataset.__loop_size__ | |||
| elif ms_role == "MS_WORKER": | |||
| # PS mode does not support loop sink. | |||
| sink_size = 1 | |||
| else: | |||
| if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend" \ | |||
| or context.get_context("device_target") == "GPU": | |||
| @@ -311,9 +322,6 @@ class _DatasetIterMSLoopSink(_DatasetIter): | |||
| def __init__(self, dataset, sink_size, epoch_num): | |||
| super().__init__(dataset, sink_size, epoch_num) | |||
| self.sink_count = self.get_sink_count(dataset) | |||
| ms_role = os.getenv("MS_ROLE") | |||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||
| self.sink_count = 1 | |||
| # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, | |||
| # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for | |||
| # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. | |||
| @@ -341,8 +349,8 @@ class _DatasetIterMS(_DatasetIter): | |||
| self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | |||
| class _DatasetIterPSLite(_DatasetIter): | |||
| """Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED""" | |||
| class _DatasetIterPSServer(_DatasetIter): | |||
| """Iter for context on MS_PSERVER or MS_SCHED""" | |||
| def __init__(self, dataset, sink_size, epoch_num): | |||
| super().__init__(dataset, sink_size, epoch_num) | |||
| @@ -355,6 +363,20 @@ class _DatasetIterPSLite(_DatasetIter): | |||
| self.op = op | |||
| class _DatasetIterPSWork(_DatasetIter): | |||
| """Iter for context on MS_WORKER""" | |||
| def __init__(self, dataset, sink_size, epoch_num): | |||
| super().__init__(dataset, sink_size, epoch_num) | |||
| if sink_size > 0: | |||
| self.sink_count = sink_size | |||
| else: | |||
| self.sink_count = dataset.get_dataset_size() | |||
| def op(): | |||
| return tuple() | |||
| self.op = op | |||
| class _DatasetIterNormal: | |||
| """Iter for normal(non sink) mode, feed the data from host.""" | |||
| @@ -30,6 +30,7 @@ def argparse_init(): | |||
| parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.") | |||
| parser.add_argument("--field_size", type=int, default=39, help="The number of features.") | |||
| parser.add_argument("--vocab_size", type=int, default=200000, help="The total features of dataset.") | |||
| parser.add_argument("--vocab_cache_size", type=int, default=0, help="The total features of hash table.") | |||
| parser.add_argument("--emb_dim", type=int, default=80, help="The dense embedding dimension of sparse feature.") | |||
| parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128], | |||
| help="The dimension of all deep layers.") | |||
| @@ -66,6 +67,7 @@ class WideDeepConfig(): | |||
| self.eval_batch_size = 16000 | |||
| self.field_size = 39 | |||
| self.vocab_size = 200000 | |||
| self.vocab_cache_size = 100000 | |||
| self.emb_dim = 80 | |||
| self.deep_layer_dim = [1024, 512, 256, 128] | |||
| self.deep_layer_act = 'relu' | |||
| @@ -103,6 +105,7 @@ class WideDeepConfig(): | |||
| self.eval_batch_size = args.eval_batch_size | |||
| self.field_size = args.field_size | |||
| self.vocab_size = args.vocab_size | |||
| self.vocab_cache_size = args.vocab_cache_size | |||
| self.emb_dim = args.emb_dim | |||
| self.deep_layer_dim = args.deep_layer_dim | |||
| self.deep_layer_act = args.deep_layer_act | |||
| @@ -147,6 +147,7 @@ class WideDeepModel(nn.Cell): | |||
| sparse = config.sparse | |||
| self.field_size = config.field_size | |||
| self.vocab_size = config.vocab_size | |||
| self.vocab_cache_size = config.vocab_cache_size | |||
| self.emb_dim = config.emb_dim | |||
| self.deep_layer_dims_list = config.deep_layer_dim | |||
| self.deep_layer_act = config.deep_layer_act | |||
| @@ -237,8 +238,20 @@ class WideDeepModel(nn.Cell): | |||
| self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) | |||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||
| elif parameter_server: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) | |||
| cache_enable = self.vocab_cache_size > 0 | |||
| target = 'DEVICE' if cache_enable else 'CPU' | |||
| if is_auto_parallel and config.full_batch and cache_enable: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, | |||
| slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, | |||
| sparse=sparse, vocab_cache_size=self.vocab_cache_size) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, | |||
| slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, | |||
| sparse=sparse, vocab_cache_size=self.vocab_cache_size) | |||
| else: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, | |||
| sparse=sparse, vocab_cache_size=self.vocab_cache_size) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, sparse=sparse, | |||
| vocab_cache_size=self.vocab_cache_size) | |||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||
| self.deep_embeddinglookup.embedding_table.set_param_ps() | |||
| self.wide_embeddinglookup.embedding_table.set_param_ps() | |||
| @@ -344,7 +357,7 @@ class TrainStepWrap(nn.Cell): | |||
| """ | |||
| def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, | |||
| sparse=False): | |||
| sparse=False, cache_enable=False): | |||
| super(TrainStepWrap, self).__init__() | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| @@ -361,7 +374,7 @@ class TrainStepWrap(nn.Cell): | |||
| self.weights_w = ParameterTuple(weights_w) | |||
| self.weights_d = ParameterTuple(weights_d) | |||
| if (sparse and is_auto_parallel) or parameter_server: | |||
| if (sparse and is_auto_parallel) or (parameter_server and not cache_enable): | |||
| self.optimizer_d = LazyAdam( | |||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | |||
| self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, | |||
| @@ -417,10 +430,17 @@ class TrainStepWrap(nn.Cell): | |||
| class PredictWithSigmoid(nn.Cell): | |||
| """ | |||
| Predict definition | |||
| """ | |||
| def __init__(self, network): | |||
| super(PredictWithSigmoid, self).__init__() | |||
| self.network = network | |||
| self.sigmoid = P.Sigmoid() | |||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| if is_auto_parallel: | |||
| self.sigmoid.shard(((1, 1),)) | |||
| def construct(self, batch_ids, batch_wts, labels): | |||
| logits, _, = self.network(batch_ids, batch_wts) | |||
| @@ -39,7 +39,8 @@ def get_WideDeep_net(config): | |||
| """ | |||
| WideDeep_net = WideDeepModel(config) | |||
| loss_net = NetWithLossClass(WideDeep_net, config) | |||
| train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server)) | |||
| train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server), | |||
| cache_enable=bool(config.vocab_cache_size > 0)) | |||
| eval_net = PredictWithSigmoid(WideDeep_net) | |||
| return train_net, eval_net | |||
| @@ -81,6 +82,7 @@ def train_and_eval(config): | |||
| else: | |||
| dataset_type = DataType.H5 | |||
| parameter_server = bool(config.parameter_server) | |||
| cache_enable = bool(config.vocab_cache_size > 0) | |||
| print("epochs is {}".format(epochs)) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=batch_size, rank_id=get_rank(), | |||
| @@ -111,7 +113,7 @@ def train_and_eval(config): | |||
| callback_list.append(ckpoint_cb) | |||
| model.train(epochs, ds_train, | |||
| callbacks=callback_list, | |||
| dataset_sink_mode=(not parameter_server)) | |||
| dataset_sink_mode=(parameter_server and cache_enable)) | |||
| if __name__ == "__main__": | |||