| @@ -0,0 +1,311 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ | |||
| #define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ | |||
| #include <unordered_map> | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "ps/ps.h" | |||
| #include "parallel/ps/util.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| namespace ps { | |||
| template <typename T> | |||
| class WorkerProxy : public ::ps::KVWorker<T> { | |||
| public: | |||
| using Worker = ::ps::KVWorker<T>; | |||
| using Callback = std::function<void()>; | |||
| using SlicedKVs = std::vector<std::pair<bool, ::ps::KVPairs<T>>>; | |||
| using Slicer = | |||
| std::function<void(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>; | |||
| using ::ps::SimpleApp::obj_; | |||
| explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) { | |||
| using _1 = std::placeholders::_1; | |||
| using _2 = std::placeholders::_2; | |||
| using _3 = std::placeholders::_3; | |||
| lookup_customer_ = std::unique_ptr<::ps::Customer>( | |||
| new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy<T>::ProcessLookupResult, this, _1))); | |||
| lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3); | |||
| init_embedding_slicer_ = std::bind(&WorkerProxy<T>::EmbeddingTableInitSlicer, this, _1, _2, _3); | |||
| push_slicer_ = std::bind(&WorkerProxy<T>::PushSlicer, this, _1, _2, _3); | |||
| broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3); | |||
| } | |||
| ~WorkerProxy() override = default; | |||
| void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); | |||
| void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, | |||
| const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd = 0, const Callback &cb = nullptr, | |||
| int priority = 0); | |||
| int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, | |||
| const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int priority = 0); | |||
| void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {}, | |||
| int cmd = 0, int priority = 0); | |||
| private: | |||
| template <typename C> | |||
| int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, C *vals, int cmd, | |||
| const Callback &cb); | |||
| void LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced); | |||
| void EmbeddingTableInitSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced); | |||
| void PushSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced); | |||
| void BroadcastSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced); | |||
| void ProcessLookupResult(const ::ps::Message &msg); | |||
| void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs<T> &kvs, | |||
| const Slicer &slicer); | |||
| std::unique_ptr<::ps::Customer> lookup_customer_; | |||
| std::unordered_map<::ps::Key, std::shared_ptr<std::vector<::ps::Range>>> embedding_table_ranges_; | |||
| std::unordered_map<int, std::vector<::ps::KVPairs<T>>> lookup_results_; | |||
| std::mutex mutex_; | |||
| Slicer lookup_slicer_; | |||
| Slicer init_embedding_slicer_; | |||
| Slicer push_slicer_; | |||
| Slicer broadcast_slicer_; | |||
| std::unordered_map<int, Callback> lookup_callbacks_; | |||
| }; | |||
| template <typename T> | |||
| void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { | |||
| uint64_t begin = 0; | |||
| uint64_t end = 0; | |||
| int server_num = ::ps::NumServers(); | |||
| for (int i = 0; i < server_num; i++) { | |||
| int local_row_cnt = Util::LocalShard(row_count, i, server_num); | |||
| if (i == 0) { | |||
| end = local_row_cnt - 1; | |||
| } else { | |||
| begin = end + 1; | |||
| end += local_row_cnt; | |||
| } | |||
| ::ps::Range range(begin, end); | |||
| if (embedding_table_ranges_.count(key) == 0) { | |||
| embedding_table_ranges_[key] = std::make_shared<std::vector<::ps::Range>>(); | |||
| } | |||
| embedding_table_ranges_[key]->push_back(range); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, | |||
| const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd, const Callback &cb, | |||
| int priority) { | |||
| int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.vals = lookup_ids; | |||
| kvs.lens = lens; | |||
| kvs.priority = priority; | |||
| Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_); | |||
| lookup_customer_->WaitRequest(ts); | |||
| } | |||
| template <typename T> | |||
| int WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, | |||
| const ::ps::SArray<int> &lens, const Callback &cb, int priority) { | |||
| int ts = obj_->NewRequest(::ps::kServerGroup); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.vals = vals; | |||
| kvs.lens = lens; | |||
| kvs.priority = priority; | |||
| Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_); | |||
| return ts; | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, | |||
| const ::ps::SArray<int> &lens, int cmd, int priority) { | |||
| int ts = obj_->NewRequest(::ps::kServerGroup); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = keys; | |||
| kvs.vals = vals; | |||
| kvs.lens = lens; | |||
| kvs.priority = priority; | |||
| Send(obj_, ts, true, false, cmd, kvs, push_slicer_); | |||
| obj_->WaitRequest(ts); | |||
| } | |||
| template <typename T> | |||
| template <typename C> | |||
| int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, | |||
| C *lookup_result, int cmd, const Callback &cb) { | |||
| int ts = lookup_customer_->NewRequest(::ps::kServerGroup); | |||
| const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { | |||
| mutex_.lock(); | |||
| auto &kvs = lookup_results_[ts]; | |||
| mutex_.unlock(); | |||
| size_t total_len = 0; | |||
| const auto &s = kvs[0]; | |||
| for (size_t i = 0; i < s.lens.size(); i++) { | |||
| total_len += s.lens[i]; | |||
| } | |||
| lookup_result->resize(total_len, 0); | |||
| T *result_addr = lookup_result->data(); | |||
| for (const auto &s : kvs) { | |||
| size_t offset = 0; | |||
| for (size_t i = 0; i < s.vals.size(); i++) { | |||
| result_addr[offset++] += s.vals[i]; | |||
| } | |||
| } | |||
| mutex_.lock(); | |||
| lookup_results_.erase(ts); | |||
| mutex_.unlock(); | |||
| if (cb) cb(); | |||
| }; | |||
| lookup_callbacks_[ts] = callback; | |||
| return ts; | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) { | |||
| int *data = send.lens.data(); | |||
| size_t size = send.lens.size(); | |||
| std::vector<int> lookup_ids(data, data + size); | |||
| std::sort(lookup_ids.begin(), lookup_ids.end()); | |||
| const Key &key = send.keys[0]; | |||
| const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); | |||
| sliced->resize(ranges.size()); | |||
| size_t index = 0; | |||
| for (size_t i = 0; i < ranges.size(); i++) { | |||
| const ::ps::Range &range = ranges[i]; | |||
| const auto &begin = range.begin(); | |||
| const auto &end = range.end(); | |||
| auto &kvs = sliced->at(i).second; | |||
| auto lookup_id = static_cast<uint64_t>(lookup_ids[index]); | |||
| while (lookup_id >= begin && lookup_id <= end) { | |||
| kvs.vals.push_back(lookup_id); | |||
| if (++index >= lookup_ids.size()) { | |||
| break; | |||
| } | |||
| lookup_id = static_cast<uint64_t>(lookup_ids[index]); | |||
| } | |||
| kvs.keys.push_back(key); | |||
| kvs.lens.push_back(kvs.vals.size()); | |||
| if (kvs.vals.size() == 0) { | |||
| sliced->at(i).first = false; | |||
| } else { | |||
| sliced->at(i).first = true; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::EmbeddingTableInitSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) { | |||
| const Key &key = send.keys[0]; | |||
| const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); | |||
| sliced->resize(ranges.size()); | |||
| for (size_t i = 0; i < ranges.size(); i++) { | |||
| sliced->at(i).first = true; | |||
| sliced->at(i).second = send; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::PushSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) { | |||
| auto server_num = ::ps::Postoffice::Get()->num_servers(); | |||
| sliced->resize(server_num); | |||
| for (int i = 0; i < server_num; i++) { | |||
| sliced->at(i).first = true; | |||
| sliced->at(i).second = send; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::BroadcastSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, | |||
| std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) { | |||
| auto server_num = ::ps::Postoffice::Get()->num_servers(); | |||
| sliced->resize(server_num); | |||
| for (int i = 0; i < server_num; i++) { | |||
| sliced->at(i).first = true; | |||
| sliced->at(i).second = send; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) { | |||
| int ts = msg.meta.timestamp; | |||
| if (msg.meta.pull) { | |||
| CHECK_GE(msg.data.size(), (size_t)2); | |||
| ::ps::KVPairs<T> kvs; | |||
| kvs.keys = msg.data[0]; | |||
| kvs.vals = msg.data[1]; | |||
| if (msg.data.size() > (size_t)2) { | |||
| kvs.lens = msg.data[2]; | |||
| } | |||
| mutex_.lock(); | |||
| lookup_results_[ts].push_back(kvs); | |||
| mutex_.unlock(); | |||
| } | |||
| if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) { | |||
| const auto &cb = lookup_callbacks_[ts]; | |||
| cb(); | |||
| lookup_callbacks_.erase(ts); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void WorkerProxy<T>::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, | |||
| const ::ps::KVPairs<T> &kvs, const Slicer &slicer) { | |||
| SlicedKVs sliced; | |||
| slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced); | |||
| for (size_t i = 0; i < sliced.size(); i++) { | |||
| const auto &s = sliced[i]; | |||
| if (!s.first) continue; | |||
| ::ps::Message msg; | |||
| msg.meta.app_id = customer->app_id(); | |||
| msg.meta.customer_id = customer->customer_id(); | |||
| msg.meta.request = true; | |||
| msg.meta.push = push; | |||
| msg.meta.pull = pull; | |||
| msg.meta.head = cmd; | |||
| msg.meta.timestamp = timestamp; | |||
| msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i); | |||
| msg.meta.priority = kvs.priority; | |||
| const auto &kvs = s.second; | |||
| if (kvs.keys.size()) { | |||
| msg.AddData(kvs.keys); | |||
| msg.AddData(kvs.vals); | |||
| if (kvs.lens.size()) { | |||
| msg.AddData(kvs.lens); | |||
| } | |||
| } | |||
| ::ps::Postoffice::Get()->van()->Send(msg); | |||
| } | |||
| } | |||
| } // namespace ps | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ | |||