/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ #define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ #include #include #include #include #include #include "ps/ps.h" #include "parallel/ps/util.h" namespace mindspore { namespace parallel { namespace ps { template class WorkerProxy : public ::ps::KVWorker { public: using Worker = ::ps::KVWorker; using Callback = std::function; using SlicedKVs = std::vector>>; using Slicer = std::function &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>; using ::ps::SimpleApp::obj_; explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) { using _1 = std::placeholders::_1; using _2 = std::placeholders::_2; using _3 = std::placeholders::_3; lookup_customer_ = std::unique_ptr<::ps::Customer>( new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy::ProcessLookupResult, this, _1))); lookup_slicer_ = std::bind(&WorkerProxy::LookupIdSlicer, this, _1, _2, _3); init_embedding_slicer_ = std::bind(&WorkerProxy::EmbeddingTableInitSlicer, this, _1, _2, _3); push_slicer_ = std::bind(&WorkerProxy::PushSlicer, this, _1, _2, _3); broadcast_slicer_ = std::bind(&WorkerProxy::BroadcastSlicer, this, _1, _2, _3); } ~WorkerProxy() override = default; void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *outs, int cmd = 0, const Callback &cb = nullptr, int priority = 0); int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, const Callback &cb = nullptr, int priority = 0); void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens = {}, int cmd = 0, int priority = 0); private: template int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *vals, int cmd, const Callback &cb); void LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced); void EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced); void PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced); void BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced); void ProcessLookupResult(const ::ps::Message &msg); void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs &kvs, const Slicer &slicer); std::unique_ptr<::ps::Customer> lookup_customer_; std::unordered_map<::ps::Key, std::shared_ptr>> embedding_table_ranges_; std::unordered_map>> lookup_results_; std::mutex mutex_; Slicer lookup_slicer_; Slicer init_embedding_slicer_; Slicer push_slicer_; Slicer broadcast_slicer_; std::unordered_map lookup_callbacks_; }; template void WorkerProxy::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { uint64_t begin = 0; uint64_t end = 0; int server_num = ::ps::NumServers(); for (int i = 0; i < server_num; i++) { int local_row_cnt = Util::LocalShard(row_count, i, server_num); if (i == 0) { end = local_row_cnt - 1; } else { begin = end + 1; end += local_row_cnt; } ::ps::Range range(begin, end); if (embedding_table_ranges_.count(key) == 0) { embedding_table_ranges_[key] = std::make_shared>(); } embedding_table_ranges_[key]->push_back(range); } } template void WorkerProxy::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *outs, int cmd, const Callback &cb, int priority) { int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); ::ps::KVPairs kvs; kvs.keys = keys; kvs.vals = lookup_ids; kvs.lens = lens; kvs.priority = priority; Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_); lookup_customer_->WaitRequest(ts); } template int WorkerProxy::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens, const Callback &cb, int priority) { int ts = obj_->NewRequest(::ps::kServerGroup); ::ps::KVPairs kvs; kvs.keys = keys; kvs.vals = vals; kvs.lens = lens; kvs.priority = priority; Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_); return ts; } template void WorkerProxy::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &vals, const ::ps::SArray &lens, int cmd, int priority) { int ts = obj_->NewRequest(::ps::kServerGroup); ::ps::KVPairs kvs; kvs.keys = keys; kvs.vals = vals; kvs.lens = lens; kvs.priority = priority; Send(obj_, ts, true, false, cmd, kvs, push_slicer_); obj_->WaitRequest(ts); } template template int WorkerProxy::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, C *lookup_result, int cmd, const Callback &cb) { int ts = lookup_customer_->NewRequest(::ps::kServerGroup); const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { mutex_.lock(); auto &kvs = lookup_results_[ts]; mutex_.unlock(); size_t total_len = 0; const auto &s = kvs[0]; for (size_t i = 0; i < s.lens.size(); i++) { total_len += s.lens[i]; } lookup_result->resize(total_len, 0); T *result_addr = lookup_result->data(); for (const auto &s : kvs) { size_t offset = 0; for (size_t i = 0; i < s.vals.size(); i++) { result_addr[offset++] += s.vals[i]; } } mutex_.lock(); lookup_results_.erase(ts); mutex_.unlock(); if (cb) cb(); }; lookup_callbacks_[ts] = callback; return ts; } template void WorkerProxy::LookupIdSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced) { int *data = send.lens.data(); size_t size = send.lens.size(); std::vector lookup_ids(data, data + size); std::sort(lookup_ids.begin(), lookup_ids.end()); const Key &key = send.keys[0]; const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); sliced->resize(ranges.size()); size_t index = 0; for (size_t i = 0; i < ranges.size(); i++) { const ::ps::Range &range = ranges[i]; const auto &begin = range.begin(); const auto &end = range.end(); auto &kvs = sliced->at(i).second; auto lookup_id = static_cast(lookup_ids[index]); while (lookup_id >= begin && lookup_id <= end) { kvs.vals.push_back(lookup_id); if (++index >= lookup_ids.size()) { break; } lookup_id = static_cast(lookup_ids[index]); } kvs.keys.push_back(key); kvs.lens.push_back(kvs.vals.size()); if (kvs.vals.size() == 0) { sliced->at(i).first = false; } else { sliced->at(i).first = true; } } } template void WorkerProxy::EmbeddingTableInitSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced) { const Key &key = send.keys[0]; const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); sliced->resize(ranges.size()); for (size_t i = 0; i < ranges.size(); i++) { sliced->at(i).first = true; sliced->at(i).second = send; } } template void WorkerProxy::PushSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced) { auto server_num = ::ps::Postoffice::Get()->num_servers(); sliced->resize(server_num); for (int i = 0; i < server_num; i++) { sliced->at(i).first = true; sliced->at(i).second = send; } } template void WorkerProxy::BroadcastSlicer(const ::ps::KVPairs &send, const std::vector<::ps::Range> &, std::vector>> *sliced) { auto server_num = ::ps::Postoffice::Get()->num_servers(); sliced->resize(server_num); for (int i = 0; i < server_num; i++) { sliced->at(i).first = true; sliced->at(i).second = send; } } template void WorkerProxy::ProcessLookupResult(const ::ps::Message &msg) { int ts = msg.meta.timestamp; if (msg.meta.pull) { CHECK_GE(msg.data.size(), (size_t)2); ::ps::KVPairs kvs; kvs.keys = msg.data[0]; kvs.vals = msg.data[1]; if (msg.data.size() > (size_t)2) { kvs.lens = msg.data[2]; } mutex_.lock(); lookup_results_[ts].push_back(kvs); mutex_.unlock(); } if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) { const auto &cb = lookup_callbacks_[ts]; cb(); lookup_callbacks_.erase(ts); } } template void WorkerProxy::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs &kvs, const Slicer &slicer) { SlicedKVs sliced; slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced); for (size_t i = 0; i < sliced.size(); i++) { const auto &s = sliced[i]; if (!s.first) continue; ::ps::Message msg; msg.meta.app_id = customer->app_id(); msg.meta.customer_id = customer->customer_id(); msg.meta.request = true; msg.meta.push = push; msg.meta.pull = pull; msg.meta.head = cmd; msg.meta.timestamp = timestamp; msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i); msg.meta.priority = kvs.priority; const auto &kvs = s.second; if (kvs.keys.size()) { msg.AddData(kvs.keys); msg.AddData(kvs.vals); if (kvs.lens.size()) { msg.AddData(kvs.lens); } } ::ps::Postoffice::Get()->van()->Send(msg); } } } // namespace ps } // namespace parallel } // namespace mindspore #endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_