|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_
- #define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_
-
- #include <unordered_map>
- #include <algorithm>
- #include <utility>
- #include <memory>
- #include <vector>
- #include "ps/ps.h"
- #include "parallel/ps/util.h"
-
- namespace mindspore {
- namespace parallel {
- namespace ps {
- template <typename T>
- class WorkerProxy : public ::ps::KVWorker<T> {
- public:
- using Worker = ::ps::KVWorker<T>;
- using Callback = std::function<void()>;
- using SlicedKVs = std::vector<std::pair<bool, ::ps::KVPairs<T>>>;
- using Slicer =
- std::function<void(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>;
- using ::ps::SimpleApp::obj_;
- explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) {
- using _1 = std::placeholders::_1;
- using _2 = std::placeholders::_2;
- using _3 = std::placeholders::_3;
- lookup_customer_ = std::unique_ptr<::ps::Customer>(
- new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy<T>::ProcessLookupResult, this, _1)));
- lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3);
- init_embedding_slicer_ = std::bind(&WorkerProxy<T>::EmbeddingTableInitSlicer, this, _1, _2, _3);
- push_slicer_ = std::bind(&WorkerProxy<T>::PushSlicer, this, _1, _2, _3);
- broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3);
- }
- ~WorkerProxy() override = default;
-
- void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
- void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
- const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd = 0, const Callback &cb = nullptr,
- int priority = 0);
- int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
- const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int priority = 0);
- void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {},
- int cmd = 0, int priority = 0);
-
- private:
- template <typename C>
- int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, C *vals, int cmd,
- const Callback &cb);
- void LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
- std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
- void EmbeddingTableInitSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
- std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
- void PushSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
- std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
- void BroadcastSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
- std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced);
- void ProcessLookupResult(const ::ps::Message &msg);
- void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs<T> &kvs,
- const Slicer &slicer);
-
- std::unique_ptr<::ps::Customer> lookup_customer_;
- std::unordered_map<::ps::Key, std::shared_ptr<std::vector<::ps::Range>>> embedding_table_ranges_;
- std::unordered_map<int, std::vector<::ps::KVPairs<T>>> lookup_results_;
- std::mutex mutex_;
- Slicer lookup_slicer_;
- Slicer init_embedding_slicer_;
- Slicer push_slicer_;
- Slicer broadcast_slicer_;
- std::unordered_map<int, Callback> lookup_callbacks_;
- };
-
- template <typename T>
- void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) {
- uint64_t begin = 0;
- uint64_t end = 0;
- int server_num = ::ps::NumServers();
- for (int i = 0; i < server_num; i++) {
- int local_row_cnt = Util::LocalShard(row_count, i, server_num);
- if (i == 0) {
- end = local_row_cnt - 1;
- } else {
- begin = end + 1;
- end += local_row_cnt;
- }
- ::ps::Range range(begin, end);
- if (embedding_table_ranges_.count(key) == 0) {
- embedding_table_ranges_[key] = std::make_shared<std::vector<::ps::Range>>();
- }
- embedding_table_ranges_[key]->push_back(range);
- }
- }
-
- template <typename T>
- void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
- const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd, const Callback &cb,
- int priority) {
- int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb);
- ::ps::KVPairs<T> kvs;
- kvs.keys = keys;
- kvs.vals = lookup_ids;
- kvs.lens = lens;
- kvs.priority = priority;
- Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_);
- lookup_customer_->WaitRequest(ts);
- }
-
- template <typename T>
- int WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
- const ::ps::SArray<int> &lens, const Callback &cb, int priority) {
- int ts = obj_->NewRequest(::ps::kServerGroup);
- ::ps::KVPairs<T> kvs;
- kvs.keys = keys;
- kvs.vals = vals;
- kvs.lens = lens;
- kvs.priority = priority;
- Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_);
- return ts;
- }
-
- template <typename T>
- void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
- const ::ps::SArray<int> &lens, int cmd, int priority) {
- int ts = obj_->NewRequest(::ps::kServerGroup);
- ::ps::KVPairs<T> kvs;
- kvs.keys = keys;
- kvs.vals = vals;
- kvs.lens = lens;
- kvs.priority = priority;
- Send(obj_, ts, true, false, cmd, kvs, push_slicer_);
- obj_->WaitRequest(ts);
- }
-
- template <typename T>
- template <typename C>
- int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids,
- C *lookup_result, int cmd, const Callback &cb) {
- int ts = lookup_customer_->NewRequest(::ps::kServerGroup);
- const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable {
- mutex_.lock();
- auto &kvs = lookup_results_[ts];
- mutex_.unlock();
-
- size_t total_len = 0;
- const auto &s = kvs[0];
- for (size_t i = 0; i < s.lens.size(); i++) {
- total_len += s.lens[i];
- }
- lookup_result->resize(total_len, 0);
- T *result_addr = lookup_result->data();
-
- for (const auto &s : kvs) {
- size_t offset = 0;
- for (size_t i = 0; i < s.vals.size(); i++) {
- result_addr[offset++] += s.vals[i];
- }
- }
-
- mutex_.lock();
- lookup_results_.erase(ts);
- mutex_.unlock();
- if (cb) cb();
- };
- lookup_callbacks_[ts] = callback;
- return ts;
- }
-
- template <typename T>
- void WorkerProxy<T>::LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
- std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
- int *data = send.lens.data();
- size_t size = send.lens.size();
- std::vector<int> lookup_ids(data, data + size);
- std::sort(lookup_ids.begin(), lookup_ids.end());
-
- const Key &key = send.keys[0];
- const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
- sliced->resize(ranges.size());
-
- size_t index = 0;
- for (size_t i = 0; i < ranges.size(); i++) {
- const ::ps::Range &range = ranges[i];
- const auto &begin = range.begin();
- const auto &end = range.end();
- auto &kvs = sliced->at(i).second;
-
- auto lookup_id = static_cast<uint64_t>(lookup_ids[index]);
- while (lookup_id >= begin && lookup_id <= end) {
- kvs.vals.push_back(lookup_id);
- if (++index >= lookup_ids.size()) {
- break;
- }
- lookup_id = static_cast<uint64_t>(lookup_ids[index]);
- }
- kvs.keys.push_back(key);
- kvs.lens.push_back(kvs.vals.size());
-
- if (kvs.vals.size() == 0) {
- sliced->at(i).first = false;
- } else {
- sliced->at(i).first = true;
- }
- }
- }
-
- template <typename T>
- void WorkerProxy<T>::EmbeddingTableInitSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
- std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
- const Key &key = send.keys[0];
- const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
- sliced->resize(ranges.size());
- for (size_t i = 0; i < ranges.size(); i++) {
- sliced->at(i).first = true;
- sliced->at(i).second = send;
- }
- }
-
- template <typename T>
- void WorkerProxy<T>::PushSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
- std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
- auto server_num = ::ps::Postoffice::Get()->num_servers();
- sliced->resize(server_num);
- for (int i = 0; i < server_num; i++) {
- sliced->at(i).first = true;
- sliced->at(i).second = send;
- }
- }
-
- template <typename T>
- void WorkerProxy<T>::BroadcastSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
- std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) {
- auto server_num = ::ps::Postoffice::Get()->num_servers();
- sliced->resize(server_num);
- for (int i = 0; i < server_num; i++) {
- sliced->at(i).first = true;
- sliced->at(i).second = send;
- }
- }
-
- template <typename T>
- void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
- int ts = msg.meta.timestamp;
- if (msg.meta.pull) {
- CHECK_GE(msg.data.size(), (size_t)2);
- ::ps::KVPairs<T> kvs;
- kvs.keys = msg.data[0];
- kvs.vals = msg.data[1];
- if (msg.data.size() > (size_t)2) {
- kvs.lens = msg.data[2];
- }
- mutex_.lock();
- lookup_results_[ts].push_back(kvs);
- mutex_.unlock();
- }
- if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) {
- const auto &cb = lookup_callbacks_[ts];
- cb();
- lookup_callbacks_.erase(ts);
- }
- }
-
- template <typename T>
- void WorkerProxy<T>::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd,
- const ::ps::KVPairs<T> &kvs, const Slicer &slicer) {
- SlicedKVs sliced;
- slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced);
-
- for (size_t i = 0; i < sliced.size(); i++) {
- const auto &s = sliced[i];
- if (!s.first) continue;
- ::ps::Message msg;
- msg.meta.app_id = customer->app_id();
- msg.meta.customer_id = customer->customer_id();
- msg.meta.request = true;
- msg.meta.push = push;
- msg.meta.pull = pull;
- msg.meta.head = cmd;
- msg.meta.timestamp = timestamp;
- msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i);
- msg.meta.priority = kvs.priority;
- const auto &kvs = s.second;
- if (kvs.keys.size()) {
- msg.AddData(kvs.keys);
- msg.AddData(kvs.vals);
- if (kvs.lens.size()) {
- msg.AddData(kvs.lens);
- }
- }
- ::ps::Postoffice::Get()->van()->Send(msg);
- }
- }
- } // namespace ps
- } // namespace parallel
- } // namespace mindspore
- #endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_
|