From 43251ab17088c04475dc274a0b2676a373f25546 Mon Sep 17 00:00:00 2001 From: cristoval Date: Wed, 19 Aug 2020 20:13:54 +0800 Subject: [PATCH] add embedding lookup unique id --- .../frontend/parallel/ps/parameter_server.h | 2 +- .../ccsrc/frontend/parallel/ps/worker_proxy.h | 39 ++++++++++++++++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index a16a3577d2..495866bd43 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -314,7 +314,7 @@ template void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { const Key &key = req_data.keys[0]; - for (size_t i = 0; i < req_data.keys.size(); i++) { + for (size_t i = 1; i < req_data.keys.size(); i++) { res->keys.push_back(req_data.keys[i]); } ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res); diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h index 9546c313d9..244d5b4d08 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker_proxy.h @@ -259,13 +259,29 @@ int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps: auto &kvs = lookup_results_[ts]; mutex_.unlock(); - ::ps::SArray result(kvs[0].vals.size(), 0); - for (auto k : kvs) { - for (size_t i = 0; i < k.vals.size(); i++) { - result[i] += k.vals[i]; + std::unordered_map>> id_addr_map; + for (const auto &s : kvs) { + int offset = 0; + int len = s.vals.size() / s.keys.size(); + for (size_t i = 0; i < s.keys.size(); i++) { + const Key &key = s.keys[i]; + T *addr = s.vals.data() + offset; + offset += len; + id_addr_map[key] = std::make_shared>(std::make_pair(addr, len)); } } - *lookup_result = result; + + T *result_addr = lookup_result->data(); + int offset = 0; + for (size_t i = 0; i < lookup_ids.size(); i++) { + auto &pair = id_addr_map[static_cast(lookup_ids[i])]; + int size = pair->second * sizeof(T); + auto ret = memcpy_s(result_addr + offset, size, pair->first, size); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + offset += pair->second; + } mutex_.lock(); lookup_results_.erase(ts); @@ -312,12 +328,23 @@ void WorkerProxy::LookupIdSlicer(int timestamp, const ::ps::KVPairs &send, sliced->resize(ranges.size()); for (size_t i = 0; i < ranges.size(); i++) { + const ::ps::Range &range = ranges[i]; + const auto &begin = range.begin(); + const auto &end = range.end(); + std::unordered_set unique_ids; auto &kvs = sliced->at(i).second; kvs.keys.push_back(key); kvs.vals.push_back(0.0f); + for (size_t j = 0; j < id_size; j++) { - kvs.keys.push_back(lookup_ids[j]); + auto lookup_id = static_cast(lookup_ids[j]); + if (lookup_id >= begin && lookup_id <= end) { + unique_ids.insert(lookup_id); + } + } + for (const auto &lookup_id : unique_ids) { + kvs.keys.push_back(lookup_id); kvs.vals.push_back(0.0f); }