| @@ -194,6 +194,14 @@ if (ENABLE_GPU) | |||||
| ) | ) | ||||
| endif () | endif () | ||||
| if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||||
| install( | |||||
| TARGETS ps_cache | |||||
| DESTINATION ${INSTALL_LIB_DIR} | |||||
| COMPONENT mindspore | |||||
| ) | |||||
| endif() | |||||
| if (ENABLE_SERVING OR ENABLE_TESTCASES) | if (ENABLE_SERVING OR ENABLE_TESTCASES) | ||||
| file(GLOB_RECURSE LIBEVENT_LIB_LIST | file(GLOB_RECURSE LIBEVENT_LIB_LIST | ||||
| ${libevent_LIBPATH}/libevent* | ${libevent_LIBPATH}/libevent* | ||||
| @@ -308,7 +308,7 @@ elseif (CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||||
| else () | else () | ||||
| if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | if (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | ||||
| target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) | target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) | ||||
| target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core) | |||||
| target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) | |||||
| if (${ENABLE_IBVERBS} STREQUAL "ON") | if (${ENABLE_IBVERBS} STREQUAL "ON") | ||||
| target_link_libraries(mindspore ibverbs rdmacm) | target_link_libraries(mindspore ibverbs rdmacm) | ||||
| endif() | endif() | ||||
| @@ -75,6 +75,9 @@ void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector<AddressPtr> &inputs | |||||
| if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) { | if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) { | ||||
| node_so_ = CUST_AICPU_OPS_SO_NAME; | node_so_ = CUST_AICPU_OPS_SO_NAME; | ||||
| node_name_ = kCustRunApi; | node_name_ = kCustRunApi; | ||||
| } else if (kCacheKernelOps.find(node_name_) != kCacheKernelOps.end()) { | |||||
| node_so_ = AICPU_OPS_SO_NAME; | |||||
| node_name_ = kCustRunApi; | |||||
| } else { | } else { | ||||
| node_so_ = AICPU_OPS_SO_NAME; | node_so_ = AICPU_OPS_SO_NAME; | ||||
| } | } | ||||
| @@ -161,6 +164,9 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr> | |||||
| if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) { | if (kCustAiCpuKernelOps.find(node_name_) != kCustAiCpuKernelOps.end()) { | ||||
| node_so_ = CUST_AICPU_OPS_SO_NAME; | node_so_ = CUST_AICPU_OPS_SO_NAME; | ||||
| node_name_ = kCustRunApi; | node_name_ = kCustRunApi; | ||||
| } else if (kCacheKernelOps.find(node_name_) != kCacheKernelOps.end()) { | |||||
| node_so_ = AICPU_OPS_SO_NAME; | |||||
| node_name_ = kCustRunApi; | |||||
| } else { | } else { | ||||
| node_so_ = AICPU_OPS_SO_NAME; | node_so_ = AICPU_OPS_SO_NAME; | ||||
| } | } | ||||
| @@ -49,6 +49,7 @@ constexpr auto kIdentity = "Identity"; | |||||
| constexpr auto kUpdateCache = "UpdateCache"; | constexpr auto kUpdateCache = "UpdateCache"; | ||||
| constexpr auto kCustRunApi = "RunCpuKernel"; | constexpr auto kCustRunApi = "RunCpuKernel"; | ||||
| const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity}; | const std::set<std::string> kCustAiCpuKernelOps{kEditDistance, kIdentity}; | ||||
| const std::set<std::string> kCacheKernelOps{kUpdateCache}; | |||||
| struct AicpuParamHead { | struct AicpuParamHead { | ||||
| uint32_t length; // Total length: include cunstom message | uint32_t length; // Total length: include cunstom message | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h" | #include "backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include <algorithm> | |||||
| #include "ps/worker.h" | #include "ps/worker.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -38,10 +39,13 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey); | key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey); | ||||
| } | } | ||||
| std::vector<size_t> keys{key_, key_, key_}; | std::vector<size_t> keys{key_, key_, key_}; | ||||
| std::vector<size_t> values; | |||||
| values.insert(values.end(), input_shape.begin(), input_shape.end()); | |||||
| values.insert(values.end(), indices_shape.begin(), indices_shape.end()); | |||||
| values.insert(values.end(), output_shape.begin(), output_shape.end()); | |||||
| std::vector<float> values; | |||||
| std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(values), | |||||
| [](size_t dim) -> float { return SizeToFloat(dim); }); | |||||
| std::transform(indices_shape.begin(), indices_shape.end(), std::back_inserter(values), | |||||
| [](size_t dim) -> float { return SizeToFloat(dim); }); | |||||
| std::transform(output_shape.begin(), output_shape.end(), std::back_inserter(values), | |||||
| [](size_t dim) -> float { return SizeToFloat(dim); }); | |||||
| MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape | MS_LOG(INFO) << "Init embedding lookup proxy kernel, input shape:" << input_shape | ||||
| << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; | << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; | ||||
| std::vector<int64_t> lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()), | std::vector<int64_t> lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()), | ||||
| @@ -72,6 +72,23 @@ bool EmbeddingLookUpPSKernel::Execute(const std::vector<AddressPtr> &inputs, con | |||||
| return Launch(inputs, workspace, outputs); | return Launch(inputs, workspace, outputs); | ||||
| } | } | ||||
| void EmbeddingLookUpPSKernel::UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, | |||||
| const float *update_vals, size_t ids_size) { | |||||
| size_t copy_lens = outer_dim_size_ * sizeof(float); | |||||
| for (size_t i = 0; i < ids_size; ++i) { | |||||
| int index = lookup_ids[i] - offset_; | |||||
| if (index >= 0 && index < SizeToInt(first_dim_size_)) { | |||||
| auto ret = | |||||
| memcpy_s(embedding_table + index * outer_dim_size_, copy_lens, update_vals + i * outer_dim_size_, copy_lens); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; | |||||
| } | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "UpdateEmbeddings index invalid."; | |||||
| } | |||||
| } | |||||
| } | |||||
| const std::vector<size_t> &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; } | const std::vector<size_t> &EmbeddingLookUpPSKernel::input_sizes() const { return input_shape_; } | ||||
| const std::vector<size_t> &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); } | const std::vector<size_t> &EmbeddingLookUpPSKernel::output_sizes() const { return GetOutputSizeList(); } | ||||
| @@ -35,7 +35,8 @@ class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerK | |||||
| bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals, | |||||
| size_t ids_size) override; | |||||
| const std::vector<size_t> &input_sizes() const override; | const std::vector<size_t> &input_sizes() const override; | ||||
| const std::vector<size_t> &output_sizes() const override; | const std::vector<size_t> &output_sizes() const override; | ||||
| const std::vector<size_t> &workspace_sizes() const override; | const std::vector<size_t> &workspace_sizes() const override; | ||||
| @@ -38,7 +38,8 @@ class PServerKernel { | |||||
| virtual void ReInit(const std::vector<std::vector<size_t>> &) {} | virtual void ReInit(const std::vector<std::vector<size_t>> &) {} | ||||
| virtual bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | virtual bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs) = 0; | const std::vector<AddressPtr> &outputs) = 0; | ||||
| virtual void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals, | |||||
| size_t ids_size) {} | |||||
| virtual const std::vector<size_t> &input_sizes() const = 0; | virtual const std::vector<size_t> &input_sizes() const = 0; | ||||
| virtual const std::vector<size_t> &output_sizes() const = 0; | virtual const std::vector<size_t> &output_sizes() const = 0; | ||||
| virtual const std::vector<size_t> &workspace_sizes() const = 0; | virtual const std::vector<size_t> &workspace_sizes() const = 0; | ||||
| @@ -56,6 +56,7 @@ | |||||
| #include "toolchain/adx_datadump_server.h" | #include "toolchain/adx_datadump_server.h" | ||||
| #if ENABLE_CPU && ENABLE_D | #if ENABLE_CPU && ENABLE_D | ||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "ps/ps_cache/ps_cache_manager.h" | |||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -487,11 +488,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) { | |||||
| // adjust kernel | // adjust kernel | ||||
| AdjustKernel(root_graph); | AdjustKernel(root_graph); | ||||
| #if ENABLE_CPU && ENABLE_D | #if ENABLE_CPU && ENABLE_D | ||||
| if (ps::Util::IsParamServerMode()) { | |||||
| CheckPSModeConsistence(root_graph); | |||||
| // Assign parameter keys. | |||||
| AssignParamKey(root_graph); | |||||
| } | |||||
| InitPsWorker(root_graph); | |||||
| #endif | #endif | ||||
| // assign stream | // assign stream | ||||
| AssignStream(NOT_NULL(root_graph)); | AssignStream(NOT_NULL(root_graph)); | ||||
| @@ -568,6 +565,9 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) { | |||||
| } | } | ||||
| // adjust execution order because merge child graph and other special operations | // adjust execution order because merge child graph and other special operations | ||||
| AdjustKernel(graph); | AdjustKernel(graph); | ||||
| #if ENABLE_CPU && ENABLE_D | |||||
| InitPsWorker(graph); | |||||
| #endif | |||||
| // Reorder optimizer order | // Reorder optimizer order | ||||
| auto execution_order = graph->execution_order(); | auto execution_order = graph->execution_order(); | ||||
| Reorder(&execution_order); | Reorder(&execution_order); | ||||
| @@ -644,6 +644,10 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tens | |||||
| #if ENABLE_CPU && ENABLE_D | #if ENABLE_CPU && ENABLE_D | ||||
| // Initialize parameter server | // Initialize parameter server | ||||
| InitPSParamAndOptim(kernel_graph, inputs); | InitPSParamAndOptim(kernel_graph, inputs); | ||||
| std::string channel_name; | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable() && IsGetNextGraph(graph_id, &channel_name)) { | |||||
| ps::ps_cache_instance.IncreaseGraphStep(channel_name); | |||||
| } | |||||
| #endif | #endif | ||||
| { | { | ||||
| // run task on device | // run task on device | ||||
| @@ -21,6 +21,9 @@ | |||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/scoped_long_running.h" | #include "utils/scoped_long_running.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -67,6 +67,7 @@ | |||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #if ENABLE_CPU && ENABLE_GPU | #if ENABLE_CPU && ENABLE_GPU | ||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "ps/ps_cache/ps_cache_manager.h" | |||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -243,6 +244,12 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | |||||
| auto input_node = input_nodes[i]; | auto input_node = input_nodes[i]; | ||||
| MS_EXCEPTION_IF_NULL(input_node); | MS_EXCEPTION_IF_NULL(input_node); | ||||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { | if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { | ||||
| #if ENABLE_CPU && ENABLE_GPU | |||||
| const std::string ¶m_name = input_node->fullname_with_scope(); | |||||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||||
| continue; | |||||
| } | |||||
| #endif | |||||
| auto pk_node = input_node->cast<ParameterPtr>(); | auto pk_node = input_node->cast<ParameterPtr>(); | ||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | ||||
| auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()); | ||||
| @@ -306,16 +313,11 @@ GraphId GPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||||
| HardwareOptimize(graph); | HardwareOptimize(graph); | ||||
| // Graph kernel fusion optimization | // Graph kernel fusion optimization | ||||
| GraphKernelOptimize(graph); | GraphKernelOptimize(graph); | ||||
| #if ENABLE_CPU && ENABLE_GPU | |||||
| if (ps::Util::IsParamServerMode()) { | |||||
| CheckPSModeConsistence(graph); | |||||
| // Assign parameter keys. | |||||
| AssignParamKey(graph); | |||||
| } | |||||
| #endif | |||||
| // Start gpu kernel runtime | // Start gpu kernel runtime | ||||
| StartKernelRT(); | StartKernelRT(); | ||||
| #if ENABLE_CPU && ENABLE_GPU | |||||
| InitPsWorker(graph); | |||||
| #endif | |||||
| // Assign CUDA streams | // Assign CUDA streams | ||||
| AssignStream(graph); | AssignStream(graph); | ||||
| // Dump .pb graph before remove nop nodes | // Dump .pb graph before remove nop nodes | ||||
| @@ -380,6 +382,12 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor: | |||||
| int kernel_num = kernel_graph->execution_order().size(); | int kernel_num = kernel_graph->execution_order().size(); | ||||
| int64_t loopsize = (kernel_num > 1) ? ConfigManager::GetInstance().gpu_loopsink_size() : 1; | int64_t loopsize = (kernel_num > 1) ? ConfigManager::GetInstance().gpu_loopsink_size() : 1; | ||||
| for (int64_t i = 0; i < loopsize; i++) { | for (int64_t i = 0; i < loopsize; i++) { | ||||
| #if ENABLE_CPU && ENABLE_GPU | |||||
| std::string channel_name; | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable() && IsGetNextGraph(graph_id, &channel_name)) { | |||||
| ps::ps_cache_instance.IncreaseGraphStep(channel_name); | |||||
| } | |||||
| #endif | |||||
| Execute(kernel_graph); | Execute(kernel_graph); | ||||
| } | } | ||||
| // In pynative mode, device addresses of tensors in value nodes need be clean. | // In pynative mode, device addresses of tensors in value nodes need be clean. | ||||
| @@ -41,8 +41,10 @@ | |||||
| #include "utils/trace_base.h" | #include "utils/trace_base.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| #include "ps/worker.h" | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | |||||
| #include "ps/common.h" | |||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "abstract/abstract_value.h" | |||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -1125,6 +1127,12 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||||
| size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); | size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); | ||||
| } | } | ||||
| if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| const std::string ¶m_name = input_node->fullname_with_scope(); | |||||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||||
| continue; | |||||
| } | |||||
| #endif | |||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); | ||||
| MS_EXCEPTION_IF_NULL(device_address); | MS_EXCEPTION_IF_NULL(device_address); | ||||
| if (size != 0 && !device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), size, | if (size != 0 && !device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), size, | ||||
| @@ -1715,8 +1723,64 @@ void SessionBasic::CleanUselessTensorsImpl(const std::shared_ptr<std::vector<ten | |||||
| } | } | ||||
| } | } | ||||
| bool SessionBasic::IsGetNextGraph(const GraphId &graph_id, std::string *channel_name) { | |||||
| auto kernel_graph = graphs_[graph_id]; | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| for (const auto &kernel_node : kernel_graph->execution_order()) { | |||||
| auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); | |||||
| if (kernel_name == kGetNextOpName) { | |||||
| auto prim = AnfAlgo::GetCNodePrimitive(kernel_node); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| *channel_name = GetValue<std::string>(prim->GetAttr("shared_name")); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { | |||||
| void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { | |||||
| if (!ps::Util::IsRoleOfWorker()) { | |||||
| return; | |||||
| } | |||||
| CheckPSModeConsistence(kernel_graph); | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||||
| if (!ps::ps_cache_instance.initialized_ps_cache()) { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| auto devcie_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET); | |||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(devcie_target, device_id_); | |||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||||
| auto context = runtime_instance->context(); | |||||
| const auto &kernels = kernel_graph->execution_order(); | |||||
| if (kernels.size() > 0 && AnfAlgo::GetCNodeName(kernels[0]) == "InitDataSetQueue") { | |||||
| GetBatchElements(kernels[0]); | |||||
| ps::ps_cache_instance.Initialize(); | |||||
| } | |||||
| ps::ps_cache_instance.DoProcessData(device_id_, context); | |||||
| } | |||||
| } else { | |||||
| // Assign parameter keys. | |||||
| AssignParamKey(kernel_graph); | |||||
| } | |||||
| } | |||||
| void SessionBasic::GetBatchElements(const AnfNodePtr &kernel_node) const { | |||||
| auto shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes"); | |||||
| auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "types"); | |||||
| if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) { | |||||
| MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size " | |||||
| << types; | |||||
| } | |||||
| size_t batch_elements = 1; | |||||
| const auto &shape = shapes[0]; | |||||
| for (size_t i = 0; i < shape.size(); ++i) { | |||||
| batch_elements *= shape[i]; | |||||
| } | |||||
| ps::ps_cache_instance.set_batch_elements(batch_elements); | |||||
| } | |||||
| void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const { | |||||
| auto input_nodes = kernel_graph->inputs(); | auto input_nodes = kernel_graph->inputs(); | ||||
| for (const auto &input_node : input_nodes) { | for (const auto &input_node : input_nodes) { | ||||
| if (!input_node->isa<Parameter>()) { | if (!input_node->isa<Parameter>()) { | ||||
| @@ -1725,8 +1789,9 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { | |||||
| auto pk_node = input_node->cast<ParameterPtr>(); | auto pk_node = input_node->cast<ParameterPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(pk_node); | MS_EXCEPTION_IF_NULL(pk_node); | ||||
| auto param_info_ptr = pk_node->param_info(); | auto param_info_ptr = pk_node->param_info(); | ||||
| if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) { | |||||
| const std::string ¶m_name = pk_node->fullname_with_scope(); | |||||
| const std::string ¶m_name = pk_node->fullname_with_scope(); | |||||
| if (param_info_ptr != nullptr && param_info_ptr->init_in_server() && | |||||
| !ps::ps_cache_instance.IsHashTable(param_name)) { | |||||
| MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name | MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name | ||||
| << "] in server, this parameter is used by kernel which executes in device"; | << "] in server, this parameter is used by kernel which executes in device"; | ||||
| } | } | ||||
| @@ -1734,10 +1799,6 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) { | |||||
| } | } | ||||
| void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | ||||
| if (!ps::Util::IsRoleOfWorker()) { | |||||
| MS_LOG(INFO) << "Not parameter server mode."; | |||||
| return; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); | std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); | ||||
| for (auto &node : node_list) { | for (auto &node : node_list) { | ||||
| @@ -1775,16 +1836,8 @@ void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, | |||||
| return; | return; | ||||
| } | } | ||||
| std::vector<tensor::TensorPtr> inputs(inputs_const); | std::vector<tensor::TensorPtr> inputs(inputs_const); | ||||
| size_t input_ctrl_size = 1; | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| if (kernel_graph->input_ctrl_tensors()) { | |||||
| input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); | |||||
| } | |||||
| auto input_nodes = kernel_graph->inputs(); | auto input_nodes = kernel_graph->inputs(); | ||||
| if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) { | |||||
| MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() | |||||
| << ", input_ctrl_size:" << input_ctrl_size; | |||||
| } | |||||
| auto ms_context = MsContext::GetInstance(); | auto ms_context = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(ms_context); | MS_EXCEPTION_IF_NULL(ms_context); | ||||
| for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
| @@ -99,9 +99,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| // get graph id in child graphs by ME front anf node pointer | // get graph id in child graphs by ME front anf node pointer | ||||
| virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const; | virtual GraphId GetGraphIdByNode(const AnfNodePtr &) const; | ||||
| virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } | virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } | ||||
| void CheckPSModeConsistence(const KernelGraphPtr &Kernel_graph); | |||||
| void AssignParamKey(const KernelGraphPtr &kernel_graph); | void AssignParamKey(const KernelGraphPtr &kernel_graph); | ||||
| void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const); | void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const); | ||||
| bool IsGetNextGraph(const GraphId &graph_id, std::string *channel_name); | |||||
| virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs, | virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs, | ||||
| std::string *error_msg) const { | std::string *error_msg) const { | ||||
| return true; | return true; | ||||
| @@ -195,6 +195,11 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list); | AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list); | ||||
| void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph); | void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph); | ||||
| void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs); | void UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| void CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const; | |||||
| void GetBatchElements(const AnfNodePtr &kernel_node) const; | |||||
| void InitPsWorker(const KernelGraphPtr &kernel_graph); | |||||
| #endif | |||||
| std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_; | ||||
| std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; | ||||
| @@ -207,6 +212,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| #if !defined(_WIN32) && !defined(_WIN64) | #if !defined(_WIN32) && !defined(_WIN64) | ||||
| std::shared_ptr<Debugger> debugger_; | std::shared_ptr<Debugger> debugger_; | ||||
| #endif | #endif | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| bool initialized_ps_cache_{false}; | |||||
| #endif | |||||
| }; | }; | ||||
| using SessionPtr = std::shared_ptr<session::SessionBasic>; | using SessionPtr = std::shared_ptr<session::SessionBasic>; | ||||
| @@ -24,6 +24,9 @@ | |||||
| #include "frontend/parallel/device_matrix.h" | #include "frontend/parallel/device_matrix.h" | ||||
| #include "frontend/parallel/graph_util/generate_graph.h" | #include "frontend/parallel/graph_util/generate_graph.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| @@ -514,6 +517,12 @@ Status GatherV2PInfo::InferBias() { | |||||
| if (repeated_calc_num_ > 1) { | if (repeated_calc_num_ > 1) { | ||||
| rank = rank / repeated_calc_num_; | rank = rank / repeated_calc_num_; | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||||
| bias_ = 0; | |||||
| return SUCCESS; | |||||
| } | |||||
| #endif | |||||
| bias_ = rank / params_strategy.at(1) * slice_size_; | bias_ = rank / params_strategy.at(1) * slice_size_; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -46,10 +46,18 @@ | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #include "ps/util.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace parallel { | namespace parallel { | ||||
| bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { | |||||
| return false; | |||||
| } | |||||
| #endif | |||||
| MS_EXCEPTION_IF_NULL(root); | MS_EXCEPTION_IF_NULL(root); | ||||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | ||||
| std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); | std::string parallel_mode = ParallelContext::GetInstance()->parallel_mode(); | ||||
| @@ -44,6 +44,9 @@ | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #include "ps/util.h" | |||||
| #endif | |||||
| using mindspore::tensor::Tensor; | using mindspore::tensor::Tensor; | ||||
| @@ -3036,6 +3039,11 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) { | |||||
| } | } | ||||
| bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { | bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { | |||||
| return false; | |||||
| } | |||||
| #endif | |||||
| MS_EXCEPTION_IF_NULL(root); | MS_EXCEPTION_IF_NULL(root); | ||||
| MS_EXCEPTION_IF_NULL(optimizer); | MS_EXCEPTION_IF_NULL(optimizer); | ||||
| MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); | ||||
| @@ -202,6 +202,7 @@ else () | |||||
| if (${ENABLE_IBVERBS} STREQUAL "ON") | if (${ENABLE_IBVERBS} STREQUAL "ON") | ||||
| target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) | target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) | ||||
| endif () | endif () | ||||
| target_link_libraries(_c_dataengine PRIVATE ps_cache) | |||||
| endif () | endif () | ||||
| endif () | endif () | ||||
| @@ -322,6 +322,7 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con | |||||
| bool profiling, int32_t *push_time) { | bool profiling, int32_t *push_time) { | ||||
| std::vector<device::DataItemGpu> items; | std::vector<device::DataItemGpu> items; | ||||
| double start_time; | double start_time; | ||||
| bool ps_data_prefetch = false; | |||||
| for (int i = 0; i < data_size.size(); i++) { | for (int i = 0; i < data_size.size(); i++) { | ||||
| device::DataItemGpu data_item; | device::DataItemGpu data_item; | ||||
| data_item.data_len_ = data_size[i]; | data_item.data_len_ = data_size[i]; | ||||
| @@ -334,6 +335,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con | |||||
| if (profiling) { | if (profiling) { | ||||
| start_time = ProfilingTime::GetCurMilliSecond(); | start_time = ProfilingTime::GetCurMilliSecond(); | ||||
| } | } | ||||
| // Data prefetch only when PS mode enables cache. | |||||
| if ((!ps_data_prefetch) && (items.size() > 0)) { | |||||
| ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_); | |||||
| ps_data_prefetch = true; | |||||
| } | |||||
| BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); | BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); | ||||
| if (profiling) { | if (profiling) { | ||||
| double end_time = ProfilingTime::GetCurMilliSecond(); | double end_time = ProfilingTime::GetCurMilliSecond(); | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "minddata/dataset/engine/datasetops/pipeline_op.h" | #include "minddata/dataset/engine/datasetops/pipeline_op.h" | ||||
| #include "minddata/dataset/engine/datasetops/repeat_op.h" | #include "minddata/dataset/engine/datasetops/repeat_op.h" | ||||
| #include "minddata/dataset/util/status.h" | #include "minddata/dataset/util/status.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||||
| #ifdef ENABLE_TDTQUE | #ifdef ENABLE_TDTQUE | ||||
| #include "minddata/dataset/util/queue.h" | #include "minddata/dataset/util/queue.h" | ||||
| @@ -17,6 +17,8 @@ | |||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "minddata/dataset/engine/perf/profiling.h" | #include "minddata/dataset/engine/perf/profiling.h" | ||||
| #include "minddata/dataset/util/log_adapter.h" | #include "minddata/dataset/util/log_adapter.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| static std::shared_ptr<TdtPlugin> instance_ptr_ = nullptr; | static std::shared_ptr<TdtPlugin> instance_ptr_ = nullptr; | ||||
| @@ -48,6 +50,10 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe | |||||
| if (profiling) { | if (profiling) { | ||||
| start_time = ProfilingTime::GetCurMilliSecond(); | start_time = ProfilingTime::GetCurMilliSecond(); | ||||
| } | } | ||||
| // Data prefetch only when PS mode enables cache. | |||||
| if (items.size() > 0) { | |||||
| ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_); | |||||
| } | |||||
| if (tdt::TdtHostPushData(channel_name, items) != 0) { | if (tdt::TdtHostPushData(channel_name, items) != 0) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -308,7 +308,14 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| .def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.") | .def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.") | ||||
| .def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.") | .def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.") | ||||
| .def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.") | .def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.") | ||||
| .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id."); | |||||
| .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.") | |||||
| .def("insert_hash_table_size", &PSContext::InsertHashTableSize, "Insert hash table size.") | |||||
| .def("reinsert_hash_table_size", &PSContext::ReInsertHashTableSize, | |||||
| "Insert hash table size with new parameter name.") | |||||
| .def("insert_weight_init_info", &PSContext::InsertWeightInitInfo, "Insert embedding table initialization seed.") | |||||
| .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.") | |||||
| .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.") | |||||
| .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not."); | |||||
| (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy") | ||||
| .def(py::init()) | .def(py::init()) | ||||
| @@ -52,6 +52,7 @@ | |||||
| #include "ps/common.h" | #include "ps/common.h" | ||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "ps/worker.h" | #include "ps/worker.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||||
| #endif | #endif | ||||
| #if (ENABLE_GE || ENABLE_D) | #if (ENABLE_GE || ENABLE_D) | ||||
| @@ -921,6 +922,11 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba | |||||
| bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, | bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, | ||||
| const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, | const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, | ||||
| const std::vector<int64_t> &input_indexes, bool need_run) { | const std::vector<int64_t> &input_indexes, bool need_run) { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| if ((ps::Util::IsParamServerMode()) && (!ps::Util::IsRoleOfWorker())) { | |||||
| return true; | |||||
| } | |||||
| #endif | |||||
| MS_LOG(INFO) << "Start InitDataSet Entry"; | MS_LOG(INFO) << "Start InitDataSet Entry"; | ||||
| ShapeVector int_input_indexes; | ShapeVector int_input_indexes; | ||||
| (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), | (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), | ||||
| @@ -966,7 +972,17 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | ||||
| backend->Link(runner.graph_id); | backend->Link(runner.graph_id); | ||||
| } | } | ||||
| ConfigManager::GetInstance().set_iter_num(size); | |||||
| // PS mode does not support loop sink. | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| if (ps::Util::IsRoleOfWorker()) { | |||||
| ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); | |||||
| ConfigManager::GetInstance().set_iter_num(1); | |||||
| } else { | |||||
| #endif | |||||
| ConfigManager::GetInstance().set_iter_num(size); | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| } | |||||
| #endif | |||||
| if (!(*runner.run)) { | if (!(*runner.run)) { | ||||
| // empty function | // empty function | ||||
| @@ -981,7 +997,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "InitDataSetVm End."; | MS_LOG(DEBUG) << "InitDataSetVm End."; | ||||
| return true; | return true; | ||||
| } | |||||
| } // namespace pipeline | |||||
| void ResetOpId() { mindspore::id_generator::reset_id(); } | void ResetOpId() { mindspore::id_generator::reset_id(); } | ||||
| @@ -14,22 +14,20 @@ if (NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/cluster_config.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/node.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | list(REMOVE_ITEM _PS_SRC_FILES "core/node_manager.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||||
| endif () | endif () | ||||
| if (NOT ENABLE_D) | if (NOT ENABLE_D) | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ascend/ascend_ps_cache.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||||
| endif() | endif() | ||||
| if (NOT ENABLE_GPU) | if (NOT ENABLE_GPU) | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_cache_manager.cc") | |||||
| endif() | endif() | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") | ||||
| list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") | list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") | ||||
| add_subdirectory(ps_cache) | add_subdirectory(ps_cache) | ||||
| set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | set_property(SOURCE ${_PS_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_PS) | ||||
| add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) | add_library(_mindspore_ps_obj OBJECT ${_PS_SRC_FILES}) | ||||
| @@ -64,6 +64,7 @@ constexpr int64_t kInitWeightToOptimIdCmd = 11; | |||||
| constexpr int64_t kInitOptimInputsShapeCmd = 12; | constexpr int64_t kInitOptimInputsShapeCmd = 12; | ||||
| constexpr int64_t kInitKeyToPushNodeIdCmd = 13; | constexpr int64_t kInitKeyToPushNodeIdCmd = 13; | ||||
| constexpr int64_t kInitEmbeddingsCmd = 20; | constexpr int64_t kInitEmbeddingsCmd = 20; | ||||
| constexpr int64_t kUpdateEmbeddingsCmd = 21; | |||||
| constexpr int64_t kCheckReadyForPushCmd = 25; | constexpr int64_t kCheckReadyForPushCmd = 25; | ||||
| constexpr int64_t kCheckReadyForPullCmd = 26; | constexpr int64_t kCheckReadyForPullCmd = 26; | ||||
| constexpr int64_t kEmbeddingLookupCmd = 30; | constexpr int64_t kEmbeddingLookupCmd = 30; | ||||
| @@ -51,6 +51,8 @@ | |||||
| #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" | #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" | ||||
| #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" | #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" | ||||
| #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" | #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||||
| #include "ps/random_normal/random_normal.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -100,6 +102,7 @@ class ParameterServer { | |||||
| void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | ||||
| void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | ||||
| void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | ||||
| void HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | |||||
| void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res); | ||||
| ParameterServer *ps_; | ParameterServer *ps_; | ||||
| @@ -118,13 +121,15 @@ class ParameterServer { | |||||
| void InitWeight(const Key &key, const WeightPtr &weight); | void InitWeight(const Key &key, const WeightPtr &weight); | ||||
| void InitGrad(const Key &key, const GradPtr &grad); | void InitGrad(const Key &key, const GradPtr &grad); | ||||
| void InitEmbeddingTable(const Key &key, | void InitEmbeddingTable(const Key &key, | ||||
| const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes); | |||||
| const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes, | |||||
| const ParamInitInfo ¶m_init_info); | |||||
| bool HasWeight(const Key &key); | bool HasWeight(const Key &key); | ||||
| void Finalize(); | void Finalize(); | ||||
| void UpdateWeights(); | void UpdateWeights(); | ||||
| void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); | void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); | ||||
| WeightPtr weight(const Key &key); | WeightPtr weight(const Key &key); | ||||
| void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res); | void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res); | ||||
| void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals); | |||||
| bool ReadyForUpdateWeights(); | bool ReadyForUpdateWeights(); | ||||
| bool ReadyForPush(const Key &key); | bool ReadyForPush(const Key &key); | ||||
| bool ReadyForPull(const Key &key); | bool ReadyForPull(const Key &key); | ||||
| @@ -193,6 +198,7 @@ void ParameterServer<T>::ServerHandler::Init() { | |||||
| handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush; | handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush; | ||||
| handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull; | handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull; | ||||
| handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup; | handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup; | ||||
| handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings; | |||||
| handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize; | handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize; | ||||
| } | } | ||||
| @@ -302,7 +308,17 @@ void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta | |||||
| for (int64_t k = 0; k < lens[2]; k++) { | for (int64_t k = 0; k < lens[2]; k++) { | ||||
| output_shape->push_back(static_cast<size_t>(req_data.vals[index++])); | output_shape->push_back(static_cast<size_t>(req_data.vals[index++])); | ||||
| } | } | ||||
| ps_->InitEmbeddingTable(key, shapes); | |||||
| ParamInitInfo param_init_info; | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||||
| param_init_info.param_type_ = static_cast<ParamType>(lens[3]); | |||||
| if (param_init_info.param_type_ == kWeight) { | |||||
| param_init_info.global_seed_ = static_cast<size_t>(lens[4]); | |||||
| param_init_info.op_seed_ = static_cast<size_t>(lens[5]); | |||||
| } else if (param_init_info.param_type_ == kAccumulation) { | |||||
| param_init_info.init_val_ = req_data.vals[index]; | |||||
| } | |||||
| } | |||||
| ps_->InitEmbeddingTable(key, shapes, param_init_info); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| @@ -338,6 +354,18 @@ void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta | |||||
| ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res); | ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res); | ||||
| } | } | ||||
| template <typename T> | |||||
| void ParameterServer<T>::ServerHandler::HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, | |||||
| const ::ps::KVPairs<T> &req_data, | |||||
| ::ps::KVPairs<T> *res) { | |||||
| std::unique_lock<std::mutex> lock(ps_->mutex()); | |||||
| MS_EXCEPTION_IF_NULL(res); | |||||
| const Key &key = req_data.keys[0]; | |||||
| const LookupIds &lookup_ids = req_data.keys.segment(1, req_data.keys.size()); | |||||
| const Values &update_vals = req_data.vals; | |||||
| ps_->UpdateEmbeddings(key, lookup_ids, update_vals); | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, | void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, | ||||
| ::ps::KVPairs<T> *res) { | ::ps::KVPairs<T> *res) { | ||||
| @@ -476,7 +504,8 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) { | |||||
| template <typename T> | template <typename T> | ||||
| void ParameterServer<T>::InitEmbeddingTable( | void ParameterServer<T>::InitEmbeddingTable( | ||||
| const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | |||||
| const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes, | |||||
| const ParamInitInfo ¶m_init_info) { | |||||
| MS_EXCEPTION_IF_NULL(shapes); | MS_EXCEPTION_IF_NULL(shapes); | ||||
| if (weights_.count(key) == 0) { | if (weights_.count(key) == 0) { | ||||
| std::shared_ptr<PServerKernel> lookup = | std::shared_ptr<PServerKernel> lookup = | ||||
| @@ -493,8 +522,18 @@ void ParameterServer<T>::InitEmbeddingTable( | |||||
| T *embedding_data = embedding->data(); | T *embedding_data = embedding->data(); | ||||
| std::default_random_engine engine; | std::default_random_engine engine; | ||||
| std::normal_distribution<float> random(0, 0.01); | std::normal_distribution<float> random(0, 0.01); | ||||
| for (size_t i = 0; i < total_dims; i++) { | |||||
| embedding_data[i] = random(engine); | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||||
| if (param_init_info.param_type_ == kWeight) { | |||||
| InitRandomNormal(0, 0.01, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_, embedding_data); | |||||
| } else if (param_init_info.param_type_ == kAccumulation) { | |||||
| for (size_t i = 0; i < total_dims; i++) { | |||||
| embedding_data[i] = param_init_info.init_val_; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| for (size_t i = 0; i < total_dims; i++) { | |||||
| embedding_data[i] = random(engine); | |||||
| } | |||||
| } | } | ||||
| weights_[key] = embedding; | weights_[key] = embedding; | ||||
| tokens_[key] = 0; | tokens_[key] = 0; | ||||
| @@ -673,6 +712,23 @@ void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, | |||||
| res->lens.push_back(res->vals.size()); | res->lens.push_back(res->vals.size()); | ||||
| } | } | ||||
| template <typename T> | |||||
| void ParameterServer<T>::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) { | |||||
| if (weights_.count(key) == 0) { | |||||
| MS_LOG(ERROR) << "Invalid embedding table key " << key; | |||||
| return; | |||||
| } | |||||
| if (embedding_lookup_ops_.count(key) == 0) { | |||||
| MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; | |||||
| return; | |||||
| } | |||||
| WeightPtr table_ptr = weights_[key]; | |||||
| MS_EXCEPTION_IF_NULL(table_ptr); | |||||
| std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key]; | |||||
| MS_EXCEPTION_IF_NULL(table_lookup_op); | |||||
| table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size()); | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| inline bool ParameterServer<T>::ReadyForUpdateWeights() { | inline bool ParameterServer<T>::ReadyForUpdateWeights() { | ||||
| return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); | return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); | ||||
| @@ -70,9 +70,16 @@ void PsCacheManager::InsertWeightInitInfo(const std::string ¶m_name, size_t | |||||
| MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; | MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; | ||||
| } | } | ||||
| auto &hash_table_info = iter->second; | auto &hash_table_info = iter->second; | ||||
| if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { | |||||
| return; | |||||
| } | |||||
| hash_table_info.param_init_info_.param_type_ = kWeight; | hash_table_info.param_init_info_.param_type_ = kWeight; | ||||
| hash_table_info.param_init_info_.global_seed_ = global_seed; | hash_table_info.param_init_info_.global_seed_ = global_seed; | ||||
| hash_table_info.param_init_info_.op_seed_ = op_seed; | hash_table_info.param_init_info_.op_seed_ = op_seed; | ||||
| if (CheckFinishInsertInitInfo()) { | |||||
| finish_insert_init_info_ = true; | |||||
| insert_init_info_.notify_one(); | |||||
| } | |||||
| } | } | ||||
| void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) { | void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) { | ||||
| @@ -81,8 +88,26 @@ void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float i | |||||
| MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; | MS_LOG(EXCEPTION) << "Can not find parameter[" << param_name << "] in hash table."; | ||||
| } | } | ||||
| auto &hash_table_info = iter->second; | auto &hash_table_info = iter->second; | ||||
| if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { | |||||
| return; | |||||
| } | |||||
| hash_table_info.param_init_info_.param_type_ = kAccumulation; | hash_table_info.param_init_info_.param_type_ = kAccumulation; | ||||
| hash_table_info.param_init_info_.init_val_ = init_val; | hash_table_info.param_init_info_.init_val_ = init_val; | ||||
| if (CheckFinishInsertInitInfo()) { | |||||
| finish_insert_init_info_ = true; | |||||
| insert_init_info_.notify_one(); | |||||
| } | |||||
| } | |||||
| bool PsCacheManager::CheckFinishInsertInitInfo() const { | |||||
| for (const auto &item : hash_tables_) { | |||||
| const auto &hash_table_info = item.second; | |||||
| const auto ¶m_init_info = hash_table_info.param_init_info_; | |||||
| if (param_init_info.param_type_ == kUnKnown) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | } | ||||
| void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) { | void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) { | ||||
| @@ -113,35 +138,49 @@ void PsCacheManager::Initialize() { | |||||
| } | } | ||||
| embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_); | embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, cache_vocab_size_); | ||||
| embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_); | embedding_host_cache_ = std::make_shared<EmbeddingHostCache>(batch_elements_, host_cache_vocab_size_); | ||||
| InitParameterServer(); | |||||
| AddEmbeddingTable(); | |||||
| AllocMemForHashTable(); | AllocMemForHashTable(); | ||||
| SetLocalIdRank(); | SetLocalIdRank(); | ||||
| initialized_ps_cache_ = true; | initialized_ps_cache_ = true; | ||||
| } | } | ||||
| void PsCacheManager::InitParameterServer() { | |||||
| void PsCacheManager::AddEmbeddingTable() const { | |||||
| for (const auto &item : hash_tables_) { | for (const auto &item : hash_tables_) { | ||||
| const auto ¶m_name = item.first; | const auto ¶m_name = item.first; | ||||
| size_t key = worker.SetParamKey(param_name); | size_t key = worker.SetParamKey(param_name); | ||||
| size_t row_count = item.second.vocab_size; | size_t row_count = item.second.vocab_size; | ||||
| std::vector<size_t> keys{key, key, key, key}; | |||||
| // if worker role | |||||
| worker.AddEmbeddingTable(key, row_count); | |||||
| } | |||||
| } | |||||
| void PsCacheManager::InitParameterServer() { | |||||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||||
| insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; }); | |||||
| for (const auto &item : hash_tables_) { | |||||
| const auto ¶m_name = item.first; | |||||
| size_t key = worker.SetParamKey(param_name); | |||||
| std::vector<size_t> keys{key, key, key, key, key, key}; | |||||
| std::vector<float> values{ | std::vector<float> values{ | ||||
| SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1}; | SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1}; | ||||
| std::vector<int64_t> lens{2, 2, 3}; | std::vector<int64_t> lens{2, 2, 3}; | ||||
| const auto &hash_table_info = item.second; | const auto &hash_table_info = item.second; | ||||
| const auto ¶m_init_info = hash_table_info.param_init_info_; | const auto ¶m_init_info = hash_table_info.param_init_info_; | ||||
| if (param_init_info.param_type_ == kWeight) { | if (param_init_info.param_type_ == kWeight) { | ||||
| lens.push_back(0); | |||||
| values.push_back(SizeToFloat(param_init_info.global_seed_)); | |||||
| values.push_back(SizeToFloat(param_init_info.op_seed_)); | |||||
| } else if (param_init_info.param_type_ == kAccumulation) { | |||||
| lens.push_back(1); | lens.push_back(1); | ||||
| values.push_back(param_init_info.init_val_); | |||||
| } else if (param_init_info.param_type_ == kAccumulation) { | |||||
| lens.push_back(2); | |||||
| } | } | ||||
| values.push_back(param_init_info.init_val_); | |||||
| lens.push_back(param_init_info.global_seed_); | |||||
| lens.push_back(param_init_info.op_seed_); | |||||
| // if worker role | // if worker role | ||||
| worker.AddEmbeddingTable(key, row_count); | |||||
| worker.InitPSEmbeddingTable(keys, values, lens); | worker.InitPSEmbeddingTable(keys, values, lens); | ||||
| } | } | ||||
| finish_init_parameter_server_ = true; | |||||
| data_prase_.notify_one(); | |||||
| } | } | ||||
| void PsCacheManager::AllocMemForHashTable() { | void PsCacheManager::AllocMemForHashTable() { | ||||
| @@ -208,10 +247,538 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||||
| if (graph_step_ >= UINT64_MAX) { | if (graph_step_ >= UINT64_MAX) { | ||||
| MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; | MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; | ||||
| } | } | ||||
| if (graph_step_ == 0) { | |||||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||||
| data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; }); | |||||
| } | |||||
| graph_step_++; | graph_step_++; | ||||
| set_channel_name(channel_name); | set_channel_name(channel_name); | ||||
| PsDataPrefetch::GetInstance().TryWakeChannel(channel_name); | PsDataPrefetch::GetInstance().TryWakeChannel(channel_name); | ||||
| data_prase_.notify_one(); | data_prase_.notify_one(); | ||||
| } | } | ||||
| void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | |||||
| if (!initialized_ps_cache_) { | |||||
| MS_LOG(EXCEPTION) << "PS cache does not init."; | |||||
| } | |||||
| auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | |||||
| process_data_thread.detach(); | |||||
| } | |||||
| void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| embedding_device_cache_->cache_->InitDevice(device_id, context); | |||||
| InitParameterServer(); | |||||
| while (true) { | |||||
| ProcessData(); | |||||
| } | |||||
| } | |||||
| void PsCacheManager::ProcessData() { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| struct timeval start_time, end_time; | |||||
| const uint64_t kUSecondInSecond = 1000000; | |||||
| (void)gettimeofday(&start_time, nullptr); | |||||
| auto channel = channel_name(); | |||||
| if (channel.empty()) { | |||||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||||
| data_prase_.wait(locker, [this] { return !channel_name_.empty(); }); | |||||
| } | |||||
| auto data = PsDataPrefetch::GetInstance().data(channel_name_); | |||||
| if (data == nullptr) { | |||||
| MS_LOG(INFO) << "No data process, channel name:" << channel_name_; | |||||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||||
| (void)data_prase_.wait_for(locker, std::chrono::milliseconds(100)); | |||||
| return; | |||||
| } | |||||
| IncreaseStep(); | |||||
| auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_); | |||||
| auto batch_ids = reinterpret_cast<int *>(data); | |||||
| auto batch_ids_len = data_size / sizeof(int); | |||||
| std::unique_ptr<int[]> hash_index(new int[batch_ids_len]); | |||||
| if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) { | |||||
| MS_LOG(EXCEPTION) << "Process data memset failed."; | |||||
| } | |||||
| // Get hash swap in/out index and ids. | |||||
| ParseData(batch_ids, batch_ids_len, hash_index.get()); | |||||
| for (const auto &item : hash_tables_) { | |||||
| auto key = worker.GetParamKey(item.first); | |||||
| auto hash_info = item.second; | |||||
| HashSwapHostToServer(key, hash_info); | |||||
| HashSwapDeviceToHost(hash_info); | |||||
| HashSwapServerToHost(key, hash_info); | |||||
| HashSwapHostToDevice(hash_info); | |||||
| } | |||||
| // Replace the batch_ids by hash index for getNext-op getting hash index as input. | |||||
| if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) { | |||||
| MS_LOG(EXCEPTION) << "Process data memcpy failed."; | |||||
| } | |||||
| embedding_device_cache_->cache_->SynchronizeStream(); | |||||
| // Finish the data process and notify data prefetch. | |||||
| PsDataPrefetch::GetInstance().FinalizeData(channel_name_); | |||||
| (void)gettimeofday(&end_time, nullptr); | |||||
| uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); | |||||
| cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); | |||||
| MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_ | |||||
| << ",graph step:" << graph_running_step_ << " channel name:" << channel_name_ | |||||
| << ", time cost:" << cost / 1000 << "ms)."; | |||||
| } | |||||
| void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { | |||||
| MS_EXCEPTION_IF_NULL(batch_ids); | |||||
| MS_EXCEPTION_IF_NULL(hash_index); | |||||
| for (size_t i = 0; i < batch_ids_len; i++) { | |||||
| bool need_swap_host_to_device = true; | |||||
| bool need_swap_device_to_host = true; | |||||
| auto id = batch_ids[i]; | |||||
| if ((id < SizeToInt(range_bound_.first)) || (id >= SizeToInt(range_bound_.second))) { | |||||
| hash_index[i] = -1; | |||||
| continue; | |||||
| } | |||||
| hash_index[i] = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); | |||||
| if (need_swap_host_to_device) { | |||||
| ParseHostDataHostToDevice(id); | |||||
| } | |||||
| if (need_swap_device_to_host) { | |||||
| ParseHostDataDeviceToHost(id); | |||||
| } | |||||
| } | |||||
| // Each 1000 step prints ps cache hit rate. | |||||
| if (data_step_ % 1000 == 0) { | |||||
| statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_; | |||||
| auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; | |||||
| MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%."; | |||||
| } | |||||
| } | |||||
| void PsCacheManager::WaitGraphRun() { | |||||
| MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes."; | |||||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||||
| if (!data_prase_.wait_for(locker, std::chrono::seconds(120), [this] { return graph_step_ > graph_running_step_; })) { | |||||
| MS_LOG(EXCEPTION) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_ | |||||
| << ", graph running step:" << graph_running_step_ << ")."; | |||||
| } | |||||
| set_current_graph_step(); | |||||
| } | |||||
| int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) { | |||||
| MS_EXCEPTION_IF_NULL(need_swap_device_to_host); | |||||
| MS_EXCEPTION_IF_NULL(need_swap_host_to_device); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); | |||||
| int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); | |||||
| int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); | |||||
| int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); | |||||
| MS_EXCEPTION_IF_NULL(device_to_host_index); | |||||
| MS_EXCEPTION_IF_NULL(device_to_host_ids); | |||||
| MS_EXCEPTION_IF_NULL(host_to_device_index); | |||||
| MS_EXCEPTION_IF_NULL(host_to_device_ids); | |||||
| auto device_hash_map = embedding_device_cache_->device_hash_map_; | |||||
| MS_EXCEPTION_IF_NULL(device_hash_map); | |||||
| int index = 0; | |||||
| auto iter = device_hash_map->id_iter(id); | |||||
| if (device_hash_map->IsIdExist(iter)) { | |||||
| *need_swap_device_to_host = false; | |||||
| *need_swap_host_to_device = false; | |||||
| index = iter->second; | |||||
| if (device_hash_map->hash_step(index) != data_step_) { | |||||
| statistics_info_.hash_hit_count_++; | |||||
| device_hash_map->set_hash_step(index, data_step_); | |||||
| } | |||||
| } else { | |||||
| auto tmp_device_to_host_size = statistics_info_.device_to_host_size_; | |||||
| while (true) { | |||||
| index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, | |||||
| &(statistics_info_.device_to_host_size_)); | |||||
| if (index == INVALID_INDEX_VALUE) { | |||||
| WaitGraphRun(); | |||||
| continue; | |||||
| } | |||||
| host_to_device_index[statistics_info_.host_to_device_size_] = index; | |||||
| host_to_device_ids[statistics_info_.host_to_device_size_] = id; | |||||
| statistics_info_.host_to_device_size_++; | |||||
| *need_swap_device_to_host = statistics_info_.device_to_host_size_ > tmp_device_to_host_size; | |||||
| break; | |||||
| } | |||||
| } | |||||
| return index; | |||||
| } | |||||
| void PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | |||||
| int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | |||||
| int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); | |||||
| int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); | |||||
| int *host_to_device_index = embedding_host_cache_->host_to_device_index.get(); | |||||
| MS_EXCEPTION_IF_NULL(host_to_server_index); | |||||
| MS_EXCEPTION_IF_NULL(host_to_server_ids); | |||||
| MS_EXCEPTION_IF_NULL(server_to_host_index); | |||||
| MS_EXCEPTION_IF_NULL(server_to_host_ids); | |||||
| MS_EXCEPTION_IF_NULL(host_to_device_index); | |||||
| auto host_hash_map = embedding_host_cache_->host_hash_map_; | |||||
| MS_EXCEPTION_IF_NULL(host_hash_map); | |||||
| auto iter = host_hash_map->id_iter(id); | |||||
| if (host_hash_map->IsIdExist(iter)) { | |||||
| auto index = iter->second; | |||||
| if (host_hash_map->hash_step(index) != data_step_) { | |||||
| host_hash_map->set_hash_step(index, data_step_); | |||||
| } | |||||
| host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; | |||||
| } else { | |||||
| while (true) { | |||||
| auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | |||||
| graph_running_step_, &statistics_info_.host_to_server_size_); | |||||
| if (index == INVALID_INDEX_VALUE) { | |||||
| WaitGraphRun(); | |||||
| continue; | |||||
| } | |||||
| host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; | |||||
| server_to_host_index[statistics_info_.server_to_host_size_] = index; | |||||
| server_to_host_ids[statistics_info_.server_to_host_size_++] = id; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| void PsCacheManager::ParseHostDataDeviceToHost(size_t id) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); | |||||
| int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | |||||
| int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | |||||
| int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); | |||||
| MS_EXCEPTION_IF_NULL(device_to_host_ids); | |||||
| MS_EXCEPTION_IF_NULL(host_to_server_index); | |||||
| MS_EXCEPTION_IF_NULL(host_to_server_ids); | |||||
| MS_EXCEPTION_IF_NULL(device_to_host_index); | |||||
| auto host_hash_map = embedding_host_cache_->host_hash_map_; | |||||
| MS_EXCEPTION_IF_NULL(host_hash_map); | |||||
| int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1]; | |||||
| auto iter = host_hash_map->id_iter(swap_device_to_host_id); | |||||
| if (host_hash_map->IsIdExist(iter)) { | |||||
| auto index = iter->second; | |||||
| if (host_hash_map->hash_step(index) != data_step_) { | |||||
| host_hash_map->set_hash_step(index, data_step_); | |||||
| } | |||||
| device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; | |||||
| } else { | |||||
| while (true) { | |||||
| auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | |||||
| graph_running_step_, &statistics_info_.host_to_server_size_); | |||||
| if (index == INVALID_INDEX_VALUE) { | |||||
| WaitGraphRun(); | |||||
| continue; | |||||
| } | |||||
| device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, | |||||
| const float *input_addr, const int *indices_addr, float *output_addr) { | |||||
| auto type_size = sizeof(float); | |||||
| size_t lens = outer_dim_size * type_size; | |||||
| for (size_t i = 0; i < indices_lens; ++i) { | |||||
| int index = indices_addr[i]; | |||||
| if (index >= 0 && index < SizeToInt(first_dim_size)) { | |||||
| size_t pos = index * outer_dim_size; | |||||
| auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; | |||||
| } | |||||
| } else { | |||||
| auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; | |||||
| } | |||||
| } | |||||
| output_addr += outer_dim_size; | |||||
| } | |||||
| } | |||||
| void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||||
| const int *indices_addr, float *output_addr) { | |||||
| size_t first_dim_size = host_cache_vocab_size_; | |||||
| size_t outer_dim_size = embedding_size; | |||||
| size_t thread_num = indices_lens / 10000 + 1; | |||||
| thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; | |||||
| std::thread threads[kMaxThreadNum]; | |||||
| size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num; | |||||
| size_t i; | |||||
| size_t task_offset = 0; | |||||
| MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens; | |||||
| for (i = 0; i < thread_num; i++) { | |||||
| if (task_offset >= indices_lens) { | |||||
| break; | |||||
| } | |||||
| MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens; | |||||
| threads[i] = std::thread(&PsCacheManager::LookUpTableTask, this, task_proc_lens, outer_dim_size, first_dim_size, | |||||
| hash_table_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size); | |||||
| task_offset += task_proc_lens; | |||||
| if (task_offset + task_proc_lens > indices_lens) { | |||||
| task_proc_lens = indices_lens - task_offset; | |||||
| } | |||||
| } | |||||
| for (size_t j = 0; j < i; j++) { | |||||
| threads[j].join(); | |||||
| } | |||||
| } | |||||
| void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, | |||||
| float *insert_data, float *hash_table_addr) { | |||||
| size_t first_dim_size = host_cache_vocab_size_; | |||||
| size_t thread_num = insert_indices_size / 10000 + 1; | |||||
| thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; | |||||
| std::thread threads[kMaxThreadNum]; | |||||
| size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num; | |||||
| size_t i; | |||||
| size_t task_offset = 0; | |||||
| auto insert_hash_table_task = [](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, | |||||
| int *insert_indices, float *insert_data, float *hash_table_addr) { | |||||
| auto type_size = sizeof(float); | |||||
| size_t lens = outer_dim_size * type_size; | |||||
| for (size_t i = 0; i < insert_indices_size; ++i) { | |||||
| int index = insert_indices[i]; | |||||
| if (index >= 0 && index < SizeToInt(first_dim_size)) { | |||||
| auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "Insert hash table task memcpy failed."; | |||||
| } | |||||
| } | |||||
| } | |||||
| }; | |||||
| for (i = 0; i < thread_num; i++) { | |||||
| if (task_offset >= insert_indices_size) { | |||||
| break; | |||||
| } | |||||
| MS_LOG(DEBUG) << "Task offset: " << task_offset << ", task process lens:" << task_proc_lens; | |||||
| threads[i] = std::thread(insert_hash_table_task, task_proc_lens, embedding_size, first_dim_size, | |||||
| insert_indices + task_offset, insert_data + task_offset * embedding_size, hash_table_addr); | |||||
| task_offset += task_proc_lens; | |||||
| if (task_offset + task_proc_lens > insert_indices_size) { | |||||
| task_proc_lens = insert_indices_size - task_offset; | |||||
| } | |||||
| } | |||||
| for (size_t j = 0; j < i; j++) { | |||||
| threads[j].join(); | |||||
| } | |||||
| } | |||||
| void PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get(); | |||||
| auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get(); | |||||
| auto swap_indices_size = statistics_info_.host_to_device_size_; | |||||
| if (swap_indices_size == 0) { | |||||
| return; | |||||
| } | |||||
| auto embedding_size = hash_info.embedding_size; | |||||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||||
| auto hash_table_size = hash_info.device_address.size; | |||||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||||
| auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size); | |||||
| LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_cache_host_to_device_index, | |||||
| swap_out_data.get()); | |||||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, | |||||
| swap_out_data.get(), | |||||
| swap_indices_size * embedding_size * sizeof(float)); | |||||
| embedding_device_cache_->cache_->CopyHostMemToDevice( | |||||
| embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index, swap_indices_size * sizeof(int)); | |||||
| embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||||
| embedding_size, swap_indices_size); | |||||
| } | |||||
| void PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| auto swap_indices_size = statistics_info_.device_to_host_size_; | |||||
| auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get(); | |||||
| auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get(); | |||||
| if (swap_indices_size == 0) { | |||||
| return; | |||||
| } | |||||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||||
| auto hash_table_size = hash_info.device_address.size; | |||||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||||
| auto embedding_size = hash_info.embedding_size; | |||||
| auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size); | |||||
| embedding_device_cache_->cache_->CopyHostMemToDevice( | |||||
| embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index, swap_indices_size * sizeof(int)); | |||||
| embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||||
| embedding_size, swap_indices_size); | |||||
| embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data.get(), | |||||
| embedding_device_cache_->hash_swap_value_addr_, | |||||
| swap_indices_size * embedding_size * sizeof(float)); | |||||
| embedding_device_cache_->cache_->SynchronizeStream(); | |||||
| InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index, | |||||
| swap_out_data.get(), host_hash_table_addr); | |||||
| } | |||||
| void PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | |||||
| auto host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | |||||
| auto swap_indices_size = statistics_info_.host_to_server_size_; | |||||
| if (swap_indices_size == 0) { | |||||
| return; | |||||
| } | |||||
| ::ps::SArray<int> lookup_ids(swap_indices_size, 0); | |||||
| ::ps::SArray<float> swap_out_data; | |||||
| auto embedding_size = hash_info.embedding_size; | |||||
| swap_out_data.resize(swap_indices_size * embedding_size); | |||||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||||
| LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, | |||||
| swap_out_data.data()); | |||||
| auto copy_len = swap_indices_size * sizeof(int); | |||||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||||
| } | |||||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||||
| } | |||||
| void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| auto swap_indices_size = statistics_info_.server_to_host_size_; | |||||
| auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); | |||||
| auto server_to_host_index = embedding_host_cache_->server_to_host_index.get(); | |||||
| if (swap_indices_size == 0) { | |||||
| return; | |||||
| } | |||||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||||
| auto embedding_size = hash_info.embedding_size; | |||||
| ::ps::SArray<int> lengths{swap_indices_size}; | |||||
| ::ps::SArray<float> lookup_result(swap_indices_size * embedding_size, 0); | |||||
| ::ps::SArray<int> lookup_ids(swap_indices_size, 0); | |||||
| auto copy_len = swap_indices_size * sizeof(int); | |||||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||||
| } | |||||
| worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||||
| InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(), | |||||
| host_hash_table_addr); | |||||
| } | |||||
| void PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, | |||||
| const HashTableInfo &hash_info) { | |||||
| MS_EXCEPTION_IF_NULL(swap_out_index); | |||||
| MS_EXCEPTION_IF_NULL(swap_out_data); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| auto swap_out_index_size = statistics_info_.device_to_host_size_; | |||||
| if (swap_out_index_size == 0) { | |||||
| return; | |||||
| } | |||||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||||
| auto hash_table_size = hash_info.device_address.size; | |||||
| auto embedding_size = hash_info.embedding_size; | |||||
| swap_out_data->resize(swap_out_index_size * embedding_size); | |||||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_out_index, | |||||
| swap_out_index_size * sizeof(int)); | |||||
| embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||||
| embedding_size, swap_out_index_size); | |||||
| embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data->data(), | |||||
| embedding_device_cache_->hash_swap_value_addr_, | |||||
| swap_out_index_size * embedding_size * sizeof(float)); | |||||
| embedding_device_cache_->cache_->RecordEvent(); | |||||
| } | |||||
| void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, | |||||
| size_t key) { | |||||
| MS_EXCEPTION_IF_NULL(swap_in_ids); | |||||
| MS_EXCEPTION_IF_NULL(swap_in_index); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| auto swap_in_ids_size = statistics_info_.host_to_device_size_; | |||||
| if (swap_in_ids_size == 0) { | |||||
| return; | |||||
| } | |||||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||||
| auto hash_table_size = hash_info.device_address.size; | |||||
| auto embedding_size = hash_info.embedding_size; | |||||
| // Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device). | |||||
| ::ps::SArray<int> lengths{swap_in_ids_size}; | |||||
| ::ps::SArray<float> lookup_result(swap_in_ids_size * embedding_size, 0); | |||||
| ::ps::SArray<int> lookup_ids(swap_in_ids_size, 0); | |||||
| auto copy_len = swap_in_ids_size * sizeof(int); | |||||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||||
| } | |||||
| worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||||
| // Hash swap-in in device. | |||||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, | |||||
| lookup_result.data(), | |||||
| swap_in_ids_size * embedding_size * sizeof(float)); | |||||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_in_index, | |||||
| swap_in_ids_size * sizeof(int)); | |||||
| embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||||
| embedding_size, swap_in_ids_size); | |||||
| } | |||||
| void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_EXCEPTION_IF_NULL(swap_out_ids); | |||||
| auto swap_out_ids_size = statistics_info_.device_to_host_size_; | |||||
| if (swap_out_ids_size == 0) { | |||||
| return; | |||||
| } | |||||
| ::ps::SArray<int> lookup_ids(swap_out_ids_size, 0); | |||||
| auto copy_len = swap_out_ids_size * sizeof(int); | |||||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); | |||||
| if (ret != EOK) { | |||||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||||
| } | |||||
| // Need synchronize event to ensure that the swap-out in device is completed. | |||||
| embedding_device_cache_->cache_->SynchronizeEvent(); | |||||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||||
| } | |||||
| void PsCacheManager::DumpHashTables() const { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| for (const auto &item : hash_tables_) { | |||||
| const auto ¶m_name = item.first; | |||||
| size_t cache_vocab_size = item.second.cache_vocab_size; | |||||
| size_t embedding_size = item.second.embedding_size; | |||||
| size_t vocab_size = item.second.vocab_size; | |||||
| MS_LOG(INFO) << "Dump hash tables: " << param_name << " || " << cache_vocab_size << " || " << embedding_size | |||||
| << " || " << vocab_size << " || " << reinterpret_cast<void *>(item.second.device_address.addr) | |||||
| << " || " << reinterpret_cast<void *>(item.second.host_address.get()); | |||||
| float *output = new float[item.second.device_address.size / 4]; | |||||
| embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr, | |||||
| item.second.device_address.size); | |||||
| embedding_device_cache_->cache_->SynchronizeStream(); | |||||
| for (size_t i = 0; i < cache_vocab_size; i++) { | |||||
| for (size_t j = 0; j < embedding_size; j++) { | |||||
| std::cout << output[i * embedding_size + j] << " "; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| } | |||||
| std::cout << std::endl; | |||||
| delete[] output; | |||||
| } | |||||
| } | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -49,6 +49,7 @@ struct HashTableInfo { | |||||
| size_t vocab_size{0}; | size_t vocab_size{0}; | ||||
| Address device_address{nullptr, 0}; | Address device_address{nullptr, 0}; | ||||
| std::shared_ptr<int[]> host_address{nullptr}; | std::shared_ptr<int[]> host_address{nullptr}; | ||||
| ParamInitInfo param_init_info_; | |||||
| }; | }; | ||||
| struct EmbeddingDeviceCache { | struct EmbeddingDeviceCache { | ||||
| @@ -158,6 +159,8 @@ class PsCacheManager { | |||||
| void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key); | void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key); | ||||
| void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, | void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, | ||||
| const int *indices_addr, float *output_addr); | const int *indices_addr, float *output_addr); | ||||
| bool CheckFinishInsertInitInfo() const; | |||||
| void AddEmbeddingTable() const; | |||||
| bool initialized_ps_cache_{false}; | bool initialized_ps_cache_{false}; | ||||
| std::string channel_name_; | std::string channel_name_; | ||||
| @@ -167,6 +170,7 @@ class PsCacheManager { | |||||
| size_t data_step_{0}; | size_t data_step_{0}; | ||||
| std::mutex data_mutex_; | std::mutex data_mutex_; | ||||
| std::condition_variable data_prase_; | std::condition_variable data_prase_; | ||||
| std::condition_variable insert_init_info_; | |||||
| std::map<std::string, HashTableInfo> hash_tables_; | std::map<std::string, HashTableInfo> hash_tables_; | ||||
| std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_; | std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_; | ||||
| @@ -178,6 +182,8 @@ class PsCacheManager { | |||||
| size_t batch_elements_{0}; | size_t batch_elements_{0}; | ||||
| PsCacheStatisticsInfo statistics_info_; | PsCacheStatisticsInfo statistics_info_; | ||||
| std::pair<size_t, size_t> range_bound_; | std::pair<size_t, size_t> range_bound_; | ||||
| std::atomic_bool finish_insert_init_info_{false}; | |||||
| std::atomic_bool finish_init_parameter_server_{false}; | |||||
| }; | }; | ||||
| static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | ||||
| @@ -26,7 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| class PsDataPrefetch { | |||||
| class EXPORT PsDataPrefetch { | |||||
| public: | public: | ||||
| EXPORT static PsDataPrefetch &GetInstance() { | EXPORT static PsDataPrefetch &GetInstance() { | ||||
| static PsDataPrefetch instance; | static PsDataPrefetch instance; | ||||
| @@ -17,6 +17,11 @@ | |||||
| #include "ps/ps_context.h" | #include "ps/ps_context.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||||
| #include "backend/kernel_compiler/kernel.h" | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -80,5 +85,43 @@ bool PSContext::is_role_sched() const { return is_sched_; } | |||||
| void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; } | void PSContext::SetPSRankId(int rank_id) { rank_id_ = rank_id; } | ||||
| int PSContext::ps_rank_id() const { return rank_id_; } | int PSContext::ps_rank_id() const { return rank_id_; } | ||||
| void PSContext::InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, | |||||
| size_t vocab_size) const { | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size); | |||||
| #endif | |||||
| } | |||||
| void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, | |||||
| size_t cache_vocab_size, size_t embedding_size) const { | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size); | |||||
| #endif | |||||
| } | |||||
| void PSContext::InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const { | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed); | |||||
| #endif | |||||
| } | |||||
| void PSContext::InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const { | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| ps_cache_instance.InsertAccumuInitInfo(param_name, init_val); | |||||
| #endif | |||||
| } | |||||
| void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const { | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| ps_cache_instance.CloneHashTable(dest_param_name, src_param_name); | |||||
| #endif | |||||
| } | |||||
| void PSContext::set_cache_enable(bool cache_enable) const { | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| PsDataPrefetch::GetInstance().set_cache_enable(cache_enable); | |||||
| #endif | |||||
| } | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -44,6 +44,14 @@ class PSContext { | |||||
| bool is_role_sched() const; | bool is_role_sched() const; | ||||
| void SetPSRankId(int rank_id); | void SetPSRankId(int rank_id); | ||||
| int ps_rank_id() const; | int ps_rank_id() const; | ||||
| void InsertHashTableSize(const std::string ¶m_name, size_t cache_vocab_size, size_t embedding_size, | |||||
| size_t vocab_size) const; | |||||
| void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name, | |||||
| size_t cache_vocab_size, size_t embedding_size) const; | |||||
| void InsertWeightInitInfo(const std::string ¶m_name, size_t global_seed, size_t op_seed) const; | |||||
| void InsertAccumuInitInfo(const std::string ¶m_name, float init_val) const; | |||||
| void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const; | |||||
| void set_cache_enable(bool cache_enable) const; | |||||
| private: | private: | ||||
| PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {} | PSContext() : ps_enabled_(false), is_worker_(false), is_pserver_(false), is_sched_(false), rank_id_(-1) {} | ||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ps/random_normal/random_normal.h" | |||||
| #include <iostream> | |||||
| #include <thread> | |||||
| #include <memory> | |||||
| #include "utils/convert_utils_base.h" | |||||
| #include "pybind_api/random_normal/random_cpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| bool InitRandomNormal(float mean, float stddev, std::vector<size_t> out_shape, size_t global_seed, size_t op_seed, | |||||
| float *output_data) { | |||||
| if (out_shape.size() == 0) { | |||||
| std::cout << "output data shape is error" << std::endl; | |||||
| } | |||||
| int64_t total_count = 1; | |||||
| for (uint32_t i = 0; i < out_shape.size(); i++) { | |||||
| total_count *= SizeToLong(out_shape[i]); | |||||
| } | |||||
| uint32_t thread_num = 16; | |||||
| if (total_count <= thread_num) { | |||||
| thread_num = 1; | |||||
| } | |||||
| float *start_ptr = output_data; | |||||
| if (start_ptr == nullptr) { | |||||
| std::cout << "start_ptr is nullptr" << std::endl; | |||||
| return false; | |||||
| } | |||||
| int64_t batchSize = total_count / thread_num; | |||||
| std::vector<std::thread> threads(thread_num); | |||||
| int64_t seed = SizeToLong(global_seed); | |||||
| int64_t seed2 = SizeToLong(op_seed); | |||||
| seed = (seed == 0 && seed2 == 0) ? clock() : seed; | |||||
| PhiloxGenerator generator = PhiloxGenerator(seed, seed2); | |||||
| if (thread_num != 1) { | |||||
| for (uint32_t i = 0; i < thread_num - 1; i++) { | |||||
| float *offset_ptr = start_ptr + batchSize * i; | |||||
| threads[i] = | |||||
| std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator, offset_ptr, batchSize, i); | |||||
| } | |||||
| float *offset_ptr = start_ptr + batchSize * (thread_num - 1); | |||||
| threads[thread_num - 1] = std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator, | |||||
| offset_ptr, total_count - (thread_num - 1) * batchSize, thread_num - 1); | |||||
| } else { | |||||
| threads[0] = | |||||
| std::thread(FillRandoms<NormalDistribution<PhiloxGenerator, float>>, generator, start_ptr, total_count, 0); | |||||
| } | |||||
| for (uint32_t i = 0; i < thread_num; i++) { | |||||
| threads[i].join(); | |||||
| } | |||||
| for (int64_t i = 0; i < total_count; i++) { | |||||
| output_data[i] *= stddev; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_ | |||||
| #define MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_ | |||||
| #include <vector> | |||||
| namespace mindspore { | |||||
| namespace ps { | |||||
| bool InitRandomNormal(float mean, float stddev, std::vector<size_t> out_shape, size_t global_seed, size_t op_seed, | |||||
| float *output_data); | |||||
| } // namespace ps | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_PS_RANDOM_NORMAL_RANDOM_NORMAL_H_ | |||||
| @@ -26,6 +26,15 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| enum ParamType { kUnKnown = 0, kWeight = 1, kAccumulation = 2 }; | |||||
| struct ParamInitInfo { | |||||
| ParamType param_type_{kUnKnown}; | |||||
| size_t global_seed_{0}; | |||||
| size_t op_seed_{0}; | |||||
| float init_val_{0}; | |||||
| }; | |||||
| class Util { | class Util { | ||||
| public: | public: | ||||
| static bool IsParamServerMode(); | static bool IsParamServerMode(); | ||||
| @@ -32,6 +32,7 @@ | |||||
| #include "ps/common.h" | #include "ps/common.h" | ||||
| #include "ps/worker_proxy.h" | #include "ps/worker_proxy.h" | ||||
| #include "utils/shape_utils.h" | #include "utils/shape_utils.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| @@ -47,15 +48,19 @@ class Worker { | |||||
| void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes); | void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes); | ||||
| void Pull(const size_t key, void *dev_addr, const size_t size); | void Pull(const size_t key, void *dev_addr, const size_t size); | ||||
| size_t SetParamKey(const std::string ¶m_name); | size_t SetParamKey(const std::string ¶m_name); | ||||
| size_t GetParamKey(const std::string ¶m_name); | |||||
| void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); | void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); | ||||
| bool GetParamInitInServer(const std::string ¶m_name); | bool GetParamInitInServer(const std::string ¶m_name); | ||||
| void SetKeyOptimId(size_t key, const std::string &optimizer_name); | void SetKeyOptimId(size_t key, const std::string &optimizer_name); | ||||
| void SetOptimInputShapes(size_t key, const ShapeVector &shape); | void SetOptimInputShapes(size_t key, const ShapeVector &shape); | ||||
| void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); | void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); | ||||
| void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, const ShapeVector &sizes); | |||||
| void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes); | |||||
| void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor); | void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor); | ||||
| void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | ||||
| const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd); | const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd); | ||||
| void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||||
| const ::ps::SArray<T> &vals); | |||||
| bool running() { return running_; } | |||||
| void Finalize(); | void Finalize(); | ||||
| private: | private: | ||||
| @@ -65,7 +70,6 @@ class Worker { | |||||
| Worker &operator=(const Worker &) = delete; | Worker &operator=(const Worker &) = delete; | ||||
| bool IsKeyInit(const size_t key); | bool IsKeyInit(const size_t key); | ||||
| size_t GetParamKey(const std::string ¶m_name); | |||||
| void InitPSOptimId(const size_t param_key); | void InitPSOptimId(const size_t param_key); | ||||
| void InitPSOptimInputShapes(const size_t key); | void InitPSOptimInputShapes(const size_t key); | ||||
| void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size); | void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size); | ||||
| @@ -187,6 +191,12 @@ void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const : | |||||
| kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd); | kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd); | ||||
| } | } | ||||
| template <typename T> | |||||
| void Worker<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||||
| const ::ps::SArray<T> &vals) { | |||||
| kv_worker_->UpdateEmbeddingTable(keys, lookup_ids, vals); | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| void Worker<T>::Finalize() { | void Worker<T>::Finalize() { | ||||
| if (running_) { | if (running_) { | ||||
| @@ -286,7 +296,7 @@ size_t Worker<T>::GetParamKey(const std::string ¶m_name) { | |||||
| size_t key = kInvalidKey; | size_t key = kInvalidKey; | ||||
| if (param_to_key_.find(param_name) != param_to_key_.end()) { | if (param_to_key_.find(param_name) != param_to_key_.end()) { | ||||
| key = param_to_key_[param_name]; | key = param_to_key_[param_name]; | ||||
| MS_LOG(INFO) << "Get key of parameter " << param_name << " key is " << key; | |||||
| MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key; | |||||
| } | } | ||||
| return key; | return key; | ||||
| } | } | ||||
| @@ -310,8 +320,7 @@ void Worker<T>::InitPSOptimId(const size_t param_key) { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<size_t> shapes, | |||||
| const ShapeVector &sizes) { | |||||
| void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes) { | |||||
| bool has_init = IsKeyInit(keys[0]); | bool has_init = IsKeyInit(keys[0]); | ||||
| if (has_init) { | if (has_init) { | ||||
| MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized."; | MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized."; | ||||
| @@ -319,7 +328,7 @@ void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vecto | |||||
| } | } | ||||
| ::ps::SArray<T> shapes_val; | ::ps::SArray<T> shapes_val; | ||||
| for (auto dim : shapes) { | for (auto dim : shapes) { | ||||
| shapes_val.push_back(static_cast<T>(dim)); | |||||
| shapes_val.push_back(dim); | |||||
| } | } | ||||
| std::vector<int> sizes_int; | std::vector<int> sizes_int; | ||||
| (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int), | (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int), | ||||
| @@ -337,9 +346,6 @@ void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor:: | |||||
| const std::string ¶m_name = pk_node->fullname_with_scope(); | const std::string ¶m_name = pk_node->fullname_with_scope(); | ||||
| void *param_data = tensor->data_c(); | void *param_data = tensor->data_c(); | ||||
| size_t param_size = LongToSize(tensor->data().nbytes()); | size_t param_size = LongToSize(tensor->data().nbytes()); | ||||
| if (param_size > INT_MAX) { | |||||
| MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " << param_size; | |||||
| } | |||||
| size_t param_key = GetParamKey(param_name); | size_t param_key = GetParamKey(param_name); | ||||
| if (param_key == kInvalidKey) { | if (param_key == kInvalidKey) { | ||||
| @@ -357,11 +363,17 @@ void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor:: | |||||
| MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name | MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name | ||||
| << ", whether init in server: " << init_in_server; | << ", whether init in server: " << init_in_server; | ||||
| kv_worker_->AddKeyToServerId(param_key); | kv_worker_->AddKeyToServerId(param_key); | ||||
| if (!init_in_server) { | |||||
| InitPSParamData({param_key}, param_data, param_size); | |||||
| if (!PsDataPrefetch::GetInstance().cache_enable()) { | |||||
| if (!init_in_server) { | |||||
| if (param_size > INT_MAX) { | |||||
| MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " | |||||
| << param_size; | |||||
| } | |||||
| InitPSParamData({param_key}, param_data, param_size); | |||||
| } | |||||
| InitPSOptimId(param_key); | |||||
| InitPSOptimInputShapes(param_key); | |||||
| } | } | ||||
| InitPSOptimId(param_key); | |||||
| InitPSOptimInputShapes(param_key); | |||||
| } | } | ||||
| } | } | ||||
| @@ -45,6 +45,7 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||||
| explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id) | explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id) | ||||
| : Worker(app_id, customer_id) { | : Worker(app_id, customer_id) { | ||||
| server_num_ = ::ps::NumServers(); | server_num_ = ::ps::NumServers(); | ||||
| MS_LOG(INFO) << "Server num:" << server_num_; | |||||
| PSContext::instance()->SetPSRankId(::ps::MyRank()); | PSContext::instance()->SetPSRankId(::ps::MyRank()); | ||||
| using std::placeholders::_1; | using std::placeholders::_1; | ||||
| using std::placeholders::_2; | using std::placeholders::_2; | ||||
| @@ -60,6 +61,7 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||||
| broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4, _5); | broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4, _5); | ||||
| round_robin_slicer_ = std::bind(&WorkerProxy<T>::RoundRobinSlicer, this, _1, _2, _3, _4, _5); | round_robin_slicer_ = std::bind(&WorkerProxy<T>::RoundRobinSlicer, this, _1, _2, _3, _4, _5); | ||||
| worker_init_embedding_slicer_ = std::bind(&WorkerProxy<T>::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5); | worker_init_embedding_slicer_ = std::bind(&WorkerProxy<T>::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5); | ||||
| update_embedding_slicer_ = std::bind(&WorkerProxy<T>::UpdateEmbeddingSlicer, this, _1, _2, _3, _4, _5); | |||||
| } | } | ||||
| ~WorkerProxy() override = default; | ~WorkerProxy() override = default; | ||||
| @@ -70,6 +72,8 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||||
| const Callback &cb = nullptr, int64_t priority = 0); | const Callback &cb = nullptr, int64_t priority = 0); | ||||
| int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, | int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, | ||||
| const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int64_t priority = 0); | const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int64_t priority = 0); | ||||
| void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||||
| const ::ps::SArray<T> &vals, const Callback &cb = nullptr, int64_t priority = 0); | |||||
| bool IsReadyForPush(const Key &key); | bool IsReadyForPush(const Key &key); | ||||
| bool IsReadyForPull(const Key &key); | bool IsReadyForPull(const Key &key); | ||||
| void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {}, | void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {}, | ||||
| @@ -98,6 +102,9 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||||
| void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | ||||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | ||||
| const std::map<int64_t, int64_t> &attrs); | const std::map<int64_t, int64_t> &attrs); | ||||
| void UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||||
| const std::map<int64_t, int64_t> &attrs); | |||||
| void ProcessLookupResult(const ::ps::Message &msg); | void ProcessLookupResult(const ::ps::Message &msg); | ||||
| void ProcessResponse(const ::ps::Message &msg); | void ProcessResponse(const ::ps::Message &msg); | ||||
| void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs<T> &kvs, | void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs<T> &kvs, | ||||
| @@ -122,6 +129,7 @@ class WorkerProxy : public ::ps::KVWorker<T> { | |||||
| Slicer broadcast_slicer_; | Slicer broadcast_slicer_; | ||||
| Slicer round_robin_slicer_; | Slicer round_robin_slicer_; | ||||
| Slicer worker_init_embedding_slicer_; | Slicer worker_init_embedding_slicer_; | ||||
| Slicer update_embedding_slicer_; | |||||
| std::unordered_map<int64_t, Callback> lookup_callbacks_; | std::unordered_map<int64_t, Callback> lookup_callbacks_; | ||||
| std::unordered_map<int64_t, Callback> general_callbacks_; | std::unordered_map<int64_t, Callback> general_callbacks_; | ||||
| std::unordered_map<int64_t, int64_t> expected_result_count_; | std::unordered_map<int64_t, int64_t> expected_result_count_; | ||||
| @@ -195,6 +203,24 @@ int64_t WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, | |||||
| return ts; | return ts; | ||||
| } | } | ||||
| template <typename T> | |||||
| void WorkerProxy<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||||
| const ::ps::SArray<T> &vals, const Callback &cb, int64_t priority) { | |||||
| int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr); | |||||
| ::ps::KVPairs<T> kvs; | |||||
| kvs.keys = keys; | |||||
| kvs.lens = lookup_ids; | |||||
| kvs.vals = vals; | |||||
| kvs.priority = priority; | |||||
| expected_result_count_[ts] = 0; | |||||
| Send(general_customer_.get(), ts, true, false, kUpdateEmbeddingsCmd, kvs, update_embedding_slicer_); | |||||
| if (expected_result_count_[ts] < server_num_) { | |||||
| general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); | |||||
| } | |||||
| general_customer_->WaitRequest(ts); | |||||
| expected_result_count_.erase(ts); | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| bool WorkerProxy<T>::IsReadyForPush(const Key &key) { | bool WorkerProxy<T>::IsReadyForPush(const Key &key) { | ||||
| ::ps::SArray<T> result(1, 0); | ::ps::SArray<T> result(1, 0); | ||||
| @@ -724,6 +750,47 @@ void WorkerProxy<T>::WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KV | |||||
| } | } | ||||
| } | } | ||||
| template <typename T> | |||||
| void WorkerProxy<T>::UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, | |||||
| const std::vector<::ps::Range> &, | |||||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||||
| const std::map<int64_t, int64_t> &attrs) { | |||||
| MS_EXCEPTION_IF_NULL(sliced); | |||||
| T *embedding_vals = send.vals.data(); | |||||
| int *lookup_ids = send.lens.data(); | |||||
| size_t val_size = send.vals.size(); | |||||
| size_t id_size = send.lens.size(); | |||||
| size_t embedding_dim = val_size / id_size; | |||||
| const Key &key = send.keys[0]; | |||||
| const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); | |||||
| sliced->resize(ranges.size()); | |||||
| for (size_t i = 0; i < ranges.size(); i++) { | |||||
| const ::ps::Range &range = ranges[i]; | |||||
| const auto &begin = range.begin(); | |||||
| const auto &end = range.end(); | |||||
| auto &kvs = sliced->at(i).second; | |||||
| kvs.keys.push_back(key); | |||||
| for (size_t j = 0; j < id_size; j++) { | |||||
| auto lookup_id = static_cast<uint64_t>(lookup_ids[j]); | |||||
| if (lookup_id >= begin && lookup_id <= end) { | |||||
| kvs.keys.push_back(lookup_id); | |||||
| for (size_t k = 0; k < embedding_dim; k++) { | |||||
| kvs.vals.push_back(embedding_vals[j * embedding_dim + k]); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (kvs.keys.size() <= 1) { | |||||
| sliced->at(i).first = false; | |||||
| } else { | |||||
| sliced->at(i).first = true; | |||||
| expected_result_count_[timestamp] += 1; | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) { | void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) { | ||||
| int64_t ts = msg.meta.timestamp; | int64_t ts = msg.meta.timestamp; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <thread> | #include <thread> | ||||
| #include <memory> | #include <memory> | ||||
| #include "runtime/device/cpu/cpu_device_address.h" | #include "runtime/device/cpu/cpu_device_address.h" | ||||
| #include "ir/tensor.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, | bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, | ||||
| @@ -19,8 +19,7 @@ | |||||
| #include "pybind_api/random_normal/philox_generator.h" | #include "pybind_api/random_normal/philox_generator.h" | ||||
| #include "pybind11/pybind11.h" | #include "pybind11/pybind11.h" | ||||
| #include "pybind_api/api_register.h" | #include "pybind_api/api_register.h" | ||||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||||
| #include "utils/log_adapter.h" | |||||
| namespace py = pybind11; | namespace py = pybind11; | ||||
| @@ -55,6 +55,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||||
| bool SyncStream() override; | bool SyncStream() override; | ||||
| void SetContext() override; | void SetContext() override; | ||||
| void CreateContext() override; | void CreateContext() override; | ||||
| void *context() const override { return rt_context_; } | |||||
| protected: | protected: | ||||
| DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include "runtime/device/gpu/gpu_memory_allocator.h" | #include "runtime/device/gpu/gpu_memory_allocator.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "ps/ps_cache/ps_cache_manager.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| namespace gpu { | namespace gpu { | ||||
| @@ -38,6 +39,9 @@ void GPUMemoryManager::MallocDeviceMemory() { | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| // If use the dynamic memory pool, then alloc the first memory block to init. | // If use the dynamic memory pool, then alloc the first memory block to init. | ||||
| if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) { | if (context_ptr->get_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL)) { | ||||
| if (ps::ps_cache_instance.initialized_ps_cache()) { | |||||
| return; | |||||
| } | |||||
| auto device_addr = MallocMemFromMemPool(1); | auto device_addr = MallocMemFromMemPool(1); | ||||
| if (!device_addr) { | if (!device_addr) { | ||||
| MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; | MS_LOG(EXCEPTION) << "Dynamic memory pool init error."; | ||||
| @@ -30,6 +30,10 @@ | |||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "utils/shape_utils.h" | #include "utils/shape_utils.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | |||||
| #endif | |||||
| using mindspore::kernel::Address; | using mindspore::kernel::Address; | ||||
| using mindspore::kernel::AddressPtr; | using mindspore::kernel::AddressPtr; | ||||
| @@ -331,15 +335,27 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; | MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; | ||||
| continue; | continue; | ||||
| } | } | ||||
| DeviceAddressPtr device_address = nullptr; | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| const std::string ¶m_name = item->fullname_with_scope(); | |||||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||||
| const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name); | |||||
| MS_EXCEPTION_IF_NULL(address.addr); | |||||
| device_address = | |||||
| CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||||
| AnfAlgo::SetOutputAddr(device_address, index, item.get()); | |||||
| continue; | |||||
| } | |||||
| #endif | |||||
| auto tensor_size = CountNodeDeviceMemorySize(item, index); | auto tensor_size = CountNodeDeviceMemorySize(item, index); | ||||
| auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||||
| device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); | |||||
| MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope(); | MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope(); | ||||
| if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { | |||||
| if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address) == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; | MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; | ||||
| } | } | ||||
| MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope() | MS_LOG(INFO) << "Malloc Input for graph " << graph->graph_id() << ", node: " << item->fullname_with_scope() | ||||
| << " index: " << index << " size: " << tensor_size; | << " index: " << index << " size: " << tensor_size; | ||||
| AnfAlgo::SetOutputAddr(address, index, item.get()); | |||||
| AnfAlgo::SetOutputAddr(device_address, index, item.get()); | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(INFO) << "AssignStaticMemoryInput end"; | MS_LOG(INFO) << "AssignStaticMemoryInput end"; | ||||
| @@ -78,6 +78,7 @@ class KernelRuntime { | |||||
| virtual void ClearGlobalIdleMem() {} | virtual void ClearGlobalIdleMem() {} | ||||
| virtual void CreateContext() {} | virtual void CreateContext() {} | ||||
| virtual void SetContext() {} | virtual void SetContext() {} | ||||
| virtual void *context() const { return nullptr; } | |||||
| uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { | uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { | ||||
| return mem_manager_->MallocMem(type, size, address); | return mem_manager_->MallocMem(type, size, address); | ||||
| } | } | ||||
| @@ -15,6 +15,7 @@ | |||||
| """Parameter for cell.""" | """Parameter for cell.""" | ||||
| from copy import copy | from copy import copy | ||||
| import numbers | |||||
| from .._c_expression import ParamInfo | from .._c_expression import ParamInfo | ||||
| from .._c_expression import MetaTensor as MetaTensor_ | from .._c_expression import MetaTensor as MetaTensor_ | ||||
| from . import dtype as mstype | from . import dtype as mstype | ||||
| @@ -23,7 +24,10 @@ from .tensor import Tensor, MetaTensor | |||||
| from .._checkparam import Validator | from .._checkparam import Validator | ||||
| from ..parallel._tensor import _get_slice_index | from ..parallel._tensor import _get_slice_index | ||||
| from ..parallel._auto_parallel_context import auto_parallel_context | from ..parallel._auto_parallel_context import auto_parallel_context | ||||
| from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched | |||||
| from ..parallel._ps_context import _is_role_worker, _is_role_pserver, _is_role_sched, _clone_hash_table | |||||
| from ..parallel._ps_context import _reinsert_hash_table_size | |||||
| from ..parallel._ps_context import _insert_weight_init_info, _insert_accumu_init_info | |||||
| from .seed import _get_global_and_op_seed | |||||
| __all__ = ['Parameter', 'ParameterTuple'] | __all__ = ['Parameter', 'ParameterTuple'] | ||||
| @@ -35,6 +39,18 @@ def _is_in_parallel_mode(): | |||||
| """Get parallel mode.""" | """Get parallel mode.""" | ||||
| return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"] | return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"] | ||||
| def init_to_value(init): | |||||
| """Get value of initializer.""" | |||||
| if isinstance(init, str): | |||||
| if init == 'zeros': | |||||
| return 0.0 | |||||
| if init == 'ones': | |||||
| return 1.0 | |||||
| raise ValueError("init should be one of values in 'zeros', 'ones'.") | |||||
| if isinstance(init, numbers.Number): | |||||
| return float(init) | |||||
| raise ValueError("init should be number or string") | |||||
| class Parameter(MetaTensor_): | class Parameter(MetaTensor_): | ||||
| """ | """ | ||||
| @@ -118,6 +134,8 @@ class Parameter(MetaTensor_): | |||||
| def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False): | def __init__(self, default_input, name=None, requires_grad=True, layerwise_parallel=False): | ||||
| self._param_info = ParamInfo() | self._param_info = ParamInfo() | ||||
| self.init_in_server = False | |||||
| self.cache_enable = False | |||||
| self.name = name | self.name = name | ||||
| self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
| self.layerwise_parallel = layerwise_parallel | self.layerwise_parallel = layerwise_parallel | ||||
| @@ -129,7 +147,6 @@ class Parameter(MetaTensor_): | |||||
| self._sliced = False | self._sliced = False | ||||
| self.is_param_ps = False | self.is_param_ps = False | ||||
| self._cast_type = None | self._cast_type = None | ||||
| self.init_in_server = False | |||||
| self._unique = False | self._unique = False | ||||
| self.is_in_parallel = _is_in_parallel_mode() | self.is_in_parallel = _is_in_parallel_mode() | ||||
| if isinstance(default_input, (MetaTensor, Tensor)): | if isinstance(default_input, (MetaTensor, Tensor)): | ||||
| @@ -155,7 +172,7 @@ class Parameter(MetaTensor_): | |||||
| if isinstance(data, bool): | if isinstance(data, bool): | ||||
| raise ValueError('Parameter data can not be `bool`') | raise ValueError('Parameter data can not be `bool`') | ||||
| if isinstance(data, MetaTensor): | if isinstance(data, MetaTensor): | ||||
| if _is_in_parallel_mode() or _is_role_worker(): | |||||
| if _is_in_parallel_mode() or _is_role_worker() or _is_role_sched(): | |||||
| # do not init data while in auto parallel. | # do not init data while in auto parallel. | ||||
| return (MetaTensor_, data.dtype, data.shape) | return (MetaTensor_, data.dtype, data.shape) | ||||
| data = data.to_tensor() | data = data.to_tensor() | ||||
| @@ -189,18 +206,18 @@ class Parameter(MetaTensor_): | |||||
| init_in_server (bool): Whether trainable parameter updated by parameter server is | init_in_server (bool): Whether trainable parameter updated by parameter server is | ||||
| initialized on server. Default: False. | initialized on server. Default: False. | ||||
| """ | """ | ||||
| if _is_role_worker() or _is_role_pserver() or _is_role_sched(): | |||||
| if init_in_server and (not self.name.endswith("embedding_table")): | |||||
| raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " | |||||
| "sparse operator support initialization in server.".format(self.name)) | |||||
| self.is_param_ps = True | |||||
| self.init_in_server = init_in_server | |||||
| self._param_info.init_in_server = init_in_server | |||||
| else: | |||||
| if not(_is_role_worker() or _is_role_pserver() or _is_role_sched()): | |||||
| raise RuntimeError("Must complete following two steps before calling set_param_ps: \ | raise RuntimeError("Must complete following two steps before calling set_param_ps: \ | ||||
| 1. set_ps_context(enable_ps=True) \ | 1. set_ps_context(enable_ps=True) \ | ||||
| 2. export MS_ROLE environment variable.") | 2. export MS_ROLE environment variable.") | ||||
| if init_in_server and (not self.name.endswith("embedding_table")): | |||||
| raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of " | |||||
| "sparse operator support initialization in server.".format(self.name)) | |||||
| self.is_param_ps = True | |||||
| self.init_in_server = init_in_server | |||||
| self._param_info.init_in_server = init_in_server | |||||
| @property | @property | ||||
| def inited_param(self): | def inited_param(self): | ||||
| @@ -238,6 +255,13 @@ class Parameter(MetaTensor_): | |||||
| format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) | format(name_, PARAMETER_NAME_PREFIX_MAX_LEN)) | ||||
| else: | else: | ||||
| raise ValueError("The type of the name should be `str` or `None`.") | raise ValueError("The type of the name should be `str` or `None`.") | ||||
| if _is_role_worker() and self.cache_enable: | |||||
| if len(self.shape) != 2: | |||||
| raise RuntimeError("The dims of parameter '{}' must be 2, but got {}." | |||||
| .format(self.name, len(self.shape))) | |||||
| _reinsert_hash_table_size(name_, self._param_info.name, self.shape[0], self.shape[1]) | |||||
| self._param_info.name = name_ | self._param_info.name = name_ | ||||
| @property | @property | ||||
| @@ -297,6 +321,7 @@ class Parameter(MetaTensor_): | |||||
| x.is_init = False | x.is_init = False | ||||
| x.is_param_ps = self.is_param_ps | x.is_param_ps = self.is_param_ps | ||||
| x.init_in_server = self.init_in_server | x.init_in_server = self.init_in_server | ||||
| x.cache_enable = self.cache_enable | |||||
| if init != 'same': | if init != 'same': | ||||
| shape = self.shape | shape = self.shape | ||||
| dtype = self.dtype | dtype = self.dtype | ||||
| @@ -431,15 +456,18 @@ class Parameter(MetaTensor_): | |||||
| raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) | raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout)) | ||||
| slice_index = int(_get_slice_index(layout[0], layout[1])) | slice_index = int(_get_slice_index(layout[0], layout[1])) | ||||
| if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): | if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): | ||||
| if _is_role_worker(): | |||||
| if _is_role_worker() or _is_role_sched(): | |||||
| data = self.init_mode.to_tensor(0, [1]) | data = self.init_mode.to_tensor(0, [1]) | ||||
| else: | else: | ||||
| data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) | data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) | ||||
| else: | else: | ||||
| data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) | data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) | ||||
| else: | else: | ||||
| if _is_role_worker() and self.cache_enable: | |||||
| global_seed, op_seed = _get_global_and_op_seed() | |||||
| _insert_weight_init_info(self.name, global_seed, op_seed) | |||||
| if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): | if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): | ||||
| if _is_role_worker(): | |||||
| if _is_role_worker() or _is_role_sched(): | |||||
| data = self.init_mode.to_tensor(0, [1]) | data = self.init_mode.to_tensor(0, [1]) | ||||
| else: | else: | ||||
| data = self.init_mode.to_tensor() | data = self.init_mode.to_tensor() | ||||
| @@ -502,6 +530,16 @@ class ParameterTuple(tuple): | |||||
| x1 = x.clone(init) | x1 = x.clone(init) | ||||
| x1.name = prefix + "." + x1.name | x1.name = prefix + "." + x1.name | ||||
| new.append(x1) | new.append(x1) | ||||
| if not x1.cache_enable: | |||||
| continue | |||||
| if not x1.name.endswith("embedding_table"): | |||||
| raise RuntimeError("Can not enable cache for parameter '{}', Only parameters of " | |||||
| "sparse operator support enable cache.".format(x1.name)) | |||||
| if _is_role_worker(): | |||||
| _clone_hash_table(x.name, x1.name) | |||||
| _insert_accumu_init_info(x1.name, init_to_value(init)) | |||||
| return ParameterTuple(new) | return ParameterTuple(new) | ||||
| def __parameter_tuple__(self): | def __parameter_tuple__(self): | ||||
| @@ -195,6 +195,20 @@ def _get_op_seed(op_seed, kernel_name): | |||||
| return _KERNEL_SEED[(kernel_name, op_seed)] | return _KERNEL_SEED[(kernel_name, op_seed)] | ||||
| def _get_global_and_op_seed(): | |||||
| """Get global_seed and op_seed.""" | |||||
| global_seed = get_seed() | |||||
| op_seed = get_seed() | |||||
| if global_seed == 0: | |||||
| global_seed = DEFAULT_GRAPH_SEED | |||||
| elif global_seed is None: | |||||
| global_seed = 0 | |||||
| Validator.check_non_negative_int(op_seed, "seed", "init") | |||||
| temp_seed = _get_op_seed(op_seed, "init") | |||||
| seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed) | |||||
| return seeds | |||||
| def _get_graph_seed(op_seed, kernel_name): | def _get_graph_seed(op_seed, kernel_name): | ||||
| """ | """ | ||||
| Get the graph-level seed. | Get the graph-level seed. | ||||
| @@ -73,6 +73,8 @@ inline size_t FloatToSize(float u) { | |||||
| } | } | ||||
| inline float IntToFloat(int32_t v) { return static_cast<float>(v); } | inline float IntToFloat(int32_t v) { return static_cast<float>(v); } | ||||
| inline float SizeToFloat(size_t v) { return static_cast<float>(v); } | |||||
| inline double LongToDouble(int64_t v) { return static_cast<double>(v); } | inline double LongToDouble(int64_t v) { return static_cast<double>(v); } | ||||
| inline double FloatToDouble(float v) { return static_cast<double>(v); } | inline double FloatToDouble(float v) { return static_cast<double>(v); } | ||||
| @@ -22,6 +22,7 @@ from mindspore.common.initializer import initializer | |||||
| from mindspore.communication.management import get_group_size | from mindspore.communication.management import get_group_size | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.parallel._utils import _get_parallel_mode | from mindspore.parallel._utils import _get_parallel_mode | ||||
| from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker | |||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore.ops.primitive import constexpr | from mindspore.ops.primitive import constexpr | ||||
| @@ -156,6 +157,7 @@ class EmbeddingLookup(Cell): | |||||
| max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 | max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 | ||||
| or None. Default: None | or None. Default: None | ||||
| sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. | sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. | ||||
| vocab_cache_size (int): Cache size of the dictionary of embeddings. | |||||
| Inputs: | Inputs: | ||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | ||||
| @@ -185,7 +187,7 @@ class EmbeddingLookup(Cell): | |||||
| def __init__(self, vocab_size, embedding_size, param_init='normal', | def __init__(self, vocab_size, embedding_size, param_init='normal', | ||||
| target='CPU', slice_mode='batch_slice', manual_shapes=None, | target='CPU', slice_mode='batch_slice', manual_shapes=None, | ||||
| max_norm=None, sparse=True): | |||||
| max_norm=None, sparse=True, vocab_cache_size=0): | |||||
| super(EmbeddingLookup, self).__init__() | super(EmbeddingLookup, self).__init__() | ||||
| self.target = target | self.target = target | ||||
| if target not in ('CPU', 'DEVICE'): | if target not in ('CPU', 'DEVICE'): | ||||
| @@ -199,11 +201,23 @@ class EmbeddingLookup(Cell): | |||||
| self.gatherv2 = P.GatherV2() | self.gatherv2 = P.GatherV2() | ||||
| self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | ||||
| self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) | self.vocab_size = validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) | ||||
| self.vocab_cache_size = validator.check_value_type('vocab_cache_size', vocab_cache_size, [int], self.cls_name) | |||||
| self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) | self.embedding_size = validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) | ||||
| self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | |||||
| name='embedding_table') | |||||
| parallel_mode = _get_parallel_mode() | parallel_mode = _get_parallel_mode() | ||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | ||||
| self.cache_enable = self.vocab_cache_size > 0 | |||||
| if self.cache_enable: | |||||
| if is_auto_parallel: | |||||
| self.vocab_cache_size = self.vocab_cache_size * get_group_size() | |||||
| self.vocab_size = self.vocab_cache_size | |||||
| self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | |||||
| name='embedding_table') | |||||
| if self.cache_enable: | |||||
| self.embedding_table.cache_enable = True | |||||
| _set_cache_enable(True) | |||||
| if _is_role_worker(): | |||||
| _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) | |||||
| self.forward_unique = False | self.forward_unique = False | ||||
| self.gather_revert = P.GatherV2() | self.gather_revert = P.GatherV2() | ||||
| self.unique = P.Unique().shard(((1,),)) | self.unique = P.Unique().shard(((1,),)) | ||||
| @@ -222,7 +236,7 @@ class EmbeddingLookup(Cell): | |||||
| self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) | self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) | ||||
| self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) | self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) | ||||
| elif slice_mode == "table_row_slice" and is_auto_parallel: | elif slice_mode == "table_row_slice" and is_auto_parallel: | ||||
| if target == 'DEVICE': | |||||
| if target == 'DEVICE' and not self.cache_enable: | |||||
| indices_shape_size = 1 | indices_shape_size = 1 | ||||
| self.gather_revert.shard(((1, 1), (get_group_size(),))) | self.gather_revert.shard(((1, 1), (get_group_size(),))) | ||||
| self.forward_unique = True | self.forward_unique = True | ||||
| @@ -88,14 +88,14 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay, param, m, v, gradient, d | |||||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | ||||
| "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool") | |||||
| "Tensor", "Tensor", "Tensor", "Tensor", "RowTensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||||
| def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, | def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, | ||||
| beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter): | |||||
| beta2_power, beta1, beta2, eps, lr, gradient, param, m, v, ps_parameter, cache_enable): | |||||
| """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" | """Apply sparse adam optimizer to the weight parameter when the gradient is sparse.""" | ||||
| success = True | success = True | ||||
| indices = gradient.indices | indices = gradient.indices | ||||
| values = gradient.values | values = gradient.values | ||||
| if ps_parameter: | |||||
| if ps_parameter and not cache_enable: | |||||
| op_shape = P.Shape() | op_shape = P.Shape() | ||||
| shapes = (op_shape(param), op_shape(m), op_shape(v), | shapes = (op_shape(param), op_shape(m), op_shape(v), | ||||
| op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), | op_shape(beta1_power), op_shape(beta2_power), op_shape(lr), op_shape(beta1), | ||||
| @@ -158,12 +158,13 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, | |||||
| @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | @_adam_opt.register("Function", "Function", "Function", "Function", "Bool", "Bool", "Bool", "Tensor", "Tensor", | ||||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool") | |||||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, beta1_power, | |||||
| beta2_power, beta1, beta2, eps, lr, gradient, param, moment1, moment2, ps_parameter): | |||||
| "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||||
| def _run_opt_with_one_number(opt, sparse_opt, push, pull, use_locking, use_nesterov, target, | |||||
| beta1_power, beta2_power, beta1, beta2, eps, lr, gradient, param, | |||||
| moment1, moment2, ps_parameter, cache_enable): | |||||
| """Apply adam optimizer to the weight parameter using Tensor.""" | """Apply adam optimizer to the weight parameter using Tensor.""" | ||||
| success = True | success = True | ||||
| if ps_parameter: | |||||
| if ps_parameter and not cache_enable: | |||||
| op_shape = P.Shape() | op_shape = P.Shape() | ||||
| success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), | success = F.depend(success, pull(push((beta1_power, beta2_power, lr, beta1, beta2, eps, gradient), | ||||
| (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) | (op_shape(param), op_shape(moment1), op_shape(moment2))), param)) | ||||
| @@ -338,12 +339,12 @@ class Adam(Optimizer): | |||||
| success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | ||||
| self.use_locking, self.use_nesterov, self._is_device, | self.use_locking, self.use_nesterov, self._is_device, | ||||
| beta1_power, beta2_power, self.beta1, self.beta2, self.eps), | beta1_power, beta2_power, self.beta1, self.beta2, self.eps), | ||||
| lr, gradients, params, moment1, moment2, self.ps_parameters) | |||||
| lr, gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable) | |||||
| else: | else: | ||||
| success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | success = self.map_(F.partial(_adam_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | ||||
| self.use_locking, self.use_nesterov, self._is_device, | self.use_locking, self.use_nesterov, self._is_device, | ||||
| beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), | beta1_power, beta2_power, self.beta1, self.beta2, self.eps, lr), | ||||
| gradients, params, moment1, moment2, self.ps_parameters) | |||||
| gradients, params, moment1, moment2, self.ps_parameters, self.cache_enable) | |||||
| return success | return success | ||||
| @Optimizer.target.setter | @Optimizer.target.setter | ||||
| @@ -24,14 +24,14 @@ _ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | |||||
| @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", | @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", | ||||
| "RowTensor", "Tensor", "Tensor", "Bool") | |||||
| "RowTensor", "Tensor", "Tensor", "Bool", "Bool") | |||||
| def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, | def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, | ||||
| gradient, weight, moment, ps_parameter): | |||||
| gradient, weight, moment, ps_parameter, cache_enable): | |||||
| """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" | """Apply sparse ftrl optimizer to the weight parameter when the gradient is sparse.""" | ||||
| success = True | success = True | ||||
| indices = gradient.indices | indices = gradient.indices | ||||
| values = gradient.values | values = gradient.values | ||||
| if ps_parameter: | |||||
| if ps_parameter and not cache_enable: | |||||
| op_shape = P.Shape() | op_shape = P.Shape() | ||||
| shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) | shapes = (op_shape(weight), op_shape(moment), op_shape(linear), op_shape(values), op_shape(indices)) | ||||
| success = F.depend(success, pull(push((values, indices), shapes), weight)) | success = F.depend(success, pull(push((values, indices), shapes), weight)) | ||||
| @@ -41,12 +41,12 @@ def _tensor_run_opt_with_sparse(opt, spars_opt, push, pull, l1, l2, lr_power, le | |||||
| @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", | @_ftrl_opt.register("Function", "Function", "Function", "Function", "Number", "Number", "Number", "Tensor", "Tensor", | ||||
| "Tensor", "Tensor", "Tensor", "Bool") | |||||
| "Tensor", "Tensor", "Tensor", "Bool", "Bool") | |||||
| def _tensor_run_opt(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, | def _tensor_run_opt(opt, spars_opt, push, pull, l1, l2, lr_power, learning_rate, linear, | ||||
| gradient, weight, moment, ps_parameter): | |||||
| gradient, weight, moment, ps_parameter, cache_enable): | |||||
| """Apply ftrl optimizer to the weight parameter.""" | """Apply ftrl optimizer to the weight parameter.""" | ||||
| success = True | success = True | ||||
| if ps_parameter: | |||||
| if ps_parameter and not cache_enable: | |||||
| op_shape = P.Shape() | op_shape = P.Shape() | ||||
| success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power), | success = F.depend(success, pull(push((gradient, learning_rate, l1, l2, lr_power), | ||||
| (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) | (op_shape(weight), op_shape(moment), op_shape(linear))), weight)) | ||||
| @@ -185,7 +185,7 @@ class FTRL(Optimizer): | |||||
| success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, | ||||
| self.l1, self.l2, self.lr_power, lr), | self.l1, self.l2, self.lr_power, lr), | ||||
| linear, grads, params, moments, self.ps_parameters) | |||||
| linear, grads, params, moments, self.ps_parameters, self.cache_enable) | |||||
| return success | return success | ||||
| @Optimizer.target.setter | @Optimizer.target.setter | ||||
| @@ -156,6 +156,8 @@ class Optimizer(Cell): | |||||
| break | break | ||||
| ps_filter = lambda x: x.is_param_ps | ps_filter = lambda x: x.is_param_ps | ||||
| self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) | self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) | ||||
| ps_cache_filter = lambda x: x.cache_enable | |||||
| self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters) | |||||
| self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) | self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) | ||||
| self.need_scale = loss_scale != 1.0 | self.need_scale = loss_scale != 1.0 | ||||
| self.global_step_increase_tensor = Tensor(1, mstype.int32) | self.global_step_increase_tensor = Tensor(1, mstype.int32) | ||||
| @@ -117,3 +117,21 @@ def _is_role_pserver(): | |||||
| def _is_role_sched(): | def _is_role_sched(): | ||||
| return ps_context().is_role_sched() | return ps_context().is_role_sched() | ||||
| def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size): | |||||
| ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size) | |||||
| def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size): | |||||
| ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size) | |||||
| def _insert_weight_init_info(name, global_seed, op_seed): | |||||
| ps_context().insert_weight_init_info(name, global_seed, op_seed) | |||||
| def _insert_accumu_init_info(name, init_val): | |||||
| ps_context().insert_accumu_init_info(name, init_val) | |||||
| def _clone_hash_table(dest_param_name, src_param_name): | |||||
| ps_context().clone_hash_table(dest_param_name, src_param_name) | |||||
| def _set_cache_enable(cache_enable): | |||||
| ps_context().set_cache_enable(cache_enable) | |||||
| @@ -92,6 +92,10 @@ def connect_network_with_dataset(network, dataset_helper): | |||||
| if isinstance(dataset_iter, _DatasetIterNormal): | if isinstance(dataset_iter, _DatasetIterNormal): | ||||
| raise RuntimeError("Dataset should be connected with network only in sink mode.") | raise RuntimeError("Dataset should be connected with network only in sink mode.") | ||||
| ms_role = os.getenv("MS_ROLE") | |||||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||||
| return network | |||||
| if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ | if (hasattr(dataset_iter, "sink_size") and dataset_iter.sink_size == 1) \ | ||||
| and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ | and (hasattr(dataset_iter, "sink_count") and dataset_iter.sink_count == 1) \ | ||||
| and context.get_context("device_target") == "Ascend" \ | and context.get_context("device_target") == "Ascend" \ | ||||
| @@ -166,14 +170,14 @@ class DatasetHelper: | |||||
| iterclass = _DatasetIterGE | iterclass = _DatasetIterGE | ||||
| else: | else: | ||||
| if context.get_context("mode") == context.GRAPH_MODE: | if context.get_context("mode") == context.GRAPH_MODE: | ||||
| if context.get_context("device_target") == "Ascend": | |||||
| ms_role = os.getenv("MS_ROLE") | |||||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||||
| iterclass = _DatasetIterPSServer | |||||
| elif ms_role == "MS_WORKER": | |||||
| iterclass = _DatasetIterPSWork | |||||
| elif (context.get_context("device_target") == "Ascend") or \ | |||||
| (context.get_context("device_target") == "GPU"): | |||||
| iterclass = _DatasetIterMSLoopSink | iterclass = _DatasetIterMSLoopSink | ||||
| elif context.get_context("device_target") == "GPU": | |||||
| ms_role = os.getenv("MS_ROLE") | |||||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||||
| iterclass = _DatasetIterPSLite | |||||
| else: | |||||
| iterclass = _DatasetIterMSLoopSink | |||||
| elif context.get_context("device_target") == "CPU": | elif context.get_context("device_target") == "CPU": | ||||
| raise RuntimeError( | raise RuntimeError( | ||||
| "Currently dataset sink mode is not supported when the device target is CPU.") | "Currently dataset sink mode is not supported when the device target is CPU.") | ||||
| @@ -218,7 +222,10 @@ class _DatasetIter: | |||||
| if not hasattr(dataset, '__transfer_dataset__'): | if not hasattr(dataset, '__transfer_dataset__'): | ||||
| if hasattr(dataset, '__loop_size__'): | if hasattr(dataset, '__loop_size__'): | ||||
| self.sink_size = dataset.__loop_size__ | |||||
| ms_role = os.getenv("MS_ROLE") | |||||
| # PS mode does not support loop sink and need get the real sink size. | |||||
| if ms_role != "MS_WORKER": | |||||
| self.sink_size = dataset.__loop_size__ | |||||
| create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and context.get_context( | create_data_info_queue = (sink_size == 1 and self.sink_count == 1 and context.get_context( | ||||
| "device_target") == "Ascend") | "device_target") == "Ascend") | ||||
| dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size, | dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size, | ||||
| @@ -260,8 +267,12 @@ class _DatasetIter: | |||||
| def get_sink_size(self): | def get_sink_size(self): | ||||
| """get sink_size to device""" | """get sink_size to device""" | ||||
| sink_size = 1 | sink_size = 1 | ||||
| ms_role = os.getenv("MS_ROLE") | |||||
| if hasattr(self.dataset, '__loop_size__'): | if hasattr(self.dataset, '__loop_size__'): | ||||
| sink_size = self.dataset.__loop_size__ | sink_size = self.dataset.__loop_size__ | ||||
| elif ms_role == "MS_WORKER": | |||||
| # PS mode does not support loop sink. | |||||
| sink_size = 1 | |||||
| else: | else: | ||||
| if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend" \ | if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend" \ | ||||
| or context.get_context("device_target") == "GPU": | or context.get_context("device_target") == "GPU": | ||||
| @@ -311,9 +322,6 @@ class _DatasetIterMSLoopSink(_DatasetIter): | |||||
| def __init__(self, dataset, sink_size, epoch_num): | def __init__(self, dataset, sink_size, epoch_num): | ||||
| super().__init__(dataset, sink_size, epoch_num) | super().__init__(dataset, sink_size, epoch_num) | ||||
| self.sink_count = self.get_sink_count(dataset) | self.sink_count = self.get_sink_count(dataset) | ||||
| ms_role = os.getenv("MS_ROLE") | |||||
| if ms_role in ("MS_PSERVER", "MS_SCHED"): | |||||
| self.sink_count = 1 | |||||
| # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, | # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch, | ||||
| # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for | # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for | ||||
| # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. | # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink. | ||||
| @@ -341,8 +349,8 @@ class _DatasetIterMS(_DatasetIter): | |||||
| self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | ||||
| class _DatasetIterPSLite(_DatasetIter): | |||||
| """Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED""" | |||||
| class _DatasetIterPSServer(_DatasetIter): | |||||
| """Iter for context on MS_PSERVER or MS_SCHED""" | |||||
| def __init__(self, dataset, sink_size, epoch_num): | def __init__(self, dataset, sink_size, epoch_num): | ||||
| super().__init__(dataset, sink_size, epoch_num) | super().__init__(dataset, sink_size, epoch_num) | ||||
| @@ -355,6 +363,20 @@ class _DatasetIterPSLite(_DatasetIter): | |||||
| self.op = op | self.op = op | ||||
| class _DatasetIterPSWork(_DatasetIter): | |||||
| """Iter for context on MS_WORKER""" | |||||
| def __init__(self, dataset, sink_size, epoch_num): | |||||
| super().__init__(dataset, sink_size, epoch_num) | |||||
| if sink_size > 0: | |||||
| self.sink_count = sink_size | |||||
| else: | |||||
| self.sink_count = dataset.get_dataset_size() | |||||
| def op(): | |||||
| return tuple() | |||||
| self.op = op | |||||
| class _DatasetIterNormal: | class _DatasetIterNormal: | ||||
| """Iter for normal(non sink) mode, feed the data from host.""" | """Iter for normal(non sink) mode, feed the data from host.""" | ||||
| @@ -30,6 +30,7 @@ def argparse_init(): | |||||
| parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.") | parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.") | ||||
| parser.add_argument("--field_size", type=int, default=39, help="The number of features.") | parser.add_argument("--field_size", type=int, default=39, help="The number of features.") | ||||
| parser.add_argument("--vocab_size", type=int, default=200000, help="The total features of dataset.") | parser.add_argument("--vocab_size", type=int, default=200000, help="The total features of dataset.") | ||||
| parser.add_argument("--vocab_cache_size", type=int, default=0, help="The total features of hash table.") | |||||
| parser.add_argument("--emb_dim", type=int, default=80, help="The dense embedding dimension of sparse feature.") | parser.add_argument("--emb_dim", type=int, default=80, help="The dense embedding dimension of sparse feature.") | ||||
| parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128], | parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128], | ||||
| help="The dimension of all deep layers.") | help="The dimension of all deep layers.") | ||||
| @@ -66,6 +67,7 @@ class WideDeepConfig(): | |||||
| self.eval_batch_size = 16000 | self.eval_batch_size = 16000 | ||||
| self.field_size = 39 | self.field_size = 39 | ||||
| self.vocab_size = 200000 | self.vocab_size = 200000 | ||||
| self.vocab_cache_size = 100000 | |||||
| self.emb_dim = 80 | self.emb_dim = 80 | ||||
| self.deep_layer_dim = [1024, 512, 256, 128] | self.deep_layer_dim = [1024, 512, 256, 128] | ||||
| self.deep_layer_act = 'relu' | self.deep_layer_act = 'relu' | ||||
| @@ -103,6 +105,7 @@ class WideDeepConfig(): | |||||
| self.eval_batch_size = args.eval_batch_size | self.eval_batch_size = args.eval_batch_size | ||||
| self.field_size = args.field_size | self.field_size = args.field_size | ||||
| self.vocab_size = args.vocab_size | self.vocab_size = args.vocab_size | ||||
| self.vocab_cache_size = args.vocab_cache_size | |||||
| self.emb_dim = args.emb_dim | self.emb_dim = args.emb_dim | ||||
| self.deep_layer_dim = args.deep_layer_dim | self.deep_layer_dim = args.deep_layer_dim | ||||
| self.deep_layer_act = args.deep_layer_act | self.deep_layer_act = args.deep_layer_act | ||||
| @@ -147,6 +147,7 @@ class WideDeepModel(nn.Cell): | |||||
| sparse = config.sparse | sparse = config.sparse | ||||
| self.field_size = config.field_size | self.field_size = config.field_size | ||||
| self.vocab_size = config.vocab_size | self.vocab_size = config.vocab_size | ||||
| self.vocab_cache_size = config.vocab_cache_size | |||||
| self.emb_dim = config.emb_dim | self.emb_dim = config.emb_dim | ||||
| self.deep_layer_dims_list = config.deep_layer_dim | self.deep_layer_dims_list = config.deep_layer_dim | ||||
| self.deep_layer_act = config.deep_layer_act | self.deep_layer_act = config.deep_layer_act | ||||
| @@ -237,8 +238,20 @@ class WideDeepModel(nn.Cell): | |||||
| self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) | self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) | ||||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | self.embedding_table = self.deep_embeddinglookup.embedding_table | ||||
| elif parameter_server: | elif parameter_server: | ||||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) | |||||
| cache_enable = self.vocab_cache_size > 0 | |||||
| target = 'DEVICE' if cache_enable else 'CPU' | |||||
| if is_auto_parallel and config.full_batch and cache_enable: | |||||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, | |||||
| slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, | |||||
| sparse=sparse, vocab_cache_size=self.vocab_cache_size) | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, | |||||
| slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, | |||||
| sparse=sparse, vocab_cache_size=self.vocab_cache_size) | |||||
| else: | |||||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, | |||||
| sparse=sparse, vocab_cache_size=self.vocab_cache_size) | |||||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, sparse=sparse, | |||||
| vocab_cache_size=self.vocab_cache_size) | |||||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | self.embedding_table = self.deep_embeddinglookup.embedding_table | ||||
| self.deep_embeddinglookup.embedding_table.set_param_ps() | self.deep_embeddinglookup.embedding_table.set_param_ps() | ||||
| self.wide_embeddinglookup.embedding_table.set_param_ps() | self.wide_embeddinglookup.embedding_table.set_param_ps() | ||||
| @@ -344,7 +357,7 @@ class TrainStepWrap(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, | def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, | ||||
| sparse=False): | |||||
| sparse=False, cache_enable=False): | |||||
| super(TrainStepWrap, self).__init__() | super(TrainStepWrap, self).__init__() | ||||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | parallel_mode = context.get_auto_parallel_context("parallel_mode") | ||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | ||||
| @@ -361,7 +374,7 @@ class TrainStepWrap(nn.Cell): | |||||
| self.weights_w = ParameterTuple(weights_w) | self.weights_w = ParameterTuple(weights_w) | ||||
| self.weights_d = ParameterTuple(weights_d) | self.weights_d = ParameterTuple(weights_d) | ||||
| if (sparse and is_auto_parallel) or parameter_server: | |||||
| if (sparse and is_auto_parallel) or (parameter_server and not cache_enable): | |||||
| self.optimizer_d = LazyAdam( | self.optimizer_d = LazyAdam( | ||||
| self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | ||||
| self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, | self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, | ||||
| @@ -417,10 +430,17 @@ class TrainStepWrap(nn.Cell): | |||||
| class PredictWithSigmoid(nn.Cell): | class PredictWithSigmoid(nn.Cell): | ||||
| """ | |||||
| Predict definition | |||||
| """ | |||||
| def __init__(self, network): | def __init__(self, network): | ||||
| super(PredictWithSigmoid, self).__init__() | super(PredictWithSigmoid, self).__init__() | ||||
| self.network = network | self.network = network | ||||
| self.sigmoid = P.Sigmoid() | self.sigmoid = P.Sigmoid() | ||||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||||
| if is_auto_parallel: | |||||
| self.sigmoid.shard(((1, 1),)) | |||||
| def construct(self, batch_ids, batch_wts, labels): | def construct(self, batch_ids, batch_wts, labels): | ||||
| logits, _, = self.network(batch_ids, batch_wts) | logits, _, = self.network(batch_ids, batch_wts) | ||||
| @@ -39,7 +39,8 @@ def get_WideDeep_net(config): | |||||
| """ | """ | ||||
| WideDeep_net = WideDeepModel(config) | WideDeep_net = WideDeepModel(config) | ||||
| loss_net = NetWithLossClass(WideDeep_net, config) | loss_net = NetWithLossClass(WideDeep_net, config) | ||||
| train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server)) | |||||
| train_net = TrainStepWrap(loss_net, parameter_server=bool(config.parameter_server), | |||||
| cache_enable=bool(config.vocab_cache_size > 0)) | |||||
| eval_net = PredictWithSigmoid(WideDeep_net) | eval_net = PredictWithSigmoid(WideDeep_net) | ||||
| return train_net, eval_net | return train_net, eval_net | ||||
| @@ -81,6 +82,7 @@ def train_and_eval(config): | |||||
| else: | else: | ||||
| dataset_type = DataType.H5 | dataset_type = DataType.H5 | ||||
| parameter_server = bool(config.parameter_server) | parameter_server = bool(config.parameter_server) | ||||
| cache_enable = bool(config.vocab_cache_size > 0) | |||||
| print("epochs is {}".format(epochs)) | print("epochs is {}".format(epochs)) | ||||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | ds_train = create_dataset(data_path, train_mode=True, epochs=1, | ||||
| batch_size=batch_size, rank_id=get_rank(), | batch_size=batch_size, rank_id=get_rank(), | ||||
| @@ -111,7 +113,7 @@ def train_and_eval(config): | |||||
| callback_list.append(ckpoint_cb) | callback_list.append(ckpoint_cb) | ||||
| model.train(epochs, ds_train, | model.train(epochs, ds_train, | ||||
| callbacks=callback_list, | callbacks=callback_list, | ||||
| dataset_sink_mode=(not parameter_server)) | |||||
| dataset_sink_mode=(parameter_server and cache_enable)) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||