|
|
|
@@ -0,0 +1,311 @@ |
|
|
|
/** |
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
* |
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
* you may not use this file except in compliance with the License. |
|
|
|
* You may obtain a copy of the License at |
|
|
|
* |
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
* |
|
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ |
|
|
|
#define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ |
|
|
|
|
|
|
|
#include <unordered_map> |
|
|
|
#include <algorithm> |
|
|
|
#include <utility> |
|
|
|
#include <memory> |
|
|
|
#include <vector> |
|
|
|
#include "ps/ps.h" |
|
|
|
#include "parallel/ps/util.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace parallel { |
|
|
|
namespace ps { |
|
|
|
template <typename T> |
|
|
|
class WorkerProxy : public ::ps::KVWorker<T> { |
|
|
|
public: |
|
|
|
using Worker = ::ps::KVWorker<T>; |
|
|
|
using Callback = std::function<void()>; |
|
|
|
using SlicedKVs = std::vector<std::pair<bool, ::ps::KVPairs<T>>>; |
|
|
|
using Slicer = |
|
|
|
std::function<void(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges, SlicedKVs *sliced)>; |
|
|
|
using ::ps::SimpleApp::obj_; |
|
|
|
explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id) : Worker(app_id, customer_id) { |
|
|
|
using _1 = std::placeholders::_1; |
|
|
|
using _2 = std::placeholders::_2; |
|
|
|
using _3 = std::placeholders::_3; |
|
|
|
lookup_customer_ = std::unique_ptr<::ps::Customer>( |
|
|
|
new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy<T>::ProcessLookupResult, this, _1))); |
|
|
|
lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3); |
|
|
|
init_embedding_slicer_ = std::bind(&WorkerProxy<T>::EmbeddingTableInitSlicer, this, _1, _2, _3); |
|
|
|
push_slicer_ = std::bind(&WorkerProxy<T>::PushSlicer, this, _1, _2, _3); |
|
|
|
broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3); |
|
|
|
} |
|
|
|
~WorkerProxy() override = default; |
|
|
|
|
|
|
|
void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); |
|
|
|
void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, |
|
|
|
const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd = 0, const Callback &cb = nullptr, |
|
|
|
int priority = 0); |
|
|
|
int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, |
|
|
|
const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int priority = 0); |
|
|
|
void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {}, |
|
|
|
int cmd = 0, int priority = 0); |
|
|
|
|
|
|
|
private: |
|
|
|
template <typename C> |
|
|
|
int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, C *vals, int cmd, |
|
|
|
const Callback &cb); |
|
|
|
void LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, |
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced); |
|
|
|
void EmbeddingTableInitSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, |
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced); |
|
|
|
void PushSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, |
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced); |
|
|
|
void BroadcastSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, |
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced); |
|
|
|
void ProcessLookupResult(const ::ps::Message &msg); |
|
|
|
void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs<T> &kvs, |
|
|
|
const Slicer &slicer); |
|
|
|
|
|
|
|
std::unique_ptr<::ps::Customer> lookup_customer_; |
|
|
|
std::unordered_map<::ps::Key, std::shared_ptr<std::vector<::ps::Range>>> embedding_table_ranges_; |
|
|
|
std::unordered_map<int, std::vector<::ps::KVPairs<T>>> lookup_results_; |
|
|
|
std::mutex mutex_; |
|
|
|
Slicer lookup_slicer_; |
|
|
|
Slicer init_embedding_slicer_; |
|
|
|
Slicer push_slicer_; |
|
|
|
Slicer broadcast_slicer_; |
|
|
|
std::unordered_map<int, Callback> lookup_callbacks_; |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) { |
|
|
|
uint64_t begin = 0; |
|
|
|
uint64_t end = 0; |
|
|
|
int server_num = ::ps::NumServers(); |
|
|
|
for (int i = 0; i < server_num; i++) { |
|
|
|
int local_row_cnt = Util::LocalShard(row_count, i, server_num); |
|
|
|
if (i == 0) { |
|
|
|
end = local_row_cnt - 1; |
|
|
|
} else { |
|
|
|
begin = end + 1; |
|
|
|
end += local_row_cnt; |
|
|
|
} |
|
|
|
::ps::Range range(begin, end); |
|
|
|
if (embedding_table_ranges_.count(key) == 0) { |
|
|
|
embedding_table_ranges_[key] = std::make_shared<std::vector<::ps::Range>>(); |
|
|
|
} |
|
|
|
embedding_table_ranges_[key]->push_back(range); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, |
|
|
|
const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd, const Callback &cb, |
|
|
|
int priority) { |
|
|
|
int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb); |
|
|
|
::ps::KVPairs<T> kvs; |
|
|
|
kvs.keys = keys; |
|
|
|
kvs.vals = lookup_ids; |
|
|
|
kvs.lens = lens; |
|
|
|
kvs.priority = priority; |
|
|
|
Send(lookup_customer_.get(), ts, true, true, cmd, kvs, broadcast_slicer_); |
|
|
|
lookup_customer_->WaitRequest(ts); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
int WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, |
|
|
|
const ::ps::SArray<int> &lens, const Callback &cb, int priority) { |
|
|
|
int ts = obj_->NewRequest(::ps::kServerGroup); |
|
|
|
::ps::KVPairs<T> kvs; |
|
|
|
kvs.keys = keys; |
|
|
|
kvs.vals = vals; |
|
|
|
kvs.lens = lens; |
|
|
|
kvs.priority = priority; |
|
|
|
Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, init_embedding_slicer_); |
|
|
|
return ts; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, |
|
|
|
const ::ps::SArray<int> &lens, int cmd, int priority) { |
|
|
|
int ts = obj_->NewRequest(::ps::kServerGroup); |
|
|
|
::ps::KVPairs<T> kvs; |
|
|
|
kvs.keys = keys; |
|
|
|
kvs.vals = vals; |
|
|
|
kvs.lens = lens; |
|
|
|
kvs.priority = priority; |
|
|
|
Send(obj_, ts, true, false, cmd, kvs, push_slicer_); |
|
|
|
obj_->WaitRequest(ts); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
template <typename C> |
|
|
|
int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &lookup_ids, |
|
|
|
C *lookup_result, int cmd, const Callback &cb) { |
|
|
|
int ts = lookup_customer_->NewRequest(::ps::kServerGroup); |
|
|
|
const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable { |
|
|
|
mutex_.lock(); |
|
|
|
auto &kvs = lookup_results_[ts]; |
|
|
|
mutex_.unlock(); |
|
|
|
|
|
|
|
size_t total_len = 0; |
|
|
|
const auto &s = kvs[0]; |
|
|
|
for (size_t i = 0; i < s.lens.size(); i++) { |
|
|
|
total_len += s.lens[i]; |
|
|
|
} |
|
|
|
lookup_result->resize(total_len, 0); |
|
|
|
T *result_addr = lookup_result->data(); |
|
|
|
|
|
|
|
for (const auto &s : kvs) { |
|
|
|
size_t offset = 0; |
|
|
|
for (size_t i = 0; i < s.vals.size(); i++) { |
|
|
|
result_addr[offset++] += s.vals[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
mutex_.lock(); |
|
|
|
lookup_results_.erase(ts); |
|
|
|
mutex_.unlock(); |
|
|
|
if (cb) cb(); |
|
|
|
}; |
|
|
|
lookup_callbacks_[ts] = callback; |
|
|
|
return ts; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void WorkerProxy<T>::LookupIdSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, |
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) { |
|
|
|
int *data = send.lens.data(); |
|
|
|
size_t size = send.lens.size(); |
|
|
|
std::vector<int> lookup_ids(data, data + size); |
|
|
|
std::sort(lookup_ids.begin(), lookup_ids.end()); |
|
|
|
|
|
|
|
const Key &key = send.keys[0]; |
|
|
|
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); |
|
|
|
sliced->resize(ranges.size()); |
|
|
|
|
|
|
|
size_t index = 0; |
|
|
|
for (size_t i = 0; i < ranges.size(); i++) { |
|
|
|
const ::ps::Range &range = ranges[i]; |
|
|
|
const auto &begin = range.begin(); |
|
|
|
const auto &end = range.end(); |
|
|
|
auto &kvs = sliced->at(i).second; |
|
|
|
|
|
|
|
auto lookup_id = static_cast<uint64_t>(lookup_ids[index]); |
|
|
|
while (lookup_id >= begin && lookup_id <= end) { |
|
|
|
kvs.vals.push_back(lookup_id); |
|
|
|
if (++index >= lookup_ids.size()) { |
|
|
|
break; |
|
|
|
} |
|
|
|
lookup_id = static_cast<uint64_t>(lookup_ids[index]); |
|
|
|
} |
|
|
|
kvs.keys.push_back(key); |
|
|
|
kvs.lens.push_back(kvs.vals.size()); |
|
|
|
|
|
|
|
if (kvs.vals.size() == 0) { |
|
|
|
sliced->at(i).first = false; |
|
|
|
} else { |
|
|
|
sliced->at(i).first = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void WorkerProxy<T>::EmbeddingTableInitSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, |
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) { |
|
|
|
const Key &key = send.keys[0]; |
|
|
|
const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]); |
|
|
|
sliced->resize(ranges.size()); |
|
|
|
for (size_t i = 0; i < ranges.size(); i++) { |
|
|
|
sliced->at(i).first = true; |
|
|
|
sliced->at(i).second = send; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void WorkerProxy<T>::PushSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, |
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) { |
|
|
|
auto server_num = ::ps::Postoffice::Get()->num_servers(); |
|
|
|
sliced->resize(server_num); |
|
|
|
for (int i = 0; i < server_num; i++) { |
|
|
|
sliced->at(i).first = true; |
|
|
|
sliced->at(i).second = send; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void WorkerProxy<T>::BroadcastSlicer(const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &, |
|
|
|
std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced) { |
|
|
|
auto server_num = ::ps::Postoffice::Get()->num_servers(); |
|
|
|
sliced->resize(server_num); |
|
|
|
for (int i = 0; i < server_num; i++) { |
|
|
|
sliced->at(i).first = true; |
|
|
|
sliced->at(i).second = send; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) { |
|
|
|
int ts = msg.meta.timestamp; |
|
|
|
if (msg.meta.pull) { |
|
|
|
CHECK_GE(msg.data.size(), (size_t)2); |
|
|
|
::ps::KVPairs<T> kvs; |
|
|
|
kvs.keys = msg.data[0]; |
|
|
|
kvs.vals = msg.data[1]; |
|
|
|
if (msg.data.size() > (size_t)2) { |
|
|
|
kvs.lens = msg.data[2]; |
|
|
|
} |
|
|
|
mutex_.lock(); |
|
|
|
lookup_results_[ts].push_back(kvs); |
|
|
|
mutex_.unlock(); |
|
|
|
} |
|
|
|
if (lookup_customer_->NumResponse(ts) == ::ps::Postoffice::Get()->num_servers() - 1) { |
|
|
|
const auto &cb = lookup_callbacks_[ts]; |
|
|
|
cb(); |
|
|
|
lookup_callbacks_.erase(ts); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
void WorkerProxy<T>::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, |
|
|
|
const ::ps::KVPairs<T> &kvs, const Slicer &slicer) { |
|
|
|
SlicedKVs sliced; |
|
|
|
slicer(kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced); |
|
|
|
|
|
|
|
for (size_t i = 0; i < sliced.size(); i++) { |
|
|
|
const auto &s = sliced[i]; |
|
|
|
if (!s.first) continue; |
|
|
|
::ps::Message msg; |
|
|
|
msg.meta.app_id = customer->app_id(); |
|
|
|
msg.meta.customer_id = customer->customer_id(); |
|
|
|
msg.meta.request = true; |
|
|
|
msg.meta.push = push; |
|
|
|
msg.meta.pull = pull; |
|
|
|
msg.meta.head = cmd; |
|
|
|
msg.meta.timestamp = timestamp; |
|
|
|
msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i); |
|
|
|
msg.meta.priority = kvs.priority; |
|
|
|
const auto &kvs = s.second; |
|
|
|
if (kvs.keys.size()) { |
|
|
|
msg.AddData(kvs.keys); |
|
|
|
msg.AddData(kvs.vals); |
|
|
|
if (kvs.lens.size()) { |
|
|
|
msg.AddData(kvs.lens); |
|
|
|
} |
|
|
|
} |
|
|
|
::ps::Postoffice::Get()->van()->Send(msg); |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace ps |
|
|
|
} // namespace parallel |
|
|
|
} // namespace mindspore |
|
|
|
#endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_WORKER_PROXY_H_ |