| @@ -1,22 +0,0 @@ | |||
| if(ENABLE_GITEE) | |||
| set(REQ_URL "https://gitee.com/mirrors/ps-lite/repository/archive/34fd45cae457d59850fdcb2066467778d0673f21.zip") | |||
| set(MD5 "0d1543b8dcb0bc3610637e1643c94eb4") | |||
| else() | |||
| set(REQ_URL "https://github.com/dmlc/ps-lite/archive/34fd45cae457d59850fdcb2066467778d0673f21.zip") | |||
| set(MD5 "393c0e27b68bfaf96718caa3aa96f5a3") | |||
| endif() | |||
| set(pslite_USE_STATIC_LIBS ON) | |||
| if(${ENABLE_IBVERBS} STREQUAL "ON") | |||
| set(pslite_CXXFLAGS "USE_IBVERBS=1") | |||
| endif() | |||
| mindspore_add_pkg(pslite | |||
| LIBS ps | |||
| URL ${REQ_URL} | |||
| MD5 ${MD5} | |||
| PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/pslite/ps_lite.patch001 | |||
| ONLY_MAKE True | |||
| ONLY_MAKE_INCS include/* | |||
| ONLY_MAKE_LIBS build/*) | |||
| include_directories(${pslite_INC}) | |||
| add_library(mindspore::pslite ALIAS pslite::ps) | |||
| @@ -1,5 +0,0 @@ | |||
| mindspore_add_pkg(zeromq | |||
| VER 4.1.4 | |||
| HEAD_ONLY ./ | |||
| URL https://raw.githubusercontent.com/mli/deps/master/build/zeromq-4.1.4.tar.gz | |||
| MD5 a611ecc93fffeb6d058c0e6edf4ad4fb) | |||
| @@ -32,10 +32,6 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/flatbuffers.cmake) | |||
| if(USE_GLOG) | |||
| include(${CMAKE_SOURCE_DIR}/cmake/external_libs/glog.cmake) | |||
| endif() | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| include(${CMAKE_SOURCE_DIR}/cmake/external_libs/zeromq.cmake) | |||
| include(${CMAKE_SOURCE_DIR}/cmake/external_libs/pslite.cmake) | |||
| endif() | |||
| find_package(Python3) | |||
| include_directories(${Python3_INCLUDE_DIRS}) | |||
| @@ -339,8 +339,8 @@ elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin") | |||
| target_link_libraries(_c_expression PRIVATE -Wl,-force_load mindspore -Wl,-noall_load) | |||
| else() | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| target_link_libraries(mindspore mindspore::pslite proto_input mindspore::protobuf | |||
| mindspore::event mindspore::event_pthreads ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) | |||
| target_link_libraries(mindspore proto_input mindspore::protobuf | |||
| mindspore::event mindspore::event_pthreads) | |||
| target_link_libraries(mindspore -Wl,--no-as-needed mindspore::event_core ps_cache) | |||
| if(${ENABLE_IBVERBS} STREQUAL "ON") | |||
| target_link_libraries(mindspore ibverbs rdmacm) | |||
| @@ -17,6 +17,7 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include "ps/worker.h" | |||
| #include "ps/util.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -35,7 +36,7 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| << input_shape << " is too large."; | |||
| } | |||
| if (mindspore::ps::Util::IsRoleOfWorker()) { | |||
| if (mindspore::ps::PSContext::instance()->is_worker()) { | |||
| key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey); | |||
| } | |||
| std::vector<size_t> keys{key_, key_, key_}; | |||
| @@ -50,9 +51,10 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| << ", indices_shape:" << indices_shape << ", output_shape:" << output_shape; | |||
| std::vector<int64_t> lens{SizeToLong(input_shape.size()), SizeToLong(indices_shape.size()), | |||
| SizeToLong(output_shape.size())}; | |||
| if (mindspore::ps::Util::IsRoleOfWorker()) { | |||
| if (mindspore::ps::PSContext::instance()->is_worker()) { | |||
| mindspore::ps::worker.AddEmbeddingTable(key_, input_shape[axis]); | |||
| mindspore::ps::worker.InitPSEmbeddingTable(keys, values, lens); | |||
| mindspore::ps::ParamInitInfoMessage info; | |||
| mindspore::ps::worker.InitPSEmbeddingTable(key_, input_shape, indices_shape, output_shape, info); | |||
| } | |||
| } | |||
| @@ -70,17 +72,16 @@ bool EmbeddingLookUpProxyKernel::Launch(const std::vector<kernel::AddressPtr> &i | |||
| size_t input_size = inputs[1]->size; | |||
| size_t output_size = outputs[0]->size; | |||
| size_t size = input_size / sizeof(float); | |||
| ::ps::SArray<int> lookup_ids(size, 0); | |||
| ::ps::SArray<int> lengths{size}; | |||
| ::ps::SArray<float> lookup_result(output_size / sizeof(float), 0); | |||
| size_t size = input_size / sizeof(int); | |||
| std::vector<int> lookup_ids(size, 0); | |||
| std::vector<int> lengths{SizeToInt(size)}; | |||
| std::vector<float> lookup_result(output_size / sizeof(float), 0); | |||
| auto ret = memcpy_s(lookup_ids.data(), lookup_ids.size() * sizeof(int), indices_addr, input_size); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| mindspore::ps::worker.DoPSEmbeddingLookup({key_}, lookup_ids, lengths, &lookup_result, | |||
| mindspore::ps::kEmbeddingLookupCmd); | |||
| mindspore::ps::worker.DoPSEmbeddingLookup(key_, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| auto ret2 = memcpy_s(output_addr, outputs[0]->size, lookup_result.data(), output_size); | |||
| if (ret2 != EOK) { | |||
| @@ -62,7 +62,7 @@ class PullKernel : public CPUKernel { | |||
| MS_EXCEPTION_IF_NULL(param_node); | |||
| param_name_ = param_node->fullname_with_scope(); | |||
| if (mindspore::ps::Util::IsRoleOfWorker()) { | |||
| if (mindspore::ps::PSContext::instance()->is_worker()) { | |||
| key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey); | |||
| } | |||
| InitSizeLists(); | |||
| @@ -30,6 +30,7 @@ | |||
| #include "backend/optimizer/pass/replace_node_by_proxy.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -75,9 +76,9 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr | |||
| MS_LOG(INFO) << "Set kernel info"; | |||
| SetKernelInfo(graph.get()); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsParamServerMode()) { | |||
| if (ps::PSContext::instance()->is_ps_mode()) { | |||
| AssignParamKey(graph); | |||
| if (ps::Util::IsRoleOfWorker()) { | |||
| if (ps::PSContext::instance()->is_worker()) { | |||
| Optimize(graph); | |||
| } | |||
| } | |||
| @@ -42,8 +42,9 @@ | |||
| #include "utils/trace_base.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "ps/common.h" | |||
| #include "ps/constants.h" | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| #include "abstract/abstract_value.h" | |||
| #endif | |||
| @@ -2288,7 +2289,7 @@ void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) const { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) { | |||
| if (!ps::Util::IsRoleOfWorker()) { | |||
| if (!ps::PSContext::instance()->is_worker()) { | |||
| return; | |||
| } | |||
| CheckPSModeConsistence(kernel_graph); | |||
| @@ -2385,7 +2386,7 @@ void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | |||
| void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, | |||
| const std::vector<tensor::TensorPtr> &inputs_const) { | |||
| if (!ps::Util::IsRoleOfWorker()) { | |||
| if (!ps::PSContext::instance()->is_worker()) { | |||
| return; | |||
| } | |||
| std::vector<tensor::TensorPtr> inputs(inputs_const); | |||
| @@ -48,6 +48,7 @@ | |||
| #include "mindspore/core/utils/parallel_node_check.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| #endif | |||
| using mindspore::tensor::Tensor; | |||
| @@ -3283,7 +3284,7 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) { | |||
| bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsRoleOfPServer() || ps::Util::IsRoleOfScheduler()) { | |||
| if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { | |||
| return false; | |||
| } | |||
| #endif | |||
| @@ -288,7 +288,6 @@ if(${CMAKE_SYSTEM_NAME} MATCHES "Windows") | |||
| else() | |||
| target_link_libraries(_c_dataengine PRIVATE _c_mindrecord) | |||
| if(ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU)) | |||
| target_link_libraries(_c_dataengine PRIVATE mindspore::pslite ${zeromq_DIRPATH}/zmq_install/lib/libzmq.a) | |||
| if(${ENABLE_IBVERBS} STREQUAL "ON") | |||
| target_link_libraries(_c_dataengine PRIVATE ibverbs rdmacm) | |||
| endif() | |||
| @@ -460,7 +460,7 @@ bool StartPSWorkerAction(const ResourcePtr &res) { | |||
| bool StartPSServerAction(const ResourcePtr &res) { | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| auto &ps = ps::ParameterServer<float>::GetInstance(); | |||
| auto &ps = ps::ParameterServer::GetInstance(); | |||
| ps.Run(func_graph); | |||
| return true; | |||
| } | |||
| @@ -626,7 +626,7 @@ std::vector<ActionItem> VmPipeline() { | |||
| actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsRoleOfWorker()) { | |||
| if (ps::PSContext::instance()->is_worker()) { | |||
| actions.emplace_back(std::make_pair("worker", StartPSWorkerAction)); | |||
| } | |||
| #endif | |||
| @@ -43,6 +43,7 @@ | |||
| #include "pipeline/jit/static_analysis/auto_monad.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -406,7 +407,7 @@ bool AddRecomputationPass(const ResourcePtr &res) { | |||
| bool AddCacheEmbeddingPass(const ResourcePtr &res) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsParamServerMode()) { | |||
| if (ps::PSContext::instance()->is_ps_mode()) { | |||
| return true; | |||
| } | |||
| #endif | |||
| @@ -49,7 +49,7 @@ | |||
| #include "utils/shape_utils.h" | |||
| #include "utils/info.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/common.h" | |||
| #include "ps/constants.h" | |||
| #include "ps/util.h" | |||
| #include "ps/worker.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| @@ -492,14 +492,11 @@ std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::strin | |||
| std::string backend = MsContext::GetInstance()->backend_policy(); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (mindspore::ps::Util::IsParamServerMode()) { | |||
| mindspore::ps::Util::SetInternalEnvVar(); | |||
| } | |||
| if (ps::Util::IsRoleOfPServer()) { | |||
| if (ps::PSContext::instance()->is_server()) { | |||
| resource->results()[kBackend] = compile::CreateBackend(); | |||
| return PServerPipeline(); | |||
| } | |||
| if (ps::Util::IsRoleOfScheduler()) { | |||
| if (ps::PSContext::instance()->is_scheduler()) { | |||
| return PSchedulerPipeline(); | |||
| } | |||
| #endif | |||
| @@ -978,7 +975,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes, | |||
| const std::vector<int64_t> &input_indexes, bool need_run) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if ((ps::Util::IsParamServerMode()) && (!ps::Util::IsRoleOfWorker())) { | |||
| if ((ps::PSContext::instance()->is_ps_mode()) && (!ps::PSContext::instance()->is_worker())) { | |||
| return true; | |||
| } | |||
| #endif | |||
| @@ -1030,7 +1027,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| ConfigManager::GetInstance().set_iter_num(size); | |||
| // PS cache does not support loop sink. | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); | |||
| ConfigManager::GetInstance().set_iter_num(1); | |||
| } | |||
| @@ -1151,10 +1148,11 @@ void ClearResAtexit() { | |||
| pynative::ClearPyNativeSession(); | |||
| session::ClearPythonParasMap(); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) { | |||
| if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->is_worker()) { | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps::ps_cache_instance.Finalize(); | |||
| } | |||
| MS_LOG(INFO) << "ps::worker.Finalize"; | |||
| ps::worker.Finalize(); | |||
| } | |||
| #endif | |||
| @@ -21,8 +21,8 @@ if(NOT (ENABLE_CPU AND (ENABLE_D OR ENABLE_GPU))) | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/abstract_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/scheduler_node.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "core/http_client.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "internal/worker.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "internal/parameter_server.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "worker.cc") | |||
| list(REMOVE_ITEM _PS_SRC_FILES "parameter_server.cc") | |||
| endif() | |||
| if(NOT ENABLE_D) | |||
| @@ -1,140 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_COMMON_H_ | |||
| #define MINDSPORE_CCSRC_PS_COMMON_H_ | |||
| #include <limits.h> | |||
| #include <iostream> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <map> | |||
| #include <string> | |||
| #include "ps/ps.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| constexpr char kEnvCommType[] = "MS_COMM_TYPE"; | |||
| constexpr char kEnvInterface[] = "MS_INTERFACE"; | |||
| constexpr char kEnvPServerNum[] = "MS_SERVER_NUM"; | |||
| constexpr char kEnvWorkerNum[] = "MS_WORKER_NUM"; | |||
| constexpr char kEnvSchedulerHost[] = "MS_SCHED_HOST"; | |||
| constexpr char kEnvSchedulerPort[] = "MS_SCHED_PORT"; | |||
| constexpr char kDmlcCommType[] = "DMLC_PS_VAN_TYPE"; | |||
| constexpr char kDmlcInterface[] = "DMLC_INTERFACE"; | |||
| constexpr char kDmlcPServerNum[] = "DMLC_NUM_SERVER"; | |||
| constexpr char kDmlcWorkerNum[] = "DMLC_NUM_WORKER"; | |||
| constexpr char kDmlcRole[] = "DMLC_ROLE"; | |||
| constexpr char kDmlcSchedulerHost[] = "DMLC_PS_ROOT_URI"; | |||
| constexpr char kDmlcSchedulerPort[] = "DMLC_PS_ROOT_PORT"; | |||
| constexpr char kCommTypeOfIBVerbs[] = "ibverbs"; | |||
| constexpr char kCommTypeOfTCP[] = "zmq"; | |||
| constexpr char kRoleOfPServer[] = "server"; | |||
| constexpr char kRoleOfWorker[] = "worker"; | |||
| constexpr char kRoleOfScheduler[] = "scheduler"; | |||
| constexpr char kLearningRate[] = "learning_rate"; | |||
| constexpr char kMomentum[] = "momentum"; | |||
| constexpr char kApplyMomentum[] = "ApplyMomentum"; | |||
| constexpr char kSparseAdam[] = "Adam"; | |||
| constexpr char kSparseLazyAdam[] = "LazyAdam"; | |||
| constexpr char kSparseFtrl[] = "Ftrl"; | |||
| constexpr char kApplyMomentumOp[] = "Momentum"; | |||
| constexpr char kSparseAdamOp[] = "Adam"; | |||
| constexpr char kSparseLazyAdamOp[] = "LazyAdam"; | |||
| constexpr char kSparseFtrlOp[] = "FTRL"; | |||
| constexpr int64_t kInitWeightsCmd = 10; | |||
| constexpr int64_t kInitWeightToOptimIdCmd = 11; | |||
| constexpr int64_t kInitOptimInputsShapeCmd = 12; | |||
| constexpr int64_t kInitKeyToPushNodeIdCmd = 13; | |||
| constexpr int64_t kInitEmbeddingsCmd = 20; | |||
| constexpr int64_t kUpdateEmbeddingsCmd = 21; | |||
| constexpr int64_t kCheckReadyForPushCmd = 25; | |||
| constexpr int64_t kCheckReadyForPullCmd = 26; | |||
| constexpr int64_t kEmbeddingLookupCmd = 30; | |||
| constexpr int64_t kFinalizeCmd = 40; | |||
| constexpr size_t kInvalidKey = UINT64_MAX; | |||
| constexpr int64_t kInvalidID = -1; | |||
| using DataPtr = std::shared_ptr<unsigned char>; | |||
| using VectorPtr = std::shared_ptr<std::vector<unsigned char>>; | |||
| using Key = ::ps::Key; | |||
| using Keys = ::ps::SArray<Key>; | |||
| using Values = ::ps::SArray<float>; | |||
| using ValuesPtr = std::shared_ptr<Values>; | |||
| using Weight = ::ps::SArray<float>; | |||
| using Grad = ::ps::SArray<float>; | |||
| using LookupIds = ::ps::SArray<Key>; | |||
| using Lengths = ::ps::SArray<int>; | |||
| using WeightPtr = std::shared_ptr<Weight>; | |||
| using GradPtr = std::shared_ptr<Grad>; | |||
| using InputsShape = std::vector<std::shared_ptr<std::vector<size_t>>>; | |||
| using InputsShapePtr = std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>>; | |||
| constexpr size_t INDEX_NOT_SEND = UINT_MAX; | |||
| using OptimOriginIdx = std::map<std::string, size_t>; | |||
| using OptimPSSendIdx = std::map<std::string, size_t>; | |||
| const OptimOriginIdx kMomentumOriginIdx = {{"weight", 0}, {"accum", 1}, {"lr", 2}, {"grad", 3}, {"momentum", 4}}; | |||
| const OptimPSSendIdx kMomentumPSSendIdx = { | |||
| {"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"lr", 0}, {"grad", 1}, {"momentum", 2}}; | |||
| const OptimOriginIdx kSparseAdamOriginIdx = {{"weight", 0}, {"m", 1}, {"v", 2}, {"beta1_power", 3}, | |||
| {"beta2_power", 4}, {"lr", 5}, {"beta1", 6}, {"beta2", 7}, | |||
| {"eps", 8}, {"grad", 9}, {"indices", 10}}; | |||
| const OptimPSSendIdx kSparseAdamPSSendIdx = {{"weight", INDEX_NOT_SEND}, | |||
| {"m", INDEX_NOT_SEND}, | |||
| {"v", INDEX_NOT_SEND}, | |||
| {"beta1_power", 0}, | |||
| {"beta2_power", 1}, | |||
| {"lr", 2}, | |||
| {"beta1", 3}, | |||
| {"beta2", 4}, | |||
| {"eps", 5}, | |||
| {"grad", 6}, | |||
| {"indices", 7}}; | |||
| const OptimOriginIdx kSparseFtrlOriginIdx = {{"weight", 0}, {"accum", 1}, {"linear", 2}, {"grad", 3}, {"indices", 4}}; | |||
| const OptimPSSendIdx kSparseFtrlPSSendIdx = { | |||
| {"weight", INDEX_NOT_SEND}, {"accum", INDEX_NOT_SEND}, {"linear", INDEX_NOT_SEND}, {"grad", 0}, {"indices", 1}}; | |||
| const std::map<std::string, OptimOriginIdx> kOptimToOriginIdx = {{kApplyMomentum, kMomentumOriginIdx}, | |||
| {kSparseAdam, kSparseAdamOriginIdx}, | |||
| {kSparseLazyAdam, kSparseAdamOriginIdx}, | |||
| {kSparseFtrl, kSparseFtrlOriginIdx}}; | |||
| const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum, kMomentumPSSendIdx}, | |||
| {kSparseAdam, kSparseAdamPSSendIdx}, | |||
| {kSparseLazyAdam, kSparseAdamPSSendIdx}, | |||
| {kSparseFtrl, kSparseFtrlPSSendIdx}}; | |||
| #define EXC_IF_VEC_IDX_OOB(vec, idx) \ | |||
| { \ | |||
| size_t vec_size = vec.size(); \ | |||
| if (idx >= vec_size) { \ | |||
| MS_LOG(EXCEPTION) << "Vector " << #vec << " size is " << vec_size << ". So index " << idx \ | |||
| << " is out of bound."; \ | |||
| } \ | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_COMMON_H_ | |||
| @@ -14,10 +14,11 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_ | |||
| #define MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_ | |||
| #ifndef MINDSPORE_CCSRC_PS_CONSTANTS_H_ | |||
| #define MINDSPORE_CCSRC_PS_CONSTANTS_H_ | |||
| #include <limits.h> | |||
| #include <climits> | |||
| #include <iostream> | |||
| #include <vector> | |||
| #include <memory> | |||
| @@ -26,8 +27,6 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace internal { | |||
| constexpr char kEnvCommType[] = "MS_COMM_TYPE"; | |||
| constexpr char kEnvInterface[] = "MS_INTERFACE"; | |||
| constexpr char kEnvPServerNum[] = "MS_SERVER_NUM"; | |||
| @@ -127,7 +126,6 @@ const std::map<std::string, OptimOriginIdx> kOptimToPSSendIdx = {{kApplyMomentum | |||
| << " is out of bound."; \ | |||
| } \ | |||
| } | |||
| } // namespace internal | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_INTERNAL_CONSTANTS_H_ | |||
| #endif // MINDSPORE_CCSRC_PS_CONSTANTS_H_ | |||
| @@ -39,9 +39,9 @@ void ClusterMetadata::Init(const uint32_t &worker_num, const uint32_t &server_nu | |||
| scheduler_port_ = scheduler_port; | |||
| } | |||
| uint32_t ClusterMetadata::worker_num() { return worker_num_; } | |||
| uint32_t ClusterMetadata::total_worker_num() { return worker_num_; } | |||
| uint32_t ClusterMetadata::server_num() { return server_num_; } | |||
| uint32_t ClusterMetadata::total_server_num() { return server_num_; } | |||
| uint32_t ClusterMetadata::heartbeat_interval() { return heartbeat_interval_; } | |||
| @@ -37,8 +37,8 @@ class ClusterMetadata { | |||
| void Init(const uint32_t &worker_num, const uint32_t &server_num, std::string scheduler_host, | |||
| const uint16_t &scheduler_port); | |||
| uint32_t worker_num(); | |||
| uint32_t server_num(); | |||
| uint32_t total_worker_num(); | |||
| uint32_t total_server_num(); | |||
| uint32_t heartbeat_interval(); | |||
| void set_heartbeat_interval(const uint32_t &heartbeat_interval); | |||
| std::string scheduler_host(); | |||
| @@ -122,9 +122,9 @@ std::string CommUtil::NodeRoleToString(const NodeRole &role) { | |||
| } | |||
| } | |||
| bool CommUtil::ValidateRankId(const enum NodeRole &node_role, const uint32_t &rank_id) { | |||
| if (node_role == NodeRole::SERVER && (rank_id > ClusterMetadata::instance()->server_num() - 1)) { | |||
| if (node_role == NodeRole::SERVER && (rank_id > ClusterMetadata::instance()->total_server_num() - 1)) { | |||
| return false; | |||
| } else if (node_role == NodeRole::WORKER && (rank_id > ClusterMetadata::instance()->worker_num() - 1)) { | |||
| } else if (node_role == NodeRole::WORKER && (rank_id > ClusterMetadata::instance()->total_worker_num() - 1)) { | |||
| return false; | |||
| } | |||
| return true; | |||
| @@ -20,7 +20,7 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace core { | |||
| void NodeManager::InitNodeNum() { | |||
| total_node_num_ = ClusterMetadata::instance()->server_num() + ClusterMetadata::instance()->worker_num(); | |||
| total_node_num_ = ClusterMetadata::instance()->total_server_num() + ClusterMetadata::instance()->total_worker_num(); | |||
| } | |||
| int NodeManager::NextRankId(const RegisterMessage ®ister_message) { | |||
| @@ -1,179 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_ | |||
| #define MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_ | |||
| #include <unistd.h> | |||
| #include <unordered_map> | |||
| #include <string> | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <mutex> | |||
| #include <condition_variable> | |||
| #include <thread> | |||
| #include <cmath> | |||
| #include <random> | |||
| #include <utility> | |||
| #include <list> | |||
| #include <map> | |||
| #include <functional> | |||
| #include "ir/func_graph.h" | |||
| #include "backend/session/session_basic.h" | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "backend/session/session_factory.h" | |||
| #include "ps/optimizer_info.h" | |||
| #include "ps/optimizer_info_builder.h" | |||
| #include "ps/ps_context.h" | |||
| #include "runtime/device/cpu/kernel_select_cpu.h" | |||
| #include "utils/ms_context.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "ps/random_normal/random_normal.h" | |||
| #include "ps/internal/constants.h" | |||
| #include "ps/util.h" | |||
| #include "ps/embedding_table_shard_metadata.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/core/server_node.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace internal { | |||
| class ParameterServer { | |||
| public: | |||
| static ParameterServer &GetInstance() { | |||
| static ParameterServer instance; | |||
| return instance; | |||
| } | |||
| void Run(const FuncGraphPtr &func_graph); | |||
| private: | |||
| ParameterServer() | |||
| : pserver_num_(0), | |||
| worker_num_(0), | |||
| rank_id_(0), | |||
| grad_accum_count_(0), | |||
| handler_(nullptr), | |||
| func_graph_(nullptr), | |||
| sess_(nullptr), | |||
| running_(true), | |||
| thread_(nullptr) {} | |||
| ~ParameterServer() = default; | |||
| ParameterServer(const ParameterServer &) = delete; | |||
| ParameterServer &operator=(const ParameterServer &) = delete; | |||
| class ServerHandler { | |||
| public: | |||
| explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} | |||
| ~ServerHandler() = default; | |||
| void Init(); | |||
| void operator()(std::shared_ptr<core::TcpConnection> conn, std::shared_ptr<core::MessageMeta> meta, DataPtr data, | |||
| size_t size); | |||
| void HandlePushReq(DataPtr data, size_t size, VectorPtr res); | |||
| void HandlePullReq(DataPtr data, size_t size, VectorPtr res); | |||
| void HandleInitWeights(DataPtr data, size_t size, VectorPtr res); | |||
| void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res); | |||
| void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res); | |||
| void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res); | |||
| void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res); | |||
| void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res); | |||
| void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res); | |||
| void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res); | |||
| void HandleFinalize(DataPtr data, size_t size, VectorPtr res); | |||
| private: | |||
| ParameterServer *ps_; | |||
| typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res); | |||
| std::unordered_map<int, RequestHandler> handlers_; | |||
| std::unordered_map<Key, bool> init_weights_; | |||
| std::unordered_map<Key, bool> init_weight_to_optim_; | |||
| std::unordered_map<Key, bool> init_optim_info_; | |||
| }; | |||
| bool Init(const FuncGraphPtr &func_graph); | |||
| void InitOptimInfoBuilders(); | |||
| void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id); | |||
| void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); | |||
| void InitWeight(const Key &key, const WeightPtr &weight); | |||
| void InitGrad(const Key &key, const GradPtr &grad); | |||
| void InitEmbeddingTable(const Key &key, | |||
| const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes, | |||
| const ParamInitInfo ¶m_init_info); | |||
| bool HasWeight(const Key &key); | |||
| void Finalize(); | |||
| void UpdateWeights(); | |||
| void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); | |||
| WeightPtr weight(const Key &key); | |||
| void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res); | |||
| void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals); | |||
| bool ReadyForUpdateWeights(); | |||
| bool ReadyForPush(const Key &key); | |||
| bool ReadyForPull(const Key &key); | |||
| void ResetGradAccumCount(); | |||
| const CNodePtr GetCNode(const std::string &name) const; | |||
| std::mutex &mutex(); | |||
| void GetEmbeddingTableParamPtr(); | |||
| void SyncEmbeddingTables(); | |||
| size_t pserver_num_; | |||
| size_t worker_num_; | |||
| size_t rank_id_; | |||
| size_t grad_accum_count_; | |||
| std::unique_ptr<ServerHandler> handler_; | |||
| FuncGraphPtr func_graph_; | |||
| std::shared_ptr<session::SessionBasic> sess_; | |||
| bool running_; | |||
| std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_; | |||
| std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_; | |||
| std::unordered_map<Key, InputsShapePtr> original_optim_inputs_shape_; | |||
| std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_; | |||
| std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_; | |||
| std::unordered_map<Key, std::string> weight_key_to_optims_; | |||
| std::unordered_map<Key, std::string> weight_key_to_optim_op_; | |||
| std::unordered_map<Key, WeightPtr> weights_; | |||
| std::unordered_map<Key, bool> is_embedding_; | |||
| std::unordered_map<Key, WeightPtr> grads_; | |||
| std::unordered_map<Key, size_t> grads_accum_counter_; | |||
| std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_; | |||
| std::unordered_map<Key, uint64_t> tokens_; | |||
| std::mutex mutex_; | |||
| std::condition_variable apply_grads_cv_; | |||
| std::unique_ptr<std::thread> thread_; | |||
| core::ServerNode server_node_; | |||
| std::map<Key, ParameterPtr> embedding_tables_; | |||
| friend class ServerHandler; | |||
| }; | |||
| } // namespace internal | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_INTERNAL_PARAMETER_SERVER_H_ | |||
| @@ -1,157 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_ | |||
| #define MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_ | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <numeric> | |||
| #include <functional> | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <mutex> | |||
| #include <unordered_set> | |||
| #include <unordered_map> | |||
| #include "utils/log_adapter.h" | |||
| #include "ir/tensor.h" | |||
| #include "ps/util.h" | |||
| #include "ps/internal/constants.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "ps/core/worker_node.h" | |||
| #include "ps/embedding_table_shard_metadata.h" | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/ps_context.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace internal { | |||
| class Worker { | |||
| public: | |||
| static Worker &GetInstance() { | |||
| static Worker instance; | |||
| return instance; | |||
| } | |||
| using Callback = std::function<void()>; | |||
| using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>; | |||
| using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>; | |||
| using EmbeddingPartitioner = std::function<void( | |||
| const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>; | |||
| using KVPartitioner = | |||
| std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>; | |||
| void Run(); | |||
| void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes); | |||
| void Pull(const size_t key, void *dev_addr, const size_t size); | |||
| size_t SetParamKey(const std::string ¶m_name); | |||
| size_t GetParamKey(const std::string ¶m_name); | |||
| void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); | |||
| bool GetParamInitInServer(const std::string ¶m_name); | |||
| void SetKeyOptimId(size_t key, const std::string &optimizer_name); | |||
| void SetOptimInputShapes(size_t key, const ShapeVector &shape); | |||
| void AddEmbeddingTable(const Key &key, const size_t &row_count); | |||
| void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape, | |||
| const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape); | |||
| void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor); | |||
| void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result, | |||
| int64_t cmd); | |||
| void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids, | |||
| const std::vector<float> &vals); | |||
| bool running() { return running_; } | |||
| void Finalize(); | |||
| private: | |||
| Worker() : running_(false), key_cnt_(0) {} | |||
| ~Worker() = default; | |||
| Worker(const Worker &) = delete; | |||
| Worker &operator=(const Worker &) = delete; | |||
| void Initialize(); | |||
| bool IsKeyInit(const size_t key); | |||
| void AddKeyToServerId(const Key &key); | |||
| void AddKeyByHashMod(const Key &key); | |||
| void InitPSOptimId(const size_t param_key); | |||
| void InitPSOptimInputShapes(const size_t key); | |||
| void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size); | |||
| bool IsReadyForPush(const Key &key); | |||
| bool IsReadyForPull(const Key &key); | |||
| void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids, | |||
| const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice, | |||
| const size_t segment_size, float *gradient, int *indices); | |||
| void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index, | |||
| const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data); | |||
| void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {}, | |||
| int command = 0, int64_t priority = 0); | |||
| void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens, | |||
| size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); | |||
| void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0, | |||
| int64_t priority = 0); | |||
| void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner, | |||
| const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens); | |||
| int64_t server_num_; | |||
| bool running_; | |||
| std::mutex running_mutex_; | |||
| size_t key_cnt_; | |||
| std::map<std::string, size_t> param_to_key_; | |||
| std::map<size_t, bool> init_keys_; | |||
| std::map<size_t, int64_t> key_to_optimId_; | |||
| std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_; | |||
| std::map<std::string, bool> param_to_init_in_server_; | |||
| core::WorkerNode worker_node_; | |||
| EmbeddingPartitioner lookup_partitioner_; | |||
| KVPartitioner sparse_partitioner_; | |||
| KVPartitioner round_robin_partitioner_; | |||
| KVPartitioner worker_init_embedding_partitioner_; | |||
| KVPartitioner update_embedding_partitioner_; | |||
| KVPartitioner broadcast_partitioner_; | |||
| std::unordered_map<Key, int64_t> key_to_server_id_; | |||
| std::unordered_map<Key, size_t> embedding_row_cnt_; | |||
| std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_; | |||
| }; | |||
| static Worker &worker = Worker::GetInstance(); | |||
| } // namespace internal | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_INTERNAL_WORKER_H_ | |||
| @@ -84,7 +84,7 @@ void DenseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { | |||
| for (size_t i = 0; i < grad_index; i++) { | |||
| grad_offset += lengths[i]; | |||
| } | |||
| float *grad_data = values.data() + grad_offset; | |||
| float *grad_data = const_cast<float *>(values.data()) + grad_offset; | |||
| CHECK_EQ(size, static_cast<size_t>(lengths[grad_index])); | |||
| for (size_t i = 0; i < size; i++) { | |||
| @@ -121,7 +121,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { | |||
| for (size_t i = 0; i < grad_index; i++) { | |||
| grad_offset += lengths[i]; | |||
| } | |||
| float *incr_grad_data = values.data() + grad_offset; | |||
| float *incr_grad_data = const_cast<float *>(values.data()) + grad_offset; | |||
| MS_EXCEPTION_IF_NULL(incr_grad_data); | |||
| size_t incr_grad_size = lengths[grad_index] * sizeof(float); | |||
| @@ -148,7 +148,11 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { | |||
| for (size_t i = 0; i < indices_index; i++) { | |||
| indice_offset += lengths[i]; | |||
| } | |||
| int *incr_indice_data = reinterpret_cast<int *>(values.data()) + indice_offset; | |||
| void *incr_indice_data_temp = const_cast<float *>(values.data()) + indice_offset; | |||
| int *incr_indice_data = reinterpret_cast<int *>(incr_indice_data_temp); | |||
| MS_EXCEPTION_IF_NULL(incr_indice_data); | |||
| size_t incr_indice_size = lengths[indices_index]; | |||
| size_t incr_indice_data_size = incr_indice_size * sizeof(int); | |||
| @@ -259,7 +263,7 @@ MomentumOptimInfo::MomentumOptimInfo(const AddressPtr &weight, const AddressPtr | |||
| } | |||
| void MomentumOptimInfo::Update(const Values &values, const Lengths &lens) { | |||
| UpdateOptimInputValue<float>(kApplyMomentum, "lr", values.data(), lens); | |||
| UpdateOptimInputValue<float>(kApplyMomentum, "lr", const_cast<float *>(values.data()), lens); | |||
| } | |||
| const size_t SparseOptimInfo::indice_size() const { return indices_offset_; } | |||
| @@ -303,12 +307,12 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address | |||
| } | |||
| void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { | |||
| UpdateOptimInputValue<float>(kSparseAdam, "beta1_power", values.data(), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "beta2_power", values.data(), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "lr", values.data(), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "beta1", values.data(), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "beta2", values.data(), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "eps", values.data(), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "beta1_power", const_cast<float *>(values.data()), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "beta2_power", const_cast<float *>(values.data()), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "lr", const_cast<float *>(values.data()), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "beta1", const_cast<float *>(values.data()), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "beta2", const_cast<float *>(values.data()), lens); | |||
| UpdateOptimInputValue<float>(kSparseAdam, "eps", const_cast<float *>(values.data()), lens); | |||
| } | |||
| const AddressPtr &SparseAdamOptimInfo::gradient() { | |||
| @@ -20,7 +20,7 @@ | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "ps/common.h" | |||
| #include "ps/constants.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -129,9 +129,9 @@ OptimizerInfo *MomentumOptimInfoBuilder::BuildInputs(const WeightPtr &weight, co | |||
| return nullptr; | |||
| } | |||
| AddressPtr learning_rate = GenInputAddrPtr<float>(kApplyMomentum, "lr", values.data(), lens); | |||
| AddressPtr gradient = GenInputAddrPtr<float>(kApplyMomentum, "grad", values.data(), lens); | |||
| AddressPtr momentum = GenInputAddrPtr<float>(kApplyMomentum, "momentum", values.data(), lens); | |||
| AddressPtr learning_rate = GenInputAddrPtr<float>(kApplyMomentum, "lr", const_cast<float *>(values.data()), lens); | |||
| AddressPtr gradient = GenInputAddrPtr<float>(kApplyMomentum, "grad", const_cast<float *>(values.data()), lens); | |||
| AddressPtr momentum = GenInputAddrPtr<float>(kApplyMomentum, "momentum", const_cast<float *>(values.data()), lens); | |||
| return new MomentumOptimInfo(weight_addr, accumulate, learning_rate, gradient, momentum); | |||
| } | |||
| @@ -172,14 +172,15 @@ OptimizerInfo *SparseAdamOptimInfoBuilder::BuildInputs(const WeightPtr &weight, | |||
| return nullptr; | |||
| } | |||
| AddressPtr beta1_power = GenInputAddrPtr<float>(kSparseAdam, "beta1_power", values.data(), lens); | |||
| AddressPtr beta2_power = GenInputAddrPtr<float>(kSparseAdam, "beta2_power", values.data(), lens); | |||
| AddressPtr learning_rate = GenInputAddrPtr<float>(kSparseAdam, "lr", values.data(), lens); | |||
| AddressPtr beta1 = GenInputAddrPtr<float>(kSparseAdam, "beta1", values.data(), lens); | |||
| AddressPtr beta2 = GenInputAddrPtr<float>(kSparseAdam, "beta2", values.data(), lens); | |||
| AddressPtr epsilon = GenInputAddrPtr<float>(kSparseAdam, "eps", values.data(), lens); | |||
| AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", values.data(), lens, inputs_shape); | |||
| AddressPtr indices = GenInputAddrPtr<float>(kSparseAdam, "indices", values.data(), lens, inputs_shape); | |||
| AddressPtr beta1_power = GenInputAddrPtr<float>(kSparseAdam, "beta1_power", const_cast<float *>(values.data()), lens); | |||
| AddressPtr beta2_power = GenInputAddrPtr<float>(kSparseAdam, "beta2_power", const_cast<float *>(values.data()), lens); | |||
| AddressPtr learning_rate = GenInputAddrPtr<float>(kSparseAdam, "lr", const_cast<float *>(values.data()), lens); | |||
| AddressPtr beta1 = GenInputAddrPtr<float>(kSparseAdam, "beta1", const_cast<float *>(values.data()), lens); | |||
| AddressPtr beta2 = GenInputAddrPtr<float>(kSparseAdam, "beta2", const_cast<float *>(values.data()), lens); | |||
| AddressPtr epsilon = GenInputAddrPtr<float>(kSparseAdam, "eps", const_cast<float *>(values.data()), lens); | |||
| AddressPtr grad = GenInputAddrPtr<float>(kSparseAdam, "grad", const_cast<float *>(values.data()), lens, inputs_shape); | |||
| AddressPtr indices = | |||
| GenInputAddrPtr<float>(kSparseAdam, "indices", const_cast<float *>(values.data()), lens, inputs_shape); | |||
| return new SparseAdamOptimInfo(weight_addr, m, v, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, | |||
| grad, indices, sharded); | |||
| } | |||
| @@ -218,8 +219,9 @@ OptimizerInfo *SparseFtrlOptimInfoBuilder::BuildInputs(const WeightPtr &weight, | |||
| } | |||
| linear->size = weight->size() * sizeof(float); | |||
| AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", values.data(), lens, inputs_shape); | |||
| AddressPtr indices = GenInputAddrPtr<float>(kSparseFtrl, "indices", values.data(), lens, inputs_shape); | |||
| AddressPtr grad = GenInputAddrPtr<float>(kSparseFtrl, "grad", const_cast<float *>(values.data()), lens, inputs_shape); | |||
| AddressPtr indices = | |||
| GenInputAddrPtr<float>(kSparseFtrl, "indices", const_cast<float *>(values.data()), lens, inputs_shape); | |||
| return new SparseFtrlOptimInfo(weight_addr, accum, linear, grad, indices, sharded); | |||
| } | |||
| } // namespace ps | |||
| @@ -14,12 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/internal/parameter_server.h" | |||
| #include "ps/parameter_server.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace internal { | |||
| void ParameterServer::Run(const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_LOG(INFO) << "PServer starts connecting to scheduler and workers..."; | |||
| @@ -44,8 +42,8 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { | |||
| } | |||
| bool ParameterServer::Init(const FuncGraphPtr &func_graph) { | |||
| pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10); | |||
| worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10); | |||
| pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10); | |||
| worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10); | |||
| func_graph_ = func_graph; | |||
| handler_.reset(new ServerHandler(this)); | |||
| handler_->Init(); | |||
| @@ -257,12 +255,21 @@ void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Le | |||
| std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key]; | |||
| // Create or update the optimizer info | |||
| std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key]; | |||
| if (pserver_kernel == nullptr) { | |||
| MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; | |||
| if (optim_info == nullptr) { | |||
| const std::shared_ptr<OptimizerInfoBuilder> &builder = optim_info_builders_[weight_key_to_optims_[key]]; | |||
| std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key]; | |||
| if (pserver_kernel == nullptr) { | |||
| MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(pserver_kernel); | |||
| OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths, | |||
| optim_inputs_shape_[key], worker_num_, is_embedding_[key]); | |||
| optim_info.reset(optim); | |||
| optim_infos_[key] = optim_info; | |||
| } else { | |||
| optim_info->Update(values, lengths); | |||
| optim_info->Accumulate(values, lengths); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(pserver_kernel); | |||
| optim_infos_[key] = optim_info; | |||
| } | |||
| grads_accum_counter_[key] += 1; | |||
| @@ -373,7 +380,7 @@ inline bool ParameterServer::ReadyForPush(const Key &key) { | |||
| MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send " | |||
| "kInitWeightsCmd command. 2.The Server failed to initialize weights."; | |||
| } | |||
| MS_LOG(INFO) << "the grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size() | |||
| MS_LOG(INFO) << "The grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size() | |||
| << " the token:" << (tokens_[key] <= 0); | |||
| return grad_accum_count_ < weights_.size() && tokens_[key] <= 0; | |||
| } | |||
| @@ -544,11 +551,9 @@ void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size | |||
| for (int i = 0; i < key_num; i++) { | |||
| Key key = input.keys()[i]; | |||
| size_t data_len = input.len_size() != key_num ? input.values_size() / key_num : input.len()[i]; | |||
| MS_LOG(DEBUG) << "The data len:" << data_len; | |||
| if (!ps_->HasWeight(key)) { | |||
| WeightPtr weight_ptr = std::make_shared<std::vector<float>>(data_ptr + pos, data_ptr + (pos + data_len)); | |||
| MS_LOG(DEBUG) << "The weight ptr:" << *weight_ptr; | |||
| MS_EXCEPTION_IF_NULL(weight_ptr); | |||
| ps_->InitWeight(key, weight_ptr); | |||
| @@ -637,7 +642,7 @@ void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_ | |||
| input.ParseFromArray(data.get(), size); | |||
| const Key &key = input.keys()[0]; | |||
| bool ready = ps_->ReadyForPush(key); | |||
| MS_LOG(INFO) << "the ready is:" << ready; | |||
| MS_LOG(INFO) << "The ready is:" << ready; | |||
| KVMessage res_data; | |||
| res_data.add_keys(key); | |||
| res_data.add_values(ready); | |||
| @@ -671,7 +676,6 @@ void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t | |||
| EmbeddingTableLookup input; | |||
| input.ParseFromArray(data.get(), size); | |||
| const Key &key = input.key(); | |||
| MS_LOG(DEBUG) << "The key is:" << key; | |||
| KVMessage res_data; | |||
| std::vector<Key> keys = {input.keys().begin(), input.keys().end()}; | |||
| @@ -701,6 +705,5 @@ void ParameterServer::ServerHandler::HandleFinalize(DataPtr data, size_t size, V | |||
| MS_EXCEPTION_IF_NULL(res); | |||
| ps_->Finalize(); | |||
| } | |||
| } // namespace internal | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -145,7 +145,6 @@ const size_t &PsCacheManager::QueryHashTableSize(const std::string ¶m_name) | |||
| void PsCacheManager::Initialize() { | |||
| MS_LOG(INFO) << "PS cache initialize."; | |||
| if (!worker.running()) { | |||
| Util::SetInternalEnvVar(); | |||
| worker.Run(); | |||
| } | |||
| embedding_device_cache_ = std::make_shared<EmbeddingDeviceCache>(batch_elements_, vocab_cache_size_); | |||
| @@ -177,22 +176,19 @@ void PsCacheManager::InitParameterServer() { | |||
| for (const auto &item : hash_tables_) { | |||
| const auto ¶m_name = item.first; | |||
| size_t key = worker.SetParamKey(param_name); | |||
| std::vector<size_t> keys{key, key, key, key, key, key}; | |||
| std::vector<float> values{ | |||
| SizeToFloat(item.second.vocab_size), SizeToFloat(item.second.embedding_size), 1, 1, 1, 1, 1}; | |||
| std::vector<int64_t> lens{2, 2, 3}; | |||
| const auto &hash_table_info = item.second; | |||
| const auto ¶m_init_info = hash_table_info.param_init_info_; | |||
| if (param_init_info.param_type_ == kWeight) { | |||
| lens.push_back(1); | |||
| } else if (param_init_info.param_type_ == kAccumulation) { | |||
| lens.push_back(2); | |||
| } | |||
| values.push_back(param_init_info.init_val_); | |||
| lens.push_back(param_init_info.global_seed_); | |||
| lens.push_back(param_init_info.op_seed_); | |||
| std::vector<size_t> input_shape = {item.second.vocab_size, item.second.embedding_size}; | |||
| std::vector<size_t> indices_shape = {1, 1}; | |||
| std::vector<size_t> output_shape = {1, 1, 1}; | |||
| ParamInitInfoMessage info; | |||
| info.set_param_type(param_init_info.param_type_); | |||
| info.set_init_val(param_init_info.init_val_); | |||
| info.set_global_seed(param_init_info.global_seed_); | |||
| info.set_op_seed(param_init_info.op_seed_); | |||
| // if worker role | |||
| worker.InitPSEmbeddingTable(keys, values, lens); | |||
| worker.InitPSEmbeddingTable(key, input_shape, indices_shape, output_shape, info); | |||
| } | |||
| finish_init_parameter_server_ = true; | |||
| @@ -245,7 +241,7 @@ void PsCacheManager::AllocMemForHashTable() { | |||
| } | |||
| void PsCacheManager::SetLocalIdRank() { | |||
| auto worker_num = ::ps::NumWorkers(); | |||
| auto worker_num = PSContext::instance()->initial_worker_num(); | |||
| auto local_shard_size = FloatToInt(std::ceil(SizeToFloat(vocab_size_) / worker_num)); | |||
| vocab_cache_size_diff_ = local_shard_size - SizeToInt(vocab_cache_size_); | |||
| emb_table_slice_bounds_.first = local_shard_size * rank_id_; | |||
| @@ -829,8 +825,8 @@ bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_ | |||
| if (swap_indices_size == 0) { | |||
| return true; | |||
| } | |||
| ::ps::SArray<int> lookup_ids(swap_indices_size, 0); | |||
| ::ps::SArray<float> swap_out_data; | |||
| std::vector<int> lookup_ids(swap_indices_size, 0); | |||
| std::vector<float> swap_out_data; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| swap_out_data.resize(swap_indices_size * embedding_size); | |||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||
| @@ -857,22 +853,21 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_ | |||
| } | |||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||
| auto embedding_size = hash_info.embedding_size; | |||
| ::ps::SArray<int> lengths{swap_indices_size}; | |||
| ::ps::SArray<float> lookup_result(swap_indices_size * embedding_size, 0); | |||
| ::ps::SArray<int> lookup_ids(swap_indices_size, 0); | |||
| std::vector<float> lookup_result(swap_indices_size * embedding_size, 0); | |||
| std::vector<int> lookup_ids(swap_indices_size, 0); | |||
| auto copy_len = swap_indices_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, | |||
| lookup_result.data(), host_hash_table_addr)); | |||
| return true; | |||
| } | |||
| bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, | |||
| bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data, | |||
| const HashTableInfo &hash_info) { | |||
| MS_ERROR_IF_NULL(swap_out_index); | |||
| MS_ERROR_IF_NULL(swap_out_data); | |||
| @@ -912,16 +907,15 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons | |||
| auto cache_vocab_size = hash_info.cache_vocab_size; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| // Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device). | |||
| ::ps::SArray<int> lengths{swap_in_ids_size}; | |||
| ::ps::SArray<float> lookup_result(swap_in_ids_size * embedding_size, 0); | |||
| ::ps::SArray<int> lookup_ids(swap_in_ids_size, 0); | |||
| std::vector<float> lookup_result(swap_in_ids_size * embedding_size, 0); | |||
| std::vector<int> lookup_ids(swap_in_ids_size, 0); | |||
| auto copy_len = swap_in_ids_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| worker.DoPSEmbeddingLookup(key, lookup_ids, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| // Hash swap-in in device. | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( | |||
| embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(), | |||
| @@ -934,7 +928,7 @@ bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons | |||
| return true; | |||
| } | |||
| bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) { | |||
| bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key) { | |||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_ERROR_IF_NULL(swap_out_ids); | |||
| @@ -942,7 +936,7 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da | |||
| if (swap_out_ids_size == 0) { | |||
| return true; | |||
| } | |||
| ::ps::SArray<int> lookup_ids(swap_out_ids_size, 0); | |||
| std::vector<int> lookup_ids(swap_out_ids_size, 0); | |||
| auto copy_len = swap_out_ids_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); | |||
| if (ret != EOK) { | |||
| @@ -994,8 +988,8 @@ bool PsCacheManager::SyncHostEmbeddingTable() { | |||
| continue; | |||
| } | |||
| auto key = worker.GetParamKey(item.first); | |||
| ::ps::SArray<int> lookup_ids(swap_indices_lens, 0); | |||
| ::ps::SArray<float> swap_out_data; | |||
| std::vector<int> lookup_ids(swap_indices_lens, 0); | |||
| std::vector<float> swap_out_data; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| swap_out_data.resize(swap_indices_lens * embedding_size); | |||
| auto host_hash_table_addr = hash_info.host_address.get(); | |||
| @@ -1038,8 +1032,8 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() { | |||
| continue; | |||
| } | |||
| auto key = worker.GetParamKey(item.first); | |||
| ::ps::SArray<int> lookup_ids(swap_indices_lens, 0); | |||
| ::ps::SArray<float> swap_out_data; | |||
| std::vector<int> lookup_ids(swap_indices_lens, 0); | |||
| std::vector<float> swap_out_data; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| swap_out_data.resize(swap_indices_lens * embedding_size); | |||
| std::unique_ptr<float[]> device_hash_table_addr_tmp = | |||
| @@ -29,9 +29,9 @@ | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "ir/tensor.h" | |||
| #include "ps/ps.h" | |||
| #include "ps/common.h" | |||
| #include "ps/constants.h" | |||
| #include "ps/worker.h" | |||
| #include "ps/ps_context.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "ps/ps_cache/embedding_hash_map.h" | |||
| #include "ps/ps_cache/ps_cache_factory.h" | |||
| @@ -155,7 +155,7 @@ class PsCacheManager { | |||
| bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index); | |||
| bool ParseHostDataHostToDevice(size_t id); | |||
| bool ParseHostDataDeviceToHost(); | |||
| bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info); | |||
| bool HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data, const HashTableInfo &hash_info); | |||
| bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); | |||
| bool HashSwapHostToDevice(const HashTableInfo &hash_info); | |||
| bool HashSwapDeviceToHost(const HashTableInfo &hash_info); | |||
| @@ -165,7 +165,7 @@ class PsCacheManager { | |||
| float *hash_table_addr); | |||
| bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||
| const int *indices_addr, float *output_addr); | |||
| bool UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key); | |||
| bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key); | |||
| void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, | |||
| const int *indices_addr, float *output_addr); | |||
| bool CheckFinishInsertInitInfo() const; | |||
| @@ -48,10 +48,10 @@ void PSContext::SetPSEnable(bool enabled) { | |||
| MS_LOG(WARNING) << "MS_ROLE is " << ms_role << ", which is invalid."; | |||
| } | |||
| worker_num_ = std::strtol(common::GetEnv("MS_WORKER_NUM").c_str(), nullptr, 10); | |||
| server_num_ = std::strtol(common::GetEnv("MS_SERVER_NUM").c_str(), nullptr, 10); | |||
| scheduler_host_ = common::GetEnv("MS_SCHED_HOST"); | |||
| scheduler_port_ = std::strtol(common::GetEnv("MS_SCHED_PORT").c_str(), nullptr, 10); | |||
| worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10); | |||
| server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10); | |||
| scheduler_host_ = common::GetEnv(kEnvSchedulerHost); | |||
| scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, 10); | |||
| } else { | |||
| MS_LOG(INFO) << "PS mode is disabled."; | |||
| is_worker_ = false; | |||
| @@ -19,6 +19,7 @@ | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ps/constants.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| @@ -15,13 +15,16 @@ | |||
| */ | |||
| #include "ps/scheduler.h" | |||
| #include "ps/ps.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| void Scheduler::Run() { | |||
| ::ps::Start(0); | |||
| ::ps::Finalize(0, true); | |||
| core::ClusterMetadata::instance()->Init( | |||
| PSContext::instance()->initial_worker_num(), PSContext::instance()->initial_server_num(), | |||
| PSContext::instance()->scheduler_host(), PSContext::instance()->scheduler_port()); | |||
| scheduler_node_.Start(); | |||
| scheduler_node_.Finish(); | |||
| scheduler_node_.Stop(); | |||
| exit(1); | |||
| } | |||
| } // namespace ps | |||
| @@ -16,6 +16,11 @@ | |||
| #ifndef MINDSPORE_CCSRC_PS_SCHEDULER_H_ | |||
| #define MINDSPORE_CCSRC_PS_SCHEDULER_H_ | |||
| #include "ps/core/scheduler_node.h" | |||
| #include "ps/util.h" | |||
| #include "ps/ps_context.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| class Scheduler { | |||
| @@ -32,6 +37,7 @@ class Scheduler { | |||
| ~Scheduler() = default; | |||
| Scheduler(const Scheduler &) = delete; | |||
| Scheduler &operator=(const Scheduler &) = delete; | |||
| core::SchedulerNode scheduler_node_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -17,7 +17,7 @@ | |||
| #include "ps/util.h" | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "ps/common.h" | |||
| #include "ps/constants.h" | |||
| #include "ps/ps_context.h" | |||
| #include "utils/ms_utils.h" | |||
| @@ -46,50 +46,10 @@ std::unordered_map<int64_t, std::string> Util::id_to_optimizer_nodes{ | |||
| {3, kSparseFtrlOp}, | |||
| }; | |||
| bool Util::IsParamServerMode() { return PSContext::instance()->is_ps_mode(); } | |||
| bool Util::IsRoleOfWorker() { return PSContext::instance()->is_worker(); } | |||
| bool Util::IsRoleOfPServer() { return PSContext::instance()->is_server(); } | |||
| bool Util::IsRoleOfScheduler() { return PSContext::instance()->is_scheduler(); } | |||
| void Util::SetInternalEnvVar() { | |||
| if (IsParamServerMode()) { | |||
| auto comm_type = common::GetEnv(kEnvCommType); | |||
| if (!comm_type.empty()) { | |||
| (void)common::SetEnv(kDmlcCommType, comm_type.c_str()); | |||
| } | |||
| auto interface = common::GetEnv(kEnvInterface); | |||
| if (!interface.empty()) { | |||
| (void)common::SetEnv(kDmlcInterface, interface.c_str()); | |||
| } | |||
| auto server_num = common::GetEnv(kEnvPServerNum); | |||
| if (!server_num.empty()) { | |||
| (void)common::SetEnv(kDmlcPServerNum, server_num.c_str()); | |||
| } | |||
| auto worker_num = common::GetEnv(kEnvWorkerNum); | |||
| if (!worker_num.empty()) { | |||
| (void)common::SetEnv(kDmlcWorkerNum, worker_num.c_str()); | |||
| } | |||
| if (IsRoleOfScheduler()) { | |||
| (void)common::SetEnv(kDmlcRole, kRoleOfScheduler); | |||
| } else if (IsRoleOfPServer()) { | |||
| (void)common::SetEnv(kDmlcRole, kRoleOfPServer); | |||
| } else if (IsRoleOfWorker()) { | |||
| (void)common::SetEnv(kDmlcRole, kRoleOfWorker); | |||
| } | |||
| auto scheduler_host = common::GetEnv(kEnvSchedulerHost); | |||
| if (!scheduler_host.empty()) { | |||
| (void)common::SetEnv(kDmlcSchedulerHost, scheduler_host.c_str()); | |||
| } | |||
| auto scheduler_port = common::GetEnv(kEnvSchedulerPort); | |||
| if (!scheduler_port.empty()) { | |||
| (void)common::SetEnv(kDmlcSchedulerPort, scheduler_port.c_str()); | |||
| } | |||
| } | |||
| } | |||
| int64_t Util::optimizer_id(std::string name) { | |||
| if (optimizer_to_ids.count(name) > 0) { | |||
| return optimizer_to_ids[name]; | |||
| @@ -37,11 +37,8 @@ struct ParamInitInfo { | |||
| class Util { | |||
| public: | |||
| static bool IsParamServerMode(); | |||
| static bool IsRoleOfWorker(); | |||
| static bool IsRoleOfPServer(); | |||
| static bool IsRoleOfScheduler(); | |||
| static void SetInternalEnvVar(); | |||
| static int64_t optimizer_id(std::string name); | |||
| static std::string optimizer_name(int64_t id); | |||
| static std::string optimizer_node_name(int64_t id); | |||
| @@ -14,11 +14,10 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "ps/internal/worker.h" | |||
| #include "ps/worker.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| namespace internal { | |||
| void Worker::Run() { | |||
| std::lock_guard<std::mutex> lock(running_mutex_); | |||
| core::ClusterMetadata::instance()->Init( | |||
| @@ -198,7 +197,8 @@ void Worker::AddEmbeddingTable(const Key &key, const size_t &row_count) { | |||
| } | |||
| void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape, | |||
| const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape) { | |||
| const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape, | |||
| const ParamInitInfoMessage &info) { | |||
| bool has_init = IsKeyInit(key); | |||
| if (has_init) { | |||
| MS_LOG(DEBUG) << "The key embedding table of key " << key << " is initialized."; | |||
| @@ -210,6 +210,7 @@ void Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> & | |||
| *embedding_table_meta.mutable_input_shape() = {input_shape.begin(), input_shape.end()}; | |||
| *embedding_table_meta.mutable_indices_shape() = {indices_shape.begin(), indices_shape.end()}; | |||
| *embedding_table_meta.mutable_output_shape() = {output_shape.begin(), output_shape.end()}; | |||
| *embedding_table_meta.mutable_info() = info; | |||
| std::string kv_data = embedding_table_meta.SerializeAsString(); | |||
| @@ -295,19 +296,18 @@ void Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ | |||
| int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size()); | |||
| std::unordered_map<Key, std::shared_ptr<std::pair<float *, int64_t>>> id_addr_map; | |||
| std::shared_ptr<std::vector<float>> values = std::make_shared<std::vector<float>>(); | |||
| int64_t value_offset = 0; | |||
| for (size_t i = 0; i < resp.size(); ++i) { | |||
| KVMessage message; | |||
| message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size()); | |||
| int64_t offset = 0; | |||
| values->clear(); | |||
| for (auto j = 0; j < message.values_size(); j++) { | |||
| values->push_back(message.values(j)); | |||
| } | |||
| MS_LOG(DEBUG) << "the embedding resp:" << values; | |||
| MS_LOG(DEBUG) << "The embedding resp:" << values; | |||
| for (auto k = 0; k < message.keys_size(); k++) { | |||
| const Key &key = message.keys(k); | |||
| float *addr = values->data() + offset; | |||
| offset += single_id_len; | |||
| float *addr = values->data() + value_offset; | |||
| value_offset += single_id_len; | |||
| id_addr_map[key] = std::make_shared<std::pair<float *, int64_t>>(std::make_pair(addr, single_id_len)); | |||
| } | |||
| } | |||
| @@ -969,6 +969,5 @@ void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &pa | |||
| } | |||
| } | |||
| } | |||
| } // namespace internal | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -25,24 +25,38 @@ | |||
| #include <functional> | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include "ps/ps.h" | |||
| #include <mutex> | |||
| #include <unordered_set> | |||
| #include <unordered_map> | |||
| #include "utils/log_adapter.h" | |||
| #include "ir/tensor.h" | |||
| #include "ps/util.h" | |||
| #include "ps/common.h" | |||
| #include "ps/worker_proxy.h" | |||
| #include "ps/constants.h" | |||
| #include "utils/shape_utils.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "ps/core/worker_node.h" | |||
| #include "ps/embedding_table_shard_metadata.h" | |||
| #include "proto/comm.pb.h" | |||
| #include "proto/ps.pb.h" | |||
| #include "ps/ps_context.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| template <typename T> | |||
| class Worker { | |||
| public: | |||
| static Worker &GetInstance() { | |||
| static Worker instance; | |||
| return instance; | |||
| } | |||
| using Callback = std::function<void()>; | |||
| using PartitionEmbeddingMessages = std::vector<std::pair<bool, EmbeddingTableLookup>>; | |||
| using PartitionKVMessages = std::vector<std::pair<bool, KVMessage>>; | |||
| using EmbeddingPartitioner = std::function<void( | |||
| const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map<int64_t, int64_t> &attrs)>; | |||
| using KVPartitioner = | |||
| std::function<void(const KVMessage &send, PartitionKVMessages *partition, const std::map<int64_t, int64_t> &attrs)>; | |||
| void Run(); | |||
| void Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes); | |||
| @@ -53,340 +67,89 @@ class Worker { | |||
| bool GetParamInitInServer(const std::string ¶m_name); | |||
| void SetKeyOptimId(size_t key, const std::string &optimizer_name); | |||
| void SetOptimInputShapes(size_t key, const ShapeVector &shape); | |||
| void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); | |||
| void InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes); | |||
| void AddEmbeddingTable(const Key &key, const size_t &row_count); | |||
| void InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape, | |||
| const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape, | |||
| const ParamInitInfoMessage &info); | |||
| void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor); | |||
| void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd); | |||
| void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<T> &vals); | |||
| void DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result, | |||
| int64_t cmd); | |||
| void UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids, | |||
| const std::vector<float> &vals); | |||
| bool running() { return running_; } | |||
| void Finalize(); | |||
| private: | |||
| Worker() : kv_worker_(nullptr), running_(false), key_cnt_(0) {} | |||
| Worker() : running_(false), key_cnt_(0) {} | |||
| ~Worker() = default; | |||
| Worker(const Worker &) = delete; | |||
| Worker &operator=(const Worker &) = delete; | |||
| void Initialize(); | |||
| bool IsKeyInit(const size_t key); | |||
| void AddKeyToServerId(const Key &key); | |||
| void AddKeyByHashMod(const Key &key); | |||
| void InitPSOptimId(const size_t param_key); | |||
| void InitPSOptimInputShapes(const size_t key); | |||
| void InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size); | |||
| static void EmbeddingLookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {} | |||
| std::shared_ptr<WorkerProxy<T>> kv_worker_; | |||
| bool IsReadyForPush(const Key &key); | |||
| bool IsReadyForPull(const Key &key); | |||
| void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids, | |||
| const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice, | |||
| const size_t segment_size, float *gradient, int *indices); | |||
| void BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index, | |||
| const float *original_data, const float *grads, int *indices, std::vector<float> *reduced_data); | |||
| void PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens = {}, | |||
| int command = 0, int64_t priority = 0); | |||
| void PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens, | |||
| size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); | |||
| void PullData(const std::vector<Key> &keys, std::vector<float> *vals, std::vector<int> *lens = nullptr, int cmd = 0, | |||
| int64_t priority = 0); | |||
| void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner, | |||
| const std::map<int64_t, int64_t> &attrs, std::vector<float> *vals, std::vector<int> *lens); | |||
| int64_t server_num_; | |||
| bool running_; | |||
| std::mutex running_mutex_; | |||
| size_t key_cnt_; | |||
| std::map<std::string, size_t> param_to_key_; | |||
| std::map<size_t, bool> init_keys_; | |||
| std::map<size_t, int64_t> key_to_optimId_; | |||
| std::map<size_t, std::vector<ShapeVector>> key_to_optim_shapes_; | |||
| std::map<std::string, bool> param_to_init_in_server_; | |||
| core::WorkerNode worker_node_; | |||
| EmbeddingPartitioner lookup_partitioner_; | |||
| KVPartitioner sparse_partitioner_; | |||
| KVPartitioner round_robin_partitioner_; | |||
| KVPartitioner worker_init_embedding_partitioner_; | |||
| KVPartitioner update_embedding_partitioner_; | |||
| KVPartitioner broadcast_partitioner_; | |||
| std::unordered_map<Key, int64_t> key_to_server_id_; | |||
| std::unordered_map<Key, size_t> embedding_row_cnt_; | |||
| std::unordered_map<Key, std::shared_ptr<std::vector<EmbeddingTableShardMetadata>>> embedding_table_ranges_; | |||
| }; | |||
| template <typename T> | |||
| void Worker<T>::Run() { | |||
| if (running_) { | |||
| MS_LOG(INFO) << "'Worker is already running."; | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Worker starts connecting to scheduler and server..."; | |||
| ::ps::Start(0); | |||
| MS_LOG(INFO) << "Worker connected successfully."; | |||
| if (!::ps::IsWorker()) { | |||
| MS_LOG(EXCEPTION) << "The role is not worker."; | |||
| } | |||
| kv_worker_ = std::make_shared<WorkerProxy<T>>(0, 0, 1, 2); | |||
| running_ = true; | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes) { | |||
| if (keys.size() == 0) { | |||
| MS_LOG(EXCEPTION) << "key size should be greater than zero"; | |||
| } | |||
| if (key_to_optimId_.count(keys[0]) == 0) { | |||
| MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0]; | |||
| } | |||
| Key key = keys[0]; | |||
| int64_t optim_id = key_to_optimId_[key]; | |||
| bool is_sparse = false; | |||
| if (optim_id == 1 || optim_id == 2 || optim_id == 3) { | |||
| is_sparse = true; | |||
| } | |||
| int64_t grad_index = -1; | |||
| int64_t indice_index = -1; | |||
| // Sparse adam gradient | |||
| if (optim_id == 1 || optim_id == 2) { | |||
| grad_index = 6; | |||
| indice_index = 7; | |||
| // Sparse ftrl gradient | |||
| } else if (optim_id == 3) { | |||
| grad_index = 0; | |||
| indice_index = 1; | |||
| } | |||
| size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus<int64_t>()); | |||
| ::ps::SArray<T> total_buffer(total_size, 0); | |||
| size_t offset = 0; | |||
| size_t dst_size = 0; | |||
| size_t src_size = 0; | |||
| for (size_t i = 0; i < sizes.size(); i++) { | |||
| void *dst_data = total_buffer.data() + offset / sizeof(T); | |||
| void *src_data = reinterpret_cast<void *>(addrs[i]); | |||
| MS_EXCEPTION_IF_NULL(dst_data); | |||
| MS_EXCEPTION_IF_NULL(src_data); | |||
| dst_size = sizes[i] * sizeof(T); | |||
| src_size = sizes[i] * sizeof(T); | |||
| auto ret = memcpy_s(dst_data, dst_size, src_data, src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return; | |||
| } | |||
| offset += sizes[i] * sizeof(T); | |||
| } | |||
| while (!kv_worker_->IsReadyForPush(keys[0])) { | |||
| continue; | |||
| } | |||
| std::vector<int> sizes_int; | |||
| (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| if (!is_sparse) { | |||
| kv_worker_->PushData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray<int>(sizes_int)); | |||
| } else { | |||
| std::vector<int64_t> &var_shape = key_to_optim_shapes_[key][0]; | |||
| int64_t first_dim_size = var_shape[0]; | |||
| int64_t outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies<int64_t>()); | |||
| kv_worker_->PushSparseData(::ps::SArray<::ps::Key>(keys), total_buffer, ::ps::SArray<int>(sizes_int), grad_index, | |||
| indice_index, first_dim_size, outer_dim_size); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::Pull(const size_t key, void *dev_addr, const size_t size) { | |||
| MS_EXCEPTION_IF_NULL(dev_addr); | |||
| ::ps::SArray<T> variables(size / sizeof(T), 0); | |||
| while (!kv_worker_->IsReadyForPull(key)) { | |||
| continue; | |||
| } | |||
| kv_worker_->PullData({key}, &variables); | |||
| size_t dst_size = size; | |||
| size_t src_size = size; | |||
| auto ret = memcpy_s(dev_addr, dst_size, variables.data(), src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<int> &lens, ::ps::SArray<T> *lookup_result, int64_t cmd) { | |||
| MS_EXCEPTION_IF_NULL(lookup_result); | |||
| kv_worker_->EmbeddingLookup(keys, lookup_ids, lens, lookup_result, cmd); | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<T> &vals) { | |||
| kv_worker_->UpdateEmbeddingTable(keys, lookup_ids, vals); | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::Finalize() { | |||
| if (running_) { | |||
| MS_LOG(INFO) << "Worker starts finalizing..."; | |||
| kv_worker_->Finalize(); | |||
| kv_worker_.reset(); | |||
| running_ = false; | |||
| MS_LOG(INFO) << "Worker finalized successfully."; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::InitPSParamData(const std::vector<size_t> &keys, void *origin_addr, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(origin_addr); | |||
| ::ps::SArray<T> addr(reinterpret_cast<T *>(origin_addr), size / sizeof(T)); | |||
| ::ps::SArray<::ps::Key> key(keys); | |||
| ::ps::SArray<int> lens; | |||
| lens.push_back(addr.size()); | |||
| kv_worker_->PushData(key, addr, lens, kInitWeightsCmd); | |||
| init_keys_[key[0]] = true; | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::SetOptimInputShapes(size_t key, const ShapeVector &shape) { | |||
| if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) { | |||
| key_to_optim_shapes_[key] = {shape}; | |||
| } else { | |||
| key_to_optim_shapes_[key].push_back(shape); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::InitPSOptimInputShapes(const size_t key) { | |||
| ::ps::SArray<::ps::Key> keys; | |||
| ::ps::SArray<int> shape_len; | |||
| ::ps::SArray<T> all_shape; | |||
| std::vector<ShapeVector> shapes = key_to_optim_shapes_[key]; | |||
| for (auto shape : shapes) { | |||
| keys.push_back(key); | |||
| if (shape.size() == 0) { | |||
| shape_len.push_back(1); | |||
| all_shape.push_back(1); | |||
| } else { | |||
| shape_len.push_back(SizeToLong(shape.size())); | |||
| for (auto dim : shape) { | |||
| all_shape.push_back(static_cast<T>(dim)); | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "keys:" << keys; | |||
| MS_LOG(INFO) << "shape_len:" << shape_len; | |||
| MS_LOG(INFO) << "all_shape:" << all_shape; | |||
| if (!init_keys_[key]) { | |||
| init_keys_[key] = true; | |||
| } | |||
| kv_worker_->PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd); | |||
| } | |||
| template <typename T> | |||
| bool Worker<T>::IsKeyInit(const size_t key) { | |||
| if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| template <typename T> | |||
| size_t Worker<T>::SetParamKey(const std::string ¶m_name) { | |||
| size_t key = UINT64_MAX; | |||
| if (param_to_key_.count(param_name)) { | |||
| key = param_to_key_[param_name]; | |||
| MS_LOG(INFO) << param_name << " key is already set: key value is " << key; | |||
| } else { | |||
| key = key_cnt_++; | |||
| param_to_key_[param_name] = key; | |||
| MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name; | |||
| } | |||
| return key; | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::SetParamInitInServer(const std::string ¶m_name, bool init_in_server) { | |||
| MS_LOG(INFO) << "Set parameter " << param_name << " init_in_server:" << init_in_server; | |||
| param_to_init_in_server_[param_name] = init_in_server; | |||
| } | |||
| template <typename T> | |||
| bool Worker<T>::GetParamInitInServer(const std::string ¶m_name) { | |||
| if (param_to_init_in_server_.count(param_name) == 0) { | |||
| return false; | |||
| } | |||
| return param_to_init_in_server_[param_name]; | |||
| } | |||
| template <typename T> | |||
| size_t Worker<T>::GetParamKey(const std::string ¶m_name) { | |||
| size_t key = kInvalidKey; | |||
| if (param_to_key_.find(param_name) != param_to_key_.end()) { | |||
| key = param_to_key_[param_name]; | |||
| MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key; | |||
| } | |||
| return key; | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::SetKeyOptimId(size_t key, const std::string &optimizer_name) { | |||
| key_to_optimId_[key] = Util::optimizer_id(optimizer_name); | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::InitPSOptimId(const size_t param_key) { | |||
| if (key_to_optimId_.count(param_key) == 0) { | |||
| MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key; | |||
| } | |||
| int64_t optim_id = key_to_optimId_[param_key]; | |||
| ::ps::SArray<::ps::Key> keys = {param_key}; | |||
| ::ps::SArray<T> optim_id_vals = {static_cast<T>(optim_id)}; | |||
| ::ps::SArray<int> optim_id_lens = {optim_id_vals.size()}; | |||
| kv_worker_->PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd); | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::InitPSEmbeddingTable(const std::vector<size_t> &keys, std::vector<T> shapes, const ShapeVector &sizes) { | |||
| bool has_init = IsKeyInit(keys[0]); | |||
| if (has_init) { | |||
| MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized."; | |||
| return; | |||
| } | |||
| ::ps::SArray<T> shapes_val; | |||
| for (auto dim : shapes) { | |||
| shapes_val.push_back(dim); | |||
| } | |||
| std::vector<int> sizes_int; | |||
| (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int), | |||
| [](const int64_t &value) { return static_cast<int>(value); }); | |||
| kv_worker_->Wait( | |||
| kv_worker_->InitEmbeddingTable(::ps::SArray<::ps::Key>(keys), shapes_val, ::ps::SArray<int>(sizes_int))); | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) { | |||
| MS_EXCEPTION_IF_NULL(tensor); | |||
| MS_EXCEPTION_IF_NULL(input_node); | |||
| auto pk_node = input_node->cast<ParameterPtr>(); | |||
| MS_EXCEPTION_IF_NULL(pk_node); | |||
| const std::string ¶m_name = pk_node->fullname_with_scope(); | |||
| void *param_data = tensor->data_c(); | |||
| size_t param_size = LongToSize(tensor->data().nbytes()); | |||
| size_t param_key = GetParamKey(param_name); | |||
| if (param_key == kInvalidKey) { | |||
| MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned."; | |||
| return; | |||
| } | |||
| bool init_in_server = false; | |||
| auto param_info_ptr = pk_node->param_info(); | |||
| if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) { | |||
| init_in_server = true; | |||
| } | |||
| SetParamInitInServer(param_name, init_in_server); | |||
| bool init = IsKeyInit(param_key); | |||
| if (!init) { | |||
| MS_LOG(INFO) << "Init paramter and optimizer in parameter server side for " << param_name | |||
| << ", whether init in server: " << init_in_server; | |||
| kv_worker_->AddKeyToServerId(param_key); | |||
| if (!PsDataPrefetch::GetInstance().cache_enable()) { | |||
| if (!init_in_server) { | |||
| if (param_size > INT_MAX) { | |||
| MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " | |||
| << param_size; | |||
| } | |||
| InitPSParamData({param_key}, param_data, param_size); | |||
| } | |||
| InitPSOptimId(param_key); | |||
| InitPSOptimInputShapes(param_key); | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void Worker<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { | |||
| bool has_init = IsKeyInit(key); | |||
| if (has_init) { | |||
| return; | |||
| } | |||
| kv_worker_->AddEmbeddingTable(key, row_count); | |||
| } | |||
| static Worker<float> &worker = Worker<float>::GetInstance(); | |||
| static Worker &worker = Worker::GetInstance(); | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_WORKER_H_ | |||
| @@ -1,873 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PS_WORKER_PROXY_H_ | |||
| #define MINDSPORE_CCSRC_PS_WORKER_PROXY_H_ | |||
| #include <map> | |||
| #include <numeric> | |||
| #include <functional> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ps/ps.h" | |||
| #include "ps/util.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "ps/ps_context.h" | |||
| namespace mindspore { | |||
| namespace ps { | |||
| template <typename T> | |||
| class WorkerProxy : public ::ps::KVWorker<T> { | |||
| public: | |||
| using Worker = ::ps::KVWorker<T>; | |||
| using Callback = std::function<void()>; | |||
| using SlicedKVs = std::vector<std::pair<bool, ::ps::KVPairs<T>>>; | |||
| using Slicer = std::function<void(int64_t ts, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges, | |||
| SlicedKVs *sliced, const std::map<int64_t, int64_t> &attrs)>; | |||
| using ::ps::SimpleApp::obj_; | |||
| explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id) | |||
| : Worker(app_id, customer_id) { | |||
| server_num_ = ::ps::NumServers(); | |||
| MS_LOG(INFO) << "Server num:" << server_num_; | |||
| PSContext::instance()->SetPSRankId(::ps::MyRank()); | |||
| using std::placeholders::_1; | |||
| using std::placeholders::_2; | |||
| using std::placeholders::_3; | |||
| using std::placeholders::_4; | |||
| using std::placeholders::_5; | |||
| lookup_customer_ = std::unique_ptr<::ps::Customer>( | |||
| new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy<T>::ProcessLookupResult, this, _1))); | |||
| general_customer_ = std::unique_ptr<::ps::Customer>( | |||
| new ::ps::Customer(app_id, general_customer_id, std::bind(&WorkerProxy<T>::ProcessResponse, this, _1))); | |||
| lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3, _4, _5); | |||
| sparse_slicer_ = std::bind(&WorkerProxy<T>::SparseSlicer, this, _1, _2, _3, _4, _5); | |||
| broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4, _5); | |||
| round_robin_slicer_ = std::bind(&WorkerProxy<T>::RoundRobinSlicer, this, _1, _2, _3, _4, _5); | |||
| worker_init_embedding_slicer_ = std::bind(&WorkerProxy<T>::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5); | |||
| update_embedding_slicer_ = std::bind(&WorkerProxy<T>::UpdateEmbeddingSlicer, this, _1, _2, _3, _4, _5); | |||
| } | |||
| ~WorkerProxy() override = default; | |||
| void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); | |||
| void AddKeyToServerId(const ::ps::Key &key); | |||
| void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int64_t cmd = 0, | |||
| const Callback &cb = nullptr, int64_t priority = 0); | |||
| int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, | |||
| const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int64_t priority = 0); | |||
| void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<T> &vals, const Callback &cb = nullptr, int64_t priority = 0); | |||
| bool IsReadyForPush(const Key &key); | |||
| bool IsReadyForPull(const Key &key); | |||
| void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {}, | |||
| int64_t cmd = 0, int64_t priority = 0); | |||
| void PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens, | |||
| size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); | |||
| void PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens = nullptr, | |||
| int64_t cmd = 0, int64_t priority = 0); | |||
| void Finalize(); | |||
| private: | |||
| template <typename C> | |||
| int64_t AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, C *vals, int64_t cmd, | |||
| const Callback &cb); | |||
| int64_t AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens, | |||
| int64_t cmd, const Callback &cb); | |||
| void LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs); | |||
| void SparseSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs); | |||
| void BroadcastSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs); | |||
| void RoundRobinSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attrs); | |||
| void ProcessLookupResult(const ::ps::Message &msg); | |||
| void ProcessResponse(const ::ps::Message &msg); | |||
| void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs<T> &kvs, | |||
| const Slicer &slicer, std::map<int64_t, int64_t> attrs = {}); | |||
| void AddKeyByHashMod(const ::ps::Key &key); | |||
| void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids, | |||
| const std::vector<std::pair<int, T *>> &indice_to_grad, const int *all_indice, | |||
| const size_t segment_size, T *gradient, int *indice); | |||
| void BuildSparseValue(const ::ps::SArray<int> &lengths, const size_t grad_index, const size_t indice_index, | |||
| const T *original_data, const T *grads, int *indices, ::ps::SArray<T> *reduced_data); | |||
| int64_t server_num_; | |||
| std::unique_ptr<::ps::Customer> lookup_customer_; | |||
| std::unique_ptr<::ps::Customer> general_customer_; | |||
| std::unordered_map<::ps::Key, std::shared_ptr<std::vector<::ps::Range>>> embedding_table_ranges_; | |||
| std::unordered_map<int64_t, std::vector<::ps::KVPairs<T>>> lookup_results_; | |||
| std::unordered_map<int64_t, std::map<int64_t, ::ps::KVPairs<T>>> gathered_response_; | |||
| std::mutex mutex_; | |||
| Slicer lookup_slicer_; | |||
| Slicer sparse_slicer_; | |||
| Slicer broadcast_slicer_; | |||
| Slicer round_robin_slicer_; | |||
| Slicer worker_init_embedding_slicer_; | |||
| Slicer update_embedding_slicer_; | |||
| std::unordered_map<int64_t, Callback> lookup_callbacks_; | |||
| std::unordered_map<int64_t, Callback> general_callbacks_; | |||
| std::unordered_map<int64_t, int64_t> expected_result_count_; | |||
| std::unordered_map<::ps::Key, int64_t> key_to_server_id_; | |||
| std::unordered_map<::ps::Key, size_t> embedding_row_cnt_; | |||
| }; | |||
| template <typename T> | |||
| void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { | |||
| uint64_t begin = 0; | |||
| uint64_t end = 0; | |||
| for (int64_t i = 0; i < server_num_; i++) { | |||
| int64_t local_row_cnt = Util::LocalShard(row_count, i, server_num_); | |||
| if (i == 0) { | |||
| end = local_row_cnt - 1; | |||
| } else { | |||
| begin = end + 1; | |||
| end += local_row_cnt; | |||
| } | |||
| ::ps::Range range(begin, end); | |||
| if (embedding_table_ranges_.count(key) == 0) { | |||
| embedding_table_ranges_[key] = std::make_shared<std::vector<::ps::Range>>(); | |||
| MS_EXCEPTION_IF_NULL(embedding_table_ranges_[key]); | |||
| } | |||
| embedding_table_ranges_[key]->push_back(range); | |||
| } | |||
| embedding_row_cnt_[key] = row_count; | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::AddKeyByHashMod(const ::ps::Key &key) { | |||
| if (server_num_ == 0) { | |||
| MS_LOG(EXCEPTION) << "Server number is invalid:0"; | |||
| } | |||
| key_to_server_id_[key] = static_cast<int64_t>(key % server_num_); | |||
| MS_LOG(INFO) << "The server id of key " << key << " is " << key_to_server_id_[key]; | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::AddKeyToServerId(const ::ps::Key &key) { | |||
| AddKeyByHashMod(key); | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int64_t cmd, | |||
| const Callback &cb, int64_t priority) { | |||
| int64_t ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.lens = lookup_ids; | |||
| kvs.priority = priority; | |||
| expected_result_count_[ts] = 0; | |||
| Send(lookup_customer_.get(), ts, true, true, cmd, kvs, lookup_slicer_); | |||
| int64_t expect_rt_count = expected_result_count_[ts]; | |||
| lookup_customer_->AddResponse(ts, server_num_ - expect_rt_count); | |||
| lookup_customer_->WaitRequest(ts); | |||
| expected_result_count_.erase(ts); | |||
| } | |||
| template <typename T> | |||
| int64_t WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, | |||
| const ::ps::SArray<int> &lens, const Callback &cb, int64_t priority) { | |||
| int64_t ts = obj_->NewRequest(::ps::kServerGroup); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.vals = vals; | |||
| kvs.lens = lens; | |||
| kvs.priority = priority; | |||
| Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, broadcast_slicer_); | |||
| return ts; | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| const ::ps::SArray<T> &vals, const Callback &cb, int64_t priority) { | |||
| int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.lens = lookup_ids; | |||
| kvs.vals = vals; | |||
| kvs.priority = priority; | |||
| expected_result_count_[ts] = 0; | |||
| Send(general_customer_.get(), ts, true, false, kUpdateEmbeddingsCmd, kvs, update_embedding_slicer_); | |||
| if (expected_result_count_[ts] < server_num_) { | |||
| general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); | |||
| } | |||
| general_customer_->WaitRequest(ts); | |||
| expected_result_count_.erase(ts); | |||
| } | |||
| template <typename T> | |||
| bool WorkerProxy<T>::IsReadyForPush(const Key &key) { | |||
| ::ps::SArray<T> result(1, 0); | |||
| PullData({key}, &result, nullptr, kCheckReadyForPushCmd); | |||
| if (result[0] > 0) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| template <typename T> | |||
| bool WorkerProxy<T>::IsReadyForPull(const Key &key) { | |||
| ::ps::SArray<T> result(1, 0); | |||
| PullData({key}, &result, nullptr, kCheckReadyForPullCmd); | |||
| if (result[0] > 0) { | |||
| return true; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, | |||
| const ::ps::SArray<int> &lens, int64_t cmd, int64_t priority) { | |||
| int64_t ts = AddGeneralRspCB(keys, nullptr, nullptr, cmd, nullptr); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.vals = vals; | |||
| kvs.lens = lens; | |||
| kvs.priority = priority; | |||
| if (embedding_table_ranges_.count(keys[0])) { | |||
| if (cmd == kInitWeightsCmd) { | |||
| Send(general_customer_.get(), ts, true, false, cmd, kvs, worker_init_embedding_slicer_); | |||
| } else { | |||
| Send(general_customer_.get(), ts, true, false, cmd, kvs, broadcast_slicer_); | |||
| } | |||
| } else { | |||
| Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_); | |||
| } | |||
| if (expected_result_count_[ts] < server_num_) { | |||
| general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); | |||
| } | |||
| general_customer_->WaitRequest(ts); | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, | |||
| const ::ps::SArray<int> &lens, size_t grad_index, size_t indice_index, | |||
| size_t first_dim_size, size_t outer_dim_size) { | |||
| int64_t ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.vals = vals; | |||
| kvs.lens = lens; | |||
| const int64_t cmd = 0; | |||
| if (embedding_table_ranges_.count(keys[0])) { | |||
| std::map<int64_t, int64_t> attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}}; | |||
| Send(general_customer_.get(), ts, true, false, cmd, kvs, sparse_slicer_, attrs); | |||
| } else { | |||
| Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_); | |||
| } | |||
| if (expected_result_count_[ts] < server_num_) { | |||
| general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); | |||
| } | |||
| general_customer_->WaitRequest(ts); | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens, | |||
| int64_t cmd, int64_t priority) { | |||
| MS_EXCEPTION_IF_NULL(vals); | |||
| int64_t ts = AddGeneralRspCB(keys, vals, lens, cmd, nullptr); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.priority = priority; | |||
| if (embedding_table_ranges_.count(keys[0])) { | |||
| Send(general_customer_.get(), ts, false, true, cmd, kvs, broadcast_slicer_); | |||
| } else { | |||
| Send(general_customer_.get(), ts, false, true, cmd, kvs, round_robin_slicer_); | |||
| } | |||
| if (expected_result_count_[ts] < server_num_) { | |||
| general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]); | |||
| } | |||
| general_customer_->WaitRequest(ts); | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::Finalize() { | |||
| int64_t ts = obj_->NewRequest(::ps::kServerGroup); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys.push_back(0); | |||
| kvs.vals.push_back(0.0f); | |||
| Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_); | |||
| obj_->WaitRequest(ts); | |||
| ::ps::Finalize(0, true); | |||
| } | |||
| template <typename T> | |||
| template <typename C> | |||
| int64_t WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, | |||
| C *lookup_result, int64_t cmd, const Callback &cb) { | |||
| MS_EXCEPTION_IF_NULL(lookup_result); | |||
| int64_t ts = lookup_customer_->NewRequest(::ps::kServerGroup); | |||
| const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { | |||
| mutex_.lock(); | |||
| auto &kvs = lookup_results_[ts]; | |||
| mutex_.unlock(); | |||
| if (lookup_ids.empty()) { | |||
| MS_LOG(EXCEPTION) << "Lookup id is empty."; | |||
| } | |||
| int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size()); | |||
| std::unordered_map<Key, std::shared_ptr<std::pair<T *, int64_t>>> id_addr_map; | |||
| for (const auto &s : kvs) { | |||
| int64_t offset = 0; | |||
| for (size_t i = 0; i < s.keys.size(); i++) { | |||
| const Key &key = s.keys[i]; | |||
| T *addr = s.vals.data() + offset; | |||
| offset += single_id_len; | |||
| id_addr_map[key] = std::make_shared<std::pair<T *, int64_t>>(std::make_pair(addr, single_id_len)); | |||
| MS_EXCEPTION_IF_NULL(id_addr_map[key]); | |||
| } | |||
| } | |||
| T *result_addr = lookup_result->data(); | |||
| MS_EXCEPTION_IF_NULL(result_addr); | |||
| int64_t offset = 0; | |||
| size_t dst_size = 0; | |||
| size_t src_size = 0; | |||
| void *dst_data = nullptr; | |||
| void *src_data = nullptr; | |||
| for (size_t i = 0; i < lookup_ids.size(); i++) { | |||
| if (id_addr_map.count(lookup_ids[i]) == 0) { | |||
| offset += single_id_len; | |||
| continue; | |||
| } | |||
| auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])]; | |||
| int64_t size = single_id_len * sizeof(T); | |||
| dst_size = size; | |||
| src_size = size; | |||
| dst_data = result_addr + offset; | |||
| src_data = pair->first; | |||
| MS_EXCEPTION_IF_NULL(dst_data); | |||
| MS_EXCEPTION_IF_NULL(src_data); | |||
| auto ret = memcpy_s(dst_data, dst_size, src_data, src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return; | |||
| } | |||
| offset += single_id_len; | |||
| } | |||
| mutex_.lock(); | |||
| lookup_results_.erase(ts); | |||
| mutex_.unlock(); | |||
| if (cb) cb(); | |||
| }; | |||
| lookup_callbacks_[ts] = callback; | |||
| return ts; | |||
| } | |||
| template <typename T> | |||
| int64_t WorkerProxy<T>::AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, | |||
| ::ps::SArray<int> *lens, int64_t cmd, const Callback &cb) { | |||
| int64_t ts = general_customer_->NewRequest(::ps::kServerGroup); | |||
| const auto &callback = [this, ts, keys, vals, lens, cb]() mutable { | |||
| mutex_.lock(); | |||
| std::map<int64_t, ::ps::KVPairs<T>> server_kvs = gathered_response_[ts]; | |||
| mutex_.unlock(); | |||
| vals->clear(); | |||
| for (auto kvs : server_kvs) { | |||
| for (auto val : kvs.second.vals) { | |||
| vals->push_back(val); | |||
| } | |||
| if (lens) { | |||
| for (auto len : kvs.second.lens) { | |||
| lens->push_back(len); | |||
| } | |||
| } | |||
| } | |||
| mutex_.lock(); | |||
| gathered_response_.erase(ts); | |||
| mutex_.unlock(); | |||
| if (cb) { | |||
| cb(); | |||
| } | |||
| }; | |||
| general_callbacks_[ts] = callback; | |||
| return ts; | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attrs) { | |||
| MS_EXCEPTION_IF_NULL(sliced); | |||
| int32_t *lookup_ids = send.lens.data(); | |||
| size_t id_size = send.lens.size(); | |||
| const Key &key = send.keys[0]; | |||
| const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); | |||
| sliced->resize(ranges.size()); | |||
| for (size_t i = 0; i < ranges.size(); i++) { | |||
| const ::ps::Range &range = ranges[i]; | |||
| const auto &begin = range.begin(); | |||
| const auto &end = range.end(); | |||
| std::unordered_set<int64_t> unique_ids; | |||
| auto &kvs = sliced->at(i).second; | |||
| kvs.keys.push_back(key); | |||
| kvs.vals.push_back(0.0f); | |||
| for (size_t j = 0; j < id_size; j++) { | |||
| auto lookup_id = static_cast<uint64_t>(lookup_ids[j]); | |||
| // If lookup_id is out of range, like negative number, unique_ids will not contain it. | |||
| // Servers always get lookup_ids in its embedding table range. | |||
| if (lookup_id >= begin && lookup_id <= end) { | |||
| unique_ids.insert(lookup_id); | |||
| } | |||
| } | |||
| for (const auto &lookup_id : unique_ids) { | |||
| kvs.keys.push_back(lookup_id); | |||
| kvs.vals.push_back(0.0f); | |||
| } | |||
| if (kvs.keys.size() <= 1) { | |||
| sliced->at(i).first = false; | |||
| } else { | |||
| sliced->at(i).first = true; | |||
| expected_result_count_[timestamp] += 1; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::SparseSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attrs) { | |||
| MS_EXCEPTION_IF_NULL(sliced); | |||
| // Init variables | |||
| T *data = send.vals.data(); | |||
| if (attrs.count(0) == 0 || attrs.count(1) == 0 || attrs.count(2) == 0 || attrs.count(3) == 0) { | |||
| MS_LOG(EXCEPTION) << "Invalid attrs keys"; | |||
| } | |||
| auto iter = attrs.find(0); | |||
| size_t grad_index = static_cast<size_t>(iter->second); | |||
| iter = attrs.find(1); | |||
| size_t indice_index = static_cast<size_t>(iter->second); | |||
| iter = attrs.find(2); | |||
| size_t first_dim_size = static_cast<size_t>(iter->second); | |||
| iter = attrs.find(3); | |||
| size_t outer_dim_size = static_cast<size_t>(iter->second); | |||
| int grad_size = send.lens[grad_index]; | |||
| int indice_size = send.lens[indice_index]; | |||
| int segment_size = grad_size / indice_size; | |||
| int64_t grad_offset = 0; | |||
| int64_t indice_offset = 0; | |||
| for (size_t i = 0; i < grad_index; i++) { | |||
| grad_offset += send.lens[i]; | |||
| } | |||
| for (size_t j = 0; j < indice_index; j++) { | |||
| indice_offset += send.lens[j]; | |||
| } | |||
| T *grad_data = data + grad_offset; | |||
| int *indice_data = reinterpret_cast<int *>(data) + indice_offset; | |||
| // Build the mappings of indice to gradient | |||
| std::vector<std::pair<int, T *>> indice_to_grads; | |||
| for (int i = 0; i < indice_size; i++) { | |||
| int indice = indice_data[i]; | |||
| T *grad = grad_data + i * segment_size; | |||
| indice_to_grads.push_back(std::make_pair(indice, grad)); | |||
| } | |||
| const Key &key = send.keys[0]; | |||
| const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); | |||
| sliced->resize(ranges.size()); | |||
| // Construct reduced sparse data for each server | |||
| for (size_t i = 0; i < ranges.size(); i++) { | |||
| const ::ps::Range &range = ranges[i]; | |||
| const auto &begin = range.begin(); | |||
| const auto &end = range.end(); | |||
| auto &kvs = sliced->at(i).second; | |||
| kvs.keys = send.keys; | |||
| kvs.lens = send.lens; | |||
| // Prepare the sparse gradient and indice | |||
| std::vector<int> indice_ids; | |||
| std::unordered_set<int> distinct_ids; | |||
| for (int j = 0; j < indice_size; j++) { | |||
| size_t indice = static_cast<size_t>(indice_data[j]); | |||
| if (indice >= begin && indice <= end) { | |||
| indice_ids.push_back(indice); | |||
| distinct_ids.insert(indice); | |||
| } | |||
| } | |||
| size_t indices_size = indice_ids.size(); | |||
| if (indices_size > 0) { | |||
| int slice_segment_size = indices_size * segment_size; | |||
| std::vector<T> src_grad_data(slice_segment_size); | |||
| std::vector<int> src_indice_data(indices_size); | |||
| PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data.data(), | |||
| src_indice_data.data()); | |||
| // Reduce the sparse gradient and indice | |||
| std::vector<T> new_grad(slice_segment_size); | |||
| std::vector<int> new_indices(indices_size); | |||
| mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size}); | |||
| Util::ReduceSparseGradient(src_grad_data.data(), src_indice_data.data(), indices_size, segment_size, | |||
| first_dim_size, outer_dim_size, &unique_sparse_grad); | |||
| // Update the length of reduce sparse gradient and indice | |||
| ::ps::SArray<int> reduced_lens; | |||
| reduced_lens.CopyFrom(kvs.lens); | |||
| reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size; | |||
| reduced_lens[indice_index] = unique_sparse_grad.indices_size_; | |||
| // Build the sparse value to be sent | |||
| size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus<int>()); | |||
| ::ps::SArray<T> reduced_data(total_size, 0); | |||
| BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_, | |||
| unique_sparse_grad.indices_, &reduced_data); | |||
| kvs.lens = reduced_lens; | |||
| kvs.vals = reduced_data; | |||
| } | |||
| if (indices_size <= 0) { | |||
| ::ps::SArray<T> no_keys; | |||
| ::ps::SArray<T> no_vals; | |||
| ::ps::SArray<T> no_lens; | |||
| no_keys.push_back(key); | |||
| no_vals.push_back(-100); | |||
| kvs.vals = no_vals; | |||
| kvs.lens = no_lens; | |||
| } | |||
| sliced->at(i).first = true; | |||
| expected_result_count_[timestamp] += 1; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::PrepareSparseGradient(const size_t begin, const size_t end, | |||
| const std::unordered_set<int> &distinct_ids, | |||
| const std::vector<std::pair<int, T *>> &indice_to_grads, | |||
| const int *all_indice, const size_t segment_size, T *gradient, | |||
| int *indices) { | |||
| MS_EXCEPTION_IF_NULL(all_indice); | |||
| MS_EXCEPTION_IF_NULL(gradient); | |||
| MS_EXCEPTION_IF_NULL(indices); | |||
| int64_t offset = 0; | |||
| int64_t index = 0; | |||
| size_t segment_data_size = segment_size * sizeof(T); | |||
| size_t dst_size; | |||
| size_t src_size; | |||
| void *dst_data = nullptr; | |||
| void *src_data = nullptr; | |||
| for (auto &pair : indice_to_grads) { | |||
| if (distinct_ids.count(pair.first) == 0) { | |||
| continue; | |||
| } | |||
| indices[index++] = pair.first; | |||
| dst_size = segment_data_size; | |||
| src_size = segment_data_size; | |||
| dst_data = gradient + offset; | |||
| src_data = pair.second; | |||
| MS_EXCEPTION_IF_NULL(dst_data); | |||
| MS_EXCEPTION_IF_NULL(src_data); | |||
| auto ret = memcpy_s(gradient + offset, dst_size, pair.second, src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return; | |||
| } | |||
| offset += segment_size; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::BuildSparseValue(const ::ps::SArray<int> &lengths, const size_t grad_index, | |||
| const size_t indice_index, const T *original_data, const T *grads, int *indices, | |||
| ::ps::SArray<T> *reduced_data) { | |||
| MS_EXCEPTION_IF_NULL(original_data); | |||
| MS_EXCEPTION_IF_NULL(grads); | |||
| MS_EXCEPTION_IF_NULL(indices); | |||
| MS_EXCEPTION_IF_NULL(reduced_data); | |||
| int64_t offset = 0; | |||
| size_t dst_size = 0; | |||
| size_t src_size = 0; | |||
| void *dst_data = nullptr; | |||
| void *src_data = nullptr; | |||
| for (size_t i = 0; i < lengths.size(); i++) { | |||
| if (i != grad_index && i != indice_index) { | |||
| int data_size = lengths[i] * sizeof(T); | |||
| dst_size = data_size; | |||
| src_size = data_size; | |||
| dst_data = reduced_data->data() + offset; | |||
| src_data = const_cast<T *>(original_data) + offset; | |||
| MS_EXCEPTION_IF_NULL(dst_data); | |||
| MS_EXCEPTION_IF_NULL(src_data); | |||
| auto ret = memcpy_s(dst_data, dst_size, src_data, src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return; | |||
| } | |||
| } | |||
| offset += lengths[i]; | |||
| } | |||
| // Fill the reduced gradient | |||
| int64_t grad_offset = 0; | |||
| for (size_t i = 0; i < grad_index; i++) { | |||
| grad_offset += lengths[i]; | |||
| } | |||
| int64_t data_size = lengths[grad_index] * sizeof(T); | |||
| dst_size = data_size; | |||
| src_size = data_size; | |||
| dst_data = reduced_data->data() + grad_offset; | |||
| src_data = const_cast<T *>(grads); | |||
| MS_EXCEPTION_IF_NULL(dst_data); | |||
| MS_EXCEPTION_IF_NULL(src_data); | |||
| auto ret = memcpy_s(dst_data, dst_size, src_data, src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return; | |||
| } | |||
| // Fill the reduced indice | |||
| int64_t indice_offset = grad_offset + lengths[grad_index]; | |||
| data_size = lengths[indice_index] * sizeof(T); | |||
| T *indice_data = reduced_data->data() + indice_offset; | |||
| dst_size = data_size; | |||
| src_size = data_size; | |||
| dst_data = indice_data; | |||
| src_data = indices; | |||
| MS_EXCEPTION_IF_NULL(dst_data); | |||
| MS_EXCEPTION_IF_NULL(src_data); | |||
| ret = memcpy_s(dst_data, dst_size, src_data, src_size); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| return; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::BroadcastSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attr) { | |||
| MS_EXCEPTION_IF_NULL(sliced); | |||
| sliced->resize(server_num_); | |||
| for (int64_t i = 0; i < server_num_; i++) { | |||
| sliced->at(i).first = true; | |||
| sliced->at(i).second = send; | |||
| expected_result_count_[timestamp] += 1; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::RoundRobinSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attr) { | |||
| MS_EXCEPTION_IF_NULL(sliced); | |||
| sliced->resize(server_num_); | |||
| auto keys = send.keys; | |||
| auto vals = send.vals; | |||
| auto lens = send.lens; | |||
| int64_t server_id, len; | |||
| ::ps::Key param_key; | |||
| for (size_t i = 0; i < keys.size(); i++) { | |||
| param_key = keys[i]; | |||
| server_id = key_to_server_id_[param_key]; | |||
| if (!sliced->at(server_id).first) { | |||
| sliced->at(server_id).first = true; | |||
| expected_result_count_[timestamp] += 1; | |||
| } | |||
| ::ps::KVPairs<T> &server_kv_pairs = sliced->at(server_id).second; | |||
| server_kv_pairs.keys.push_back(param_key); | |||
| if (vals.empty()) { | |||
| continue; | |||
| } | |||
| len = lens[i]; | |||
| int64_t offset = std::accumulate(lens.begin(), lens.begin() + i, 0); | |||
| auto val_begin = vals.begin() + offset; | |||
| auto val_end = val_begin + len; | |||
| for (auto iter = val_begin; iter != val_end; iter++) { | |||
| server_kv_pairs.vals.push_back(*iter); | |||
| } | |||
| server_kv_pairs.lens.push_back(len); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, | |||
| const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attrs) { | |||
| MS_EXCEPTION_IF_NULL(sliced); | |||
| sliced->resize(server_num_); | |||
| auto keys = send.keys; | |||
| auto vals = send.vals; | |||
| auto lens = send.lens; | |||
| size_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]]; | |||
| const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[keys[0]]); | |||
| for (size_t i = 0; i < ranges.size(); i++) { | |||
| size_t offset_begin = ranges[i].begin() * col_cnt; | |||
| size_t offset_end = (ranges[i].end() + 1) * col_cnt; | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.vals = vals.segment(offset_begin, offset_end); | |||
| kvs.lens.push_back(offset_end - offset_begin); | |||
| sliced->at(i).first = true; | |||
| sliced->at(i).second = kvs; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, | |||
| const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, | |||
| const std::map<int64_t, int64_t> &attrs) { | |||
| MS_EXCEPTION_IF_NULL(sliced); | |||
| T *embedding_vals = send.vals.data(); | |||
| int *lookup_ids = send.lens.data(); | |||
| size_t val_size = send.vals.size(); | |||
| size_t id_size = send.lens.size(); | |||
| size_t embedding_dim = val_size / id_size; | |||
| const Key &key = send.keys[0]; | |||
| const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); | |||
| sliced->resize(ranges.size()); | |||
| for (size_t i = 0; i < ranges.size(); i++) { | |||
| const ::ps::Range &range = ranges[i]; | |||
| const auto &begin = range.begin(); | |||
| const auto &end = range.end(); | |||
| auto &kvs = sliced->at(i).second; | |||
| kvs.keys.push_back(key); | |||
| for (size_t j = 0; j < id_size; j++) { | |||
| auto lookup_id = static_cast<uint64_t>(lookup_ids[j]); | |||
| if (lookup_id >= begin && lookup_id <= end) { | |||
| kvs.keys.push_back(lookup_id); | |||
| for (size_t k = 0; k < embedding_dim; k++) { | |||
| kvs.vals.push_back(embedding_vals[j * embedding_dim + k]); | |||
| } | |||
| } | |||
| } | |||
| if (kvs.keys.size() <= 1) { | |||
| sliced->at(i).first = false; | |||
| } else { | |||
| sliced->at(i).first = true; | |||
| expected_result_count_[timestamp] += 1; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) { | |||
| int64_t ts = msg.meta.timestamp; | |||
| if (msg.meta.pull) { | |||
| CHECK_GE(msg.data.size(), (size_t)2); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = msg.data[0]; | |||
| kvs.vals = msg.data[1]; | |||
| if (msg.data.size() > (size_t)2) { | |||
| kvs.lens = msg.data[2]; | |||
| } | |||
| mutex_.lock(); | |||
| lookup_results_[ts].push_back(kvs); | |||
| mutex_.unlock(); | |||
| } | |||
| if (lookup_customer_->NumResponse(ts) + 1 == server_num_) { | |||
| const auto &cb = lookup_callbacks_[ts]; | |||
| cb(); | |||
| lookup_callbacks_.erase(ts); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::ProcessResponse(const ::ps::Message &msg) { | |||
| int64_t ts = msg.meta.timestamp; | |||
| if (msg.meta.pull) { | |||
| CHECK_GE(msg.data.size(), (size_t)2); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = msg.data[0]; | |||
| kvs.vals = msg.data[1]; | |||
| if (msg.data.size() > (size_t)2) { | |||
| kvs.lens = msg.data[2]; | |||
| } | |||
| mutex_.lock(); | |||
| int rsp_server_rank = ::ps::Postoffice::Get()->IDtoRank(msg.meta.sender); | |||
| gathered_response_[ts][rsp_server_rank] = kvs; | |||
| mutex_.unlock(); | |||
| if (general_customer_->NumResponse(ts) + 1 == server_num_) { | |||
| const auto &cb = general_callbacks_[ts]; | |||
| cb(); | |||
| general_callbacks_.erase(ts); | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, | |||
| const ::ps::KVPairs<T> &kvs, const Slicer &slicer, std::map<int64_t, int64_t> attrs) { | |||
| MS_EXCEPTION_IF_NULL(customer); | |||
| SlicedKVs sliced; | |||
| slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced, attrs); | |||
| for (size_t i = 0; i < sliced.size(); i++) { | |||
| const auto &s = sliced[i]; | |||
| if (!s.first) continue; | |||
| ::ps::Message msg; | |||
| msg.meta.app_id = customer->app_id(); | |||
| msg.meta.customer_id = customer->customer_id(); | |||
| msg.meta.request = true; | |||
| msg.meta.push = push; | |||
| msg.meta.pull = pull; | |||
| msg.meta.head = cmd; | |||
| msg.meta.timestamp = timestamp; | |||
| msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i); | |||
| msg.meta.priority = kvs.priority; | |||
| const auto &kvs = s.second; | |||
| if (kvs.keys.size()) { | |||
| msg.AddData(kvs.keys); | |||
| msg.AddData(kvs.vals); | |||
| if (kvs.lens.size()) { | |||
| msg.AddData(kvs.lens); | |||
| } | |||
| } | |||
| ::ps::Postoffice::Get()->van()->Send(msg); | |||
| } | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PS_WORKER_PROXY_H_ | |||
| @@ -24,7 +24,7 @@ namespace mindspore { | |||
| namespace device { | |||
| void KernelRuntimeManager::ClearRuntimeResource() { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| if (ps::PSContext::instance()->is_worker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps::ps_cache_instance.SyncEmbeddingTable(); | |||
| } | |||
| #endif | |||
| @@ -78,7 +78,6 @@ export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| export RANK_TABLE_FILE=$PATH1 | |||
| export MS_COMM_TYPE=zmq | |||
| export MS_SCHED_NUM=1 | |||
| export MS_WORKER_NUM=$RANK_SIZE | |||
| export MS_SERVER_NUM=8 | |||
| @@ -70,7 +70,6 @@ fi | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| export MS_COMM_TYPE=zmq | |||
| export MS_SCHED_NUM=1 | |||
| export MS_WORKER_NUM=8 | |||
| export MS_SERVER_NUM=8 | |||
| @@ -27,7 +27,6 @@ export EPOCH_SIZE=$2 | |||
| export DEVICE_TARGET=$3 | |||
| export DATASET=$4 | |||
| export MS_COMM_TYPE=zmq | |||
| export MS_SCHED_NUM=1 | |||
| export MS_WORKER_NUM=$RANK_SIZE | |||
| export LOCAL_WORKER_NUM=$5 | |||
| @@ -25,7 +25,6 @@ export RANK_SIZE=$1 | |||
| export EPOCH_SIZE=$2 | |||
| export DEVICE_TARGET=$3 | |||
| export DATASET=$4 | |||
| export MS_COMM_TYPE=zmq | |||
| export MS_SCHED_NUM=1 | |||
| export MS_WORKER_NUM=$RANK_SIZE | |||
| export MS_SERVER_NUM=$5 | |||
| @@ -23,7 +23,6 @@ self_path=$(dirname "${script_self}") | |||
| export EPOCH_SIZE=$1 | |||
| export DEVICE_TARGET=$2 | |||
| export DATASET=$3 | |||
| export MS_COMM_TYPE=zmq | |||
| export MS_SCHED_NUM=1 | |||
| export MS_WORKER_NUM=1 | |||
| export MS_SERVER_NUM=$4 | |||
| @@ -15,8 +15,7 @@ | |||
| # ============================================================================ | |||
| execute_path=$(pwd) | |||
| self_path=$(dirname "${script_self}") | |||
| export MS_COMM_TYPE=zmq | |||
| self_path=$(dirname $0) | |||
| export MS_SCHED_NUM=1 | |||
| DEVICE_TARGET=$1 | |||
| export MS_WORKER_NUM=$2 | |||
| @@ -15,8 +15,7 @@ | |||
| # ============================================================================ | |||
| execute_path=$(pwd) | |||
| self_path=$(dirname "${script_self}") | |||
| export MS_COMM_TYPE=zmq | |||
| self_path=$(dirname $0) | |||
| export MS_SCHED_NUM=1 | |||
| DEVICE_TARGET=$1 | |||
| DATASET_PATH=$2 | |||
| @@ -15,8 +15,7 @@ | |||
| # ============================================================================ | |||
| execute_path=$(pwd) | |||
| self_path=$(dirname "${script_self}") | |||
| export MS_COMM_TYPE=zmq | |||
| self_path=$(dirname $0) | |||
| export MS_SCHED_NUM=1 | |||
| DEVICE_TARGET=$1 | |||
| export MS_WORKER_NUM=$2 | |||
| @@ -15,8 +15,7 @@ | |||
| # ============================================================================ | |||
| execute_path=$(pwd) | |||
| self_path=$(dirname "${script_self}") | |||
| export MS_COMM_TYPE=zmq | |||
| self_path=$(dirname $0) | |||
| export MS_SCHED_NUM=1 | |||
| DEVICE_TARGET=$1 | |||
| DATASET_PATH=$2 | |||
| @@ -150,6 +150,8 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/internal/parame | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/scheduler.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/optimizer_info_builder.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/worker.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/parameter_server.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") | |||
| list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc") | |||
| @@ -1,255 +0,0 @@ | |||
| diff -Npur ps-lite-master/include/dmlc/base.h ps-lite-master-new/include/dmlc/base.h | |||
| --- ps-lite-master/include/dmlc/base.h 2020-02-29 13:59:55.000000000 +0800 | |||
| +++ ps-lite-master-new/include/dmlc/base.h 2020-07-01 11:56:50.444833389 +0800 | |||
| @@ -8,7 +8,7 @@ | |||
| /*! \brief whether use glog for logging */ | |||
| #ifndef DMLC_USE_GLOG | |||
| -#define DMLC_USE_GLOG 0 | |||
| +#define DMLC_USE_GLOG 1 | |||
| #endif | |||
| /*! | |||
| diff -Npur ps-lite-master/include/dmlc/logging.h ps-lite-master-new/include/dmlc/logging.h | |||
| --- ps-lite-master/include/dmlc/logging.h 2020-02-29 13:59:55.000000000 +0800 | |||
| +++ ps-lite-master-new/include/dmlc/logging.h 2020-07-08 21:35:33.334584767 +0800 | |||
| @@ -52,7 +52,7 @@ struct Error : public std::runtime_error | |||
| namespace dmlc { | |||
| inline void InitLogging(const char* argv0) { | |||
| - google::InitGoogleLogging(argv0); | |||
| + //google::InitGoogleLogging(argv0); | |||
| } | |||
| } // namespace dmlc | |||
| diff -Npur ps-lite-master/make/deps.mk ps-lite-master-new/make/deps.mk | |||
| --- ps-lite-master/make/deps.mk 2020-02-29 13:59:55.000000000 +0800 | |||
| +++ ps-lite-master-new/make/deps.mk 2020-06-17 10:35:46.253837426 +0800 | |||
| @@ -1,69 +1,7 @@ | |||
| # Install dependencies | |||
| - | |||
| -URL1=https://raw.githubusercontent.com/mli/deps/master/build | |||
| -URL2=https://github.com/google/protobuf/releases/download/v3.5.1 | |||
| -ifndef WGET | |||
| -WGET = wget | |||
| -endif | |||
| - | |||
| -# protobuf | |||
| -PROTOBUF = ${DEPS_PATH}/include/google/protobuf/message.h | |||
| -${PROTOBUF}: | |||
| - $(eval FILE=protobuf-cpp-3.5.1.tar.gz) | |||
| - $(eval DIR=protobuf-3.5.1) | |||
| - rm -rf $(FILE) $(DIR) | |||
| - $(WGET) $(URL2)/$(FILE) && tar --no-same-owner -zxf $(FILE) | |||
| - cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install | |||
| - rm -rf $(FILE) $(DIR) | |||
| - | |||
| # zmq | |||
| -ZMQ = ${DEPS_PATH}/include/zmq.h | |||
| +ZMQ = $(MS_ZMQ_INSTALL_PATH)/lib/libzmq.a | |||
| ${ZMQ}: | |||
| - $(eval FILE=zeromq-4.1.4.tar.gz) | |||
| - $(eval DIR=zeromq-4.1.4) | |||
| - rm -rf $(FILE) $(DIR) | |||
| - $(WGET) $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE) | |||
| - cd $(DIR) && export CFLAGS=-fPIC && export CXXFLAGS=-fPIC && ./configure -prefix=$(DEPS_PATH) --with-libsodium=no --with-libgssapi_krb5=no && $(MAKE) && $(MAKE) install | |||
| - rm -rf $(FILE) $(DIR) | |||
| - | |||
| -# lz4 | |||
| -LZ4 = ${DEPS_PATH}/include/lz4.h | |||
| -${LZ4}: | |||
| - $(eval FILE=lz4-r129.tar.gz) | |||
| - $(eval DIR=lz4-r129) | |||
| - rm -rf $(FILE) $(DIR) | |||
| - wget $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE) | |||
| - cd $(DIR) && $(MAKE) && PREFIX=$(DEPS_PATH) $(MAKE) install | |||
| - rm -rf $(FILE) $(DIR) | |||
| - | |||
| -# cityhash | |||
| -CITYHASH = ${DEPS_PATH}/include/city.h | |||
| -${CITYHASH}: | |||
| - $(eval FILE=cityhash-1.1.1.tar.gz) | |||
| - $(eval DIR=cityhash-1.1.1) | |||
| - rm -rf $(FILE) $(DIR) | |||
| - wget $(URL1)/$(FILE) && tar --no-same-owner -zxf $(FILE) | |||
| - cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --enable-sse4.2 && $(MAKE) CXXFLAGS="-g -O3 -msse4.2" && $(MAKE) install | |||
| - rm -rf $(FILE) $(DIR) | |||
| - | |||
| - | |||
| -# # gflags | |||
| -# ${DEPS_PATH}/include/google/gflags.h: | |||
| -# $(eval FILE=gflags-2.0-no-svn-files.tar.gz) | |||
| -# $(eval DIR=gflags-2.0) | |||
| -# rm -rf $(FILE) $(DIR) | |||
| -# wget $(URL)/$(FILE) && tar -zxf $(FILE) | |||
| -# cd $(DIR) && ./configure -prefix=$(DEPS_PATH) && $(MAKE) && $(MAKE) install | |||
| -# rm -rf $(FILE) $(DIR) | |||
| -# gflags: | ${DEPS_PATH}/include/google/gflags.h | |||
| + cd $(MS_ZMQ_DIR) && export CFLAGS="-fPIC -D_GLIBCXX_USE_CXX11_ABI=0" && export CXXFLAGS=-fPIC && ./configure -prefix=$(MS_ZMQ_INSTALL_PATH) --with-libsodium=no --with-libgssapi_krb5=no && $(MAKE) && $(MAKE) install | |||
| -# # glog | |||
| -# ${DEPS_PATH}/include/glog/logging.h: | ${DEPS_PATH}/include/google/gflags.h | |||
| -# $(eval FILE=v0.3.4.tar.gz) | |||
| -# $(eval DIR=glog-0.3.4) | |||
| -# rm -rf $(FILE) $(DIR) | |||
| -# wget https://github.com/google/glog/archive/$(FILE) && tar -zxf $(FILE) | |||
| -# cd $(DIR) && ./configure -prefix=$(DEPS_PATH) --with-gflags=$(DEPS_PATH) && $(MAKE) && $(MAKE) install | |||
| -# rm -rf $(FILE) $(DIR) | |||
| -# glog: | ${DEPS_PATH}/include/glog/logging.h | |||
| diff -Npur ps-lite-master/make/ps.mk ps-lite-master-new/make/ps.mk | |||
| --- ps-lite-master/make/ps.mk 2020-02-29 13:59:55.000000000 +0800 | |||
| +++ ps-lite-master-new/make/ps.mk 2020-06-05 09:28:35.337740291 +0800 | |||
| @@ -9,5 +9,5 @@ ifeq ($(USE_KEY32), 1) | |||
| ADD_CFLAGS += -DUSE_KEY32=1 | |||
| endif | |||
| -PS_LDFLAGS_SO = -L$(DEPS_PATH)/lib -lprotobuf-lite -lzmq | |||
| -PS_LDFLAGS_A = $(addprefix $(DEPS_PATH)/lib/, libprotobuf-lite.a libzmq.a) | |||
| +PS_LDFLAGS_SO = -L$(MS_ZMQ_INSTALL_PATH)/lib -lzmq -L$(MS_PROTO_LIB_DIR) -lprotobuf-lite | |||
| +PS_LDFLAGS_A = $(addprefix $(MS_ZMQ_INSTALL_PATH)/lib -L$(MS_PROTO_LIB_DIR), libprotobuf-lite.a libzmq.a) | |||
| diff -Npur ps-lite-master/Makefile ps-lite-master-new/Makefile | |||
| --- ps-lite-master/Makefile 2020-02-29 13:59:55.000000000 +0800 | |||
| +++ ps-lite-master-new/Makefile 2020-06-17 11:09:20.240322660 +0800 | |||
| @@ -12,13 +12,24 @@ ifndef DEPS_PATH | |||
| DEPS_PATH = $(shell pwd)/deps | |||
| endif | |||
| +MS_PROTO_DIR = @protobuf_DIRPATH@ | |||
| +MS_GLOG_DIR = @glog_DIRPATH@ | |||
| +MS_ZMQ_DIR = @zeromq_DIRPATH@ | |||
| + | |||
| +MS_PROTO_LIB_DIR = @protobuf_LIBPATH@ | |||
| +MS_GLOG_LIB_DIR = @glog_LIBPATH@ | |||
| +MS_ZMQ_INSTALL_PATH = $(MS_ZMQ_DIR)/zmq_install | |||
| ifndef PROTOC | |||
| -PROTOC = ${DEPS_PATH}/bin/protoc | |||
| +PROTOC = $(MS_PROTO_DIR)/bin/protoc | |||
| endif | |||
| -INCPATH = -I./src -I./include -I$(DEPS_PATH)/include | |||
| -CFLAGS = -std=c++11 -msse2 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS) | |||
| +INCPATH = -I./src -I./include -I$(MS_ZMQ_INSTALL_PATH)/include | |||
| +INCPATH += -I$(MS_PROTO_DIR)/include | |||
| +INCPATH += -I$(MS_GLOG_DIR)/include | |||
| + | |||
| +CXXFLAGS = -D_GLIBCXX_USE_CXX11_ABI=0 | |||
| +CFLAGS = -std=c++11 -fPIC -O3 -ggdb -Wall -finline-functions $(INCPATH) $(ADD_CFLAGS) -D_GLIBCXX_USE_CXX11_ABI=0 | |||
| LIBS = -pthread | |||
| ifdef USE_IBVERBS | |||
| @@ -30,6 +41,7 @@ ifdef ASAN | |||
| CFLAGS += -fsanitize=address -fno-omit-frame-pointer -fno-optimize-sibling-calls | |||
| endif | |||
| +LIBS += -L$(MS_GLOG_LIB_DIR) -lglog | |||
| all: ps test | |||
| @@ -51,9 +63,9 @@ build/libps.a: $(OBJS) | |||
| build/%.o: src/%.cc ${ZMQ} src/meta.pb.h | |||
| @mkdir -p $(@D) | |||
| $(CXX) $(INCPATH) -std=c++11 -MM -MT build/$*.o $< >build/$*.d | |||
| - $(CXX) $(CFLAGS) $(LIBS) -c $< -o $@ | |||
| + $(CXX) $(CFLAGS) $(CXXFLAGS) $(LIBS) -c $< -o $@ | |||
| -src/%.pb.cc src/%.pb.h : src/%.proto ${PROTOBUF} | |||
| +src/%.pb.cc src/%.pb.h : src/%.proto | |||
| $(PROTOC) --cpp_out=./src --proto_path=./src $< | |||
| -include build/*.d | |||
| diff -Npur ps-lite-master/src/ibverbs_van.h ps-lite-master-new/src/ibverbs_van.h | |||
| --- ps-lite-master/src/ibverbs_van.h 2020-02-29 13:59:55.000000000 +0800 | |||
| +++ ps-lite-master-new/src/ibverbs_van.h 2020-06-02 20:52:11.076230014 +0800 | |||
| @@ -145,15 +145,15 @@ class SimpleMempool { | |||
| total_allocated_size += new_mem_size; | |||
| } | |||
| - CHECK_NE(free_list.end(), it) << "Not enough memory"; | |||
| + //CHECK_NE(free_list.end(), it) << "Not enough memory"; | |||
| CHECK_GE(it->first, proper_size); | |||
| char *addr = it->second; | |||
| size_t space_left = it->first - proper_size; | |||
| free_list.erase(it); | |||
| - CHECK_EQ(used_list.find(addr), used_list.end()) | |||
| - << "Address is already allocated"; | |||
| + //CHECK_EQ(used_list.find(addr), used_list.end()) | |||
| + //<< "Address is already allocated"; | |||
| used_list.emplace(addr, proper_size); | |||
| @@ -173,8 +173,8 @@ class SimpleMempool { | |||
| std::lock_guard<std::mutex> lk(mu_); | |||
| auto it = used_list.find(addr); | |||
| - CHECK_NE(used_list.end(), it) | |||
| - << "Cannot find info about address: " << (uintptr_t)addr; | |||
| + //CHECK_NE(used_list.end(), it) | |||
| + //<< "Cannot find info about address: " << (uintptr_t)addr; | |||
| size_t size = it->second; | |||
| used_list.erase(it); | |||
| @@ -208,7 +208,7 @@ class SimpleMempool { | |||
| // Convert the memory address to its associated RDMA memory region | |||
| inline struct ibv_mr *Addr2MR(char *addr) { | |||
| auto it = mr_list.lower_bound(addr); | |||
| - CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region"; | |||
| + //CHECK_NE(it, mr_list.end()) << "cannot find the associated memory region"; | |||
| return it->second; | |||
| } | |||
| }; | |||
| @@ -330,7 +330,7 @@ class AddressPool { | |||
| CHECK(ptr); | |||
| uint32_t idx = indices_.front(); | |||
| indices_.pop(); | |||
| - CHECK_EQ(table_[idx], nullptr); | |||
| + //CHECK_EQ(table_[idx], nullptr); | |||
| table_[idx] = ptr; | |||
| return idx; | |||
| } | |||
| @@ -636,7 +636,7 @@ class IBVerbsVan : public Van { | |||
| PBMeta meta; | |||
| PackMetaPB(msg.meta, &meta); | |||
| - CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); | |||
| + //CHECK_NE(endpoints_.find(remote_id), endpoints_.end()); | |||
| Endpoint *endpoint = endpoints_[remote_id].get(); | |||
| MessageBuffer *msg_buf = new MessageBuffer(); | |||
| diff -Npur ps-lite-master/src/van.cc ps-lite-master-new/src/van.cc | |||
| --- ps-lite-master/src/van.cc 2020-02-29 13:59:55.000000000 +0800 | |||
| +++ ps-lite-master-new/src/van.cc 2020-06-02 20:52:43.330405828 +0800 | |||
| @@ -448,6 +448,7 @@ void Van::PackMetaPB(const Meta& meta, P | |||
| if (meta.timestamp != Meta::kEmpty) pb->set_timestamp(meta.timestamp); | |||
| if (meta.body.size()) pb->set_body(meta.body); | |||
| pb->set_push(meta.push); | |||
| + pb->set_pull(meta.pull); | |||
| pb->set_request(meta.request); | |||
| pb->set_simple_app(meta.simple_app); | |||
| pb->set_priority(meta.priority); | |||
| diff -Npur ps-lite-master/tests/test.mk ps-lite-master-new/tests/test.mk | |||
| --- ps-lite-master/tests/test.mk 2020-02-29 13:59:55.000000000 +0800 | |||
| +++ ps-lite-master-new/tests/test.mk 2020-06-16 19:15:06.025087897 +0800 | |||
| @@ -1,10 +1,10 @@ | |||
| -TEST_SRC = $(wildcard tests/test_*.cc) | |||
| -TEST = $(patsubst tests/test_%.cc, tests/test_%, $(TEST_SRC)) | |||
| +#TEST_SRC = $(wildcard tests/test_*.cc) | |||
| +#TEST = $(patsubst tests/test_%.cc, tests/test_%, $(TEST_SRC)) | |||
| -# -ltcmalloc_and_profiler | |||
| -LDFLAGS = -Wl,-rpath,$(DEPS_PATH)/lib $(PS_LDFLAGS_SO) -pthread | |||
| -tests/% : tests/%.cc build/libps.a | |||
| - $(CXX) $(CFLAGS) $(LIBS) -MM -MT tests/$* $< >tests/$*.d | |||
| - $(CXX) $(CFLAGS) $(LIBS) -o $@ $(filter %.cc %.a, $^) $(LDFLAGS) | |||
| - | |||
| --include tests/*.d | |||
| +## -ltcmalloc_and_profiler | |||
| +#LDFLAGS = -Wl,-rpath,$(DEPS_PATH)/lib $(PS_LDFLAGS_SO) -pthread | |||
| +#tests/% : tests/%.cc build/libps.a | |||
| +# $(CXX) $(CFLAGS) $(LIBS) -MM -MT tests/$* $< >tests/$*.d | |||
| +# $(CXX) $(CFLAGS) $(LIBS) -o $@ $(filter %.cc %.a, $^) $(LDFLAGS) | |||
| +# | |||
| +#-include tests/*.d | |||