You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

worker_proxy.h 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_WORKER_PROXY_H_
  17. #define MINDSPORE_CCSRC_PS_WORKER_PROXY_H_
  18. #include <map>
  19. #include <numeric>
  20. #include <functional>
  21. #include <unordered_map>
  22. #include <unordered_set>
  23. #include <algorithm>
  24. #include <utility>
  25. #include <memory>
  26. #include <vector>
  27. #include "ps/ps.h"
  28. #include "ps/util.h"
  29. #include "backend/kernel_compiler/common_utils.h"
  30. #include "ps/ps_context.h"
  31. namespace mindspore {
  32. namespace ps {
  33. template <typename T>
  34. class WorkerProxy : public ::ps::KVWorker<T> {
  35. public:
  36. using Worker = ::ps::KVWorker<T>;
  37. using Callback = std::function<void()>;
  38. using SlicedKVs = std::vector<std::pair<bool, ::ps::KVPairs<T>>>;
  39. using Slicer = std::function<void(int64_t ts, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges,
  40. SlicedKVs *sliced, const std::map<int64_t, int64_t> &attrs)>;
  41. using ::ps::SimpleApp::obj_;
  42. explicit WorkerProxy(int64_t app_id, int64_t customer_id, int64_t lookup_customer_id, int64_t general_customer_id)
  43. : Worker(app_id, customer_id) {
  44. server_num_ = ::ps::NumServers();
  45. MS_LOG(INFO) << "Server num:" << server_num_;
  46. PSContext::instance()->SetPSRankId(::ps::MyRank());
  47. using std::placeholders::_1;
  48. using std::placeholders::_2;
  49. using std::placeholders::_3;
  50. using std::placeholders::_4;
  51. using std::placeholders::_5;
  52. lookup_customer_ = std::unique_ptr<::ps::Customer>(
  53. new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy<T>::ProcessLookupResult, this, _1)));
  54. general_customer_ = std::unique_ptr<::ps::Customer>(
  55. new ::ps::Customer(app_id, general_customer_id, std::bind(&WorkerProxy<T>::ProcessResponse, this, _1)));
  56. lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3, _4, _5);
  57. sparse_slicer_ = std::bind(&WorkerProxy<T>::SparseSlicer, this, _1, _2, _3, _4, _5);
  58. broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4, _5);
  59. round_robin_slicer_ = std::bind(&WorkerProxy<T>::RoundRobinSlicer, this, _1, _2, _3, _4, _5);
  60. worker_init_embedding_slicer_ = std::bind(&WorkerProxy<T>::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5);
  61. update_embedding_slicer_ = std::bind(&WorkerProxy<T>::UpdateEmbeddingSlicer, this, _1, _2, _3, _4, _5);
  62. }
  63. ~WorkerProxy() override = default;
  64. void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
  65. void AddKeyToServerId(const ::ps::Key &key);
  66. void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
  67. const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int64_t cmd = 0,
  68. const Callback &cb = nullptr, int64_t priority = 0);
  69. int64_t InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
  70. const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int64_t priority = 0);
  71. void UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
  72. const ::ps::SArray<T> &vals, const Callback &cb = nullptr, int64_t priority = 0);
  73. bool IsReadyForPush(const Key &key);
  74. bool IsReadyForPull(const Key &key);
  75. void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {},
  76. int64_t cmd = 0, int64_t priority = 0);
  77. void PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens,
  78. size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
  79. void PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens = nullptr,
  80. int64_t cmd = 0, int64_t priority = 0);
  81. void Finalize();
  82. private:
  83. template <typename C>
  84. int64_t AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, C *vals, int64_t cmd,
  85. const Callback &cb);
  86. int64_t AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens,
  87. int64_t cmd, const Callback &cb);
  88. void LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  89. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs);
  90. void SparseSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  91. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs);
  92. void BroadcastSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  93. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, const std::map<int64_t, int64_t> &attrs);
  94. void RoundRobinSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  95. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  96. const std::map<int64_t, int64_t> &attrs);
  97. void WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  98. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  99. const std::map<int64_t, int64_t> &attrs);
  100. void UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  101. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  102. const std::map<int64_t, int64_t> &attrs);
  103. void ProcessLookupResult(const ::ps::Message &msg);
  104. void ProcessResponse(const ::ps::Message &msg);
  105. void Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd, const ::ps::KVPairs<T> &kvs,
  106. const Slicer &slicer, std::map<int64_t, int64_t> attrs = {});
  107. void AddKeyByHashMod(const ::ps::Key &key);
  108. void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set<int> &distinct_ids,
  109. const std::vector<std::pair<int, T *>> &indice_to_grad, const int *all_indice,
  110. const size_t segment_size, T *gradient, int *indice);
  111. void BuildSparseValue(const ::ps::SArray<int> &lengths, const size_t grad_index, const size_t indice_index,
  112. const T *original_data, const T *grads, int *indices, ::ps::SArray<T> *reduced_data);
  113. int64_t server_num_;
  114. std::unique_ptr<::ps::Customer> lookup_customer_;
  115. std::unique_ptr<::ps::Customer> general_customer_;
  116. std::unordered_map<::ps::Key, std::shared_ptr<std::vector<::ps::Range>>> embedding_table_ranges_;
  117. std::unordered_map<int64_t, std::vector<::ps::KVPairs<T>>> lookup_results_;
  118. std::unordered_map<int64_t, std::map<int64_t, ::ps::KVPairs<T>>> gathered_response_;
  119. std::mutex mutex_;
  120. Slicer lookup_slicer_;
  121. Slicer sparse_slicer_;
  122. Slicer broadcast_slicer_;
  123. Slicer round_robin_slicer_;
  124. Slicer worker_init_embedding_slicer_;
  125. Slicer update_embedding_slicer_;
  126. std::unordered_map<int64_t, Callback> lookup_callbacks_;
  127. std::unordered_map<int64_t, Callback> general_callbacks_;
  128. std::unordered_map<int64_t, int64_t> expected_result_count_;
  129. std::unordered_map<::ps::Key, int64_t> key_to_server_id_;
  130. std::unordered_map<::ps::Key, size_t> embedding_row_cnt_;
  131. };
  132. template <typename T>
  133. void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) {
  134. uint64_t begin = 0;
  135. uint64_t end = 0;
  136. for (int64_t i = 0; i < server_num_; i++) {
  137. int64_t local_row_cnt = Util::LocalShard(row_count, i, server_num_);
  138. if (i == 0) {
  139. end = local_row_cnt - 1;
  140. } else {
  141. begin = end + 1;
  142. end += local_row_cnt;
  143. }
  144. ::ps::Range range(begin, end);
  145. if (embedding_table_ranges_.count(key) == 0) {
  146. embedding_table_ranges_[key] = std::make_shared<std::vector<::ps::Range>>();
  147. MS_EXCEPTION_IF_NULL(embedding_table_ranges_[key]);
  148. }
  149. embedding_table_ranges_[key]->push_back(range);
  150. }
  151. embedding_row_cnt_[key] = row_count;
  152. }
  153. template <typename T>
  154. void WorkerProxy<T>::AddKeyByHashMod(const ::ps::Key &key) {
  155. if (server_num_ == 0) {
  156. MS_LOG(EXCEPTION) << "Server number is invalid:0";
  157. }
  158. key_to_server_id_[key] = static_cast<int64_t>(key % server_num_);
  159. MS_LOG(INFO) << "The server id of key " << key << " is " << key_to_server_id_[key];
  160. }
  161. template <typename T>
  162. void WorkerProxy<T>::AddKeyToServerId(const ::ps::Key &key) {
  163. AddKeyByHashMod(key);
  164. }
  165. template <typename T>
  166. void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
  167. const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int64_t cmd,
  168. const Callback &cb, int64_t priority) {
  169. int64_t ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb);
  170. ::ps::KVPairs<T> kvs;
  171. kvs.keys = keys;
  172. kvs.lens = lookup_ids;
  173. kvs.priority = priority;
  174. expected_result_count_[ts] = 0;
  175. Send(lookup_customer_.get(), ts, true, true, cmd, kvs, lookup_slicer_);
  176. int64_t expect_rt_count = expected_result_count_[ts];
  177. lookup_customer_->AddResponse(ts, server_num_ - expect_rt_count);
  178. lookup_customer_->WaitRequest(ts);
  179. expected_result_count_.erase(ts);
  180. }
  181. template <typename T>
  182. int64_t WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
  183. const ::ps::SArray<int> &lens, const Callback &cb, int64_t priority) {
  184. int64_t ts = obj_->NewRequest(::ps::kServerGroup);
  185. ::ps::KVPairs<T> kvs;
  186. kvs.keys = keys;
  187. kvs.vals = vals;
  188. kvs.lens = lens;
  189. kvs.priority = priority;
  190. Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, broadcast_slicer_);
  191. return ts;
  192. }
  193. template <typename T>
  194. void WorkerProxy<T>::UpdateEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
  195. const ::ps::SArray<T> &vals, const Callback &cb, int64_t priority) {
  196. int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr);
  197. ::ps::KVPairs<T> kvs;
  198. kvs.keys = keys;
  199. kvs.lens = lookup_ids;
  200. kvs.vals = vals;
  201. kvs.priority = priority;
  202. expected_result_count_[ts] = 0;
  203. Send(general_customer_.get(), ts, true, false, kUpdateEmbeddingsCmd, kvs, update_embedding_slicer_);
  204. if (expected_result_count_[ts] < server_num_) {
  205. general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
  206. }
  207. general_customer_->WaitRequest(ts);
  208. expected_result_count_.erase(ts);
  209. }
  210. template <typename T>
  211. bool WorkerProxy<T>::IsReadyForPush(const Key &key) {
  212. ::ps::SArray<T> result(1, 0);
  213. PullData({key}, &result, nullptr, kCheckReadyForPushCmd);
  214. if (result[0] > 0) {
  215. return true;
  216. } else {
  217. return false;
  218. }
  219. }
  220. template <typename T>
  221. bool WorkerProxy<T>::IsReadyForPull(const Key &key) {
  222. ::ps::SArray<T> result(1, 0);
  223. PullData({key}, &result, nullptr, kCheckReadyForPullCmd);
  224. if (result[0] > 0) {
  225. return true;
  226. } else {
  227. return false;
  228. }
  229. }
  230. template <typename T>
  231. void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
  232. const ::ps::SArray<int> &lens, int64_t cmd, int64_t priority) {
  233. int64_t ts = AddGeneralRspCB(keys, nullptr, nullptr, cmd, nullptr);
  234. ::ps::KVPairs<T> kvs;
  235. kvs.keys = keys;
  236. kvs.vals = vals;
  237. kvs.lens = lens;
  238. kvs.priority = priority;
  239. if (embedding_table_ranges_.count(keys[0])) {
  240. if (cmd == kInitWeightsCmd) {
  241. Send(general_customer_.get(), ts, true, false, cmd, kvs, worker_init_embedding_slicer_);
  242. } else {
  243. Send(general_customer_.get(), ts, true, false, cmd, kvs, broadcast_slicer_);
  244. }
  245. } else {
  246. Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_);
  247. }
  248. if (expected_result_count_[ts] < server_num_) {
  249. general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
  250. }
  251. general_customer_->WaitRequest(ts);
  252. }
  253. template <typename T>
  254. void WorkerProxy<T>::PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
  255. const ::ps::SArray<int> &lens, size_t grad_index, size_t indice_index,
  256. size_t first_dim_size, size_t outer_dim_size) {
  257. int64_t ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr);
  258. ::ps::KVPairs<T> kvs;
  259. kvs.keys = keys;
  260. kvs.vals = vals;
  261. kvs.lens = lens;
  262. const int64_t cmd = 0;
  263. if (embedding_table_ranges_.count(keys[0])) {
  264. std::map<int64_t, int64_t> attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}};
  265. Send(general_customer_.get(), ts, true, false, cmd, kvs, sparse_slicer_, attrs);
  266. } else {
  267. Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_);
  268. }
  269. if (expected_result_count_[ts] < server_num_) {
  270. general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
  271. }
  272. general_customer_->WaitRequest(ts);
  273. }
  274. template <typename T>
  275. void WorkerProxy<T>::PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens,
  276. int64_t cmd, int64_t priority) {
  277. MS_EXCEPTION_IF_NULL(vals);
  278. int64_t ts = AddGeneralRspCB(keys, vals, lens, cmd, nullptr);
  279. ::ps::KVPairs<T> kvs;
  280. kvs.keys = keys;
  281. kvs.priority = priority;
  282. if (embedding_table_ranges_.count(keys[0])) {
  283. Send(general_customer_.get(), ts, false, true, cmd, kvs, broadcast_slicer_);
  284. } else {
  285. Send(general_customer_.get(), ts, false, true, cmd, kvs, round_robin_slicer_);
  286. }
  287. if (expected_result_count_[ts] < server_num_) {
  288. general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
  289. }
  290. general_customer_->WaitRequest(ts);
  291. }
  292. template <typename T>
  293. void WorkerProxy<T>::Finalize() {
  294. int64_t ts = obj_->NewRequest(::ps::kServerGroup);
  295. ::ps::KVPairs<T> kvs;
  296. kvs.keys.push_back(0);
  297. kvs.vals.push_back(0.0f);
  298. Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_);
  299. obj_->WaitRequest(ts);
  300. ::ps::Finalize(0, true);
  301. }
  302. template <typename T>
  303. template <typename C>
  304. int64_t WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
  305. C *lookup_result, int64_t cmd, const Callback &cb) {
  306. MS_EXCEPTION_IF_NULL(lookup_result);
  307. int64_t ts = lookup_customer_->NewRequest(::ps::kServerGroup);
  308. const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable {
  309. mutex_.lock();
  310. auto &kvs = lookup_results_[ts];
  311. mutex_.unlock();
  312. if (lookup_ids.empty()) {
  313. MS_LOG(EXCEPTION) << "Lookup id is empty.";
  314. }
  315. int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
  316. std::unordered_map<Key, std::shared_ptr<std::pair<T *, int64_t>>> id_addr_map;
  317. for (const auto &s : kvs) {
  318. int64_t offset = 0;
  319. for (size_t i = 0; i < s.keys.size(); i++) {
  320. const Key &key = s.keys[i];
  321. T *addr = s.vals.data() + offset;
  322. offset += single_id_len;
  323. id_addr_map[key] = std::make_shared<std::pair<T *, int64_t>>(std::make_pair(addr, single_id_len));
  324. MS_EXCEPTION_IF_NULL(id_addr_map[key]);
  325. }
  326. }
  327. T *result_addr = lookup_result->data();
  328. MS_EXCEPTION_IF_NULL(result_addr);
  329. int64_t offset = 0;
  330. size_t dst_size = 0;
  331. size_t src_size = 0;
  332. void *dst_data = nullptr;
  333. void *src_data = nullptr;
  334. for (size_t i = 0; i < lookup_ids.size(); i++) {
  335. if (id_addr_map.count(lookup_ids[i]) == 0) {
  336. offset += single_id_len;
  337. continue;
  338. }
  339. auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])];
  340. int64_t size = single_id_len * sizeof(T);
  341. dst_size = size;
  342. src_size = size;
  343. dst_data = result_addr + offset;
  344. src_data = pair->first;
  345. MS_EXCEPTION_IF_NULL(dst_data);
  346. MS_EXCEPTION_IF_NULL(src_data);
  347. auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  348. if (ret != 0) {
  349. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  350. return;
  351. }
  352. offset += single_id_len;
  353. }
  354. mutex_.lock();
  355. lookup_results_.erase(ts);
  356. mutex_.unlock();
  357. if (cb) cb();
  358. };
  359. lookup_callbacks_[ts] = callback;
  360. return ts;
  361. }
  362. template <typename T>
  363. int64_t WorkerProxy<T>::AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals,
  364. ::ps::SArray<int> *lens, int64_t cmd, const Callback &cb) {
  365. int64_t ts = general_customer_->NewRequest(::ps::kServerGroup);
  366. const auto &callback = [this, ts, keys, vals, lens, cb]() mutable {
  367. mutex_.lock();
  368. std::map<int64_t, ::ps::KVPairs<T>> server_kvs = gathered_response_[ts];
  369. mutex_.unlock();
  370. vals->clear();
  371. for (auto kvs : server_kvs) {
  372. for (auto val : kvs.second.vals) {
  373. vals->push_back(val);
  374. }
  375. if (lens) {
  376. for (auto len : kvs.second.lens) {
  377. lens->push_back(len);
  378. }
  379. }
  380. }
  381. mutex_.lock();
  382. gathered_response_.erase(ts);
  383. mutex_.unlock();
  384. if (cb) {
  385. cb();
  386. }
  387. };
  388. general_callbacks_[ts] = callback;
  389. return ts;
  390. }
  391. template <typename T>
  392. void WorkerProxy<T>::LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  393. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  394. const std::map<int64_t, int64_t> &attrs) {
  395. MS_EXCEPTION_IF_NULL(sliced);
  396. int32_t *lookup_ids = send.lens.data();
  397. size_t id_size = send.lens.size();
  398. const Key &key = send.keys[0];
  399. const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
  400. sliced->resize(ranges.size());
  401. for (size_t i = 0; i < ranges.size(); i++) {
  402. const ::ps::Range &range = ranges[i];
  403. const auto &begin = range.begin();
  404. const auto &end = range.end();
  405. std::unordered_set<int64_t> unique_ids;
  406. auto &kvs = sliced->at(i).second;
  407. kvs.keys.push_back(key);
  408. kvs.vals.push_back(0.0f);
  409. for (size_t j = 0; j < id_size; j++) {
  410. auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
  411. // If lookup_id is out of range, like negative number, unique_ids will not contain it.
  412. // Servers always get lookup_ids in its embedding table range.
  413. if (lookup_id >= begin && lookup_id <= end) {
  414. unique_ids.insert(lookup_id);
  415. }
  416. }
  417. for (const auto &lookup_id : unique_ids) {
  418. kvs.keys.push_back(lookup_id);
  419. kvs.vals.push_back(0.0f);
  420. }
  421. if (kvs.keys.size() <= 1) {
  422. sliced->at(i).first = false;
  423. } else {
  424. sliced->at(i).first = true;
  425. expected_result_count_[timestamp] += 1;
  426. }
  427. }
  428. }
  429. template <typename T>
  430. void WorkerProxy<T>::SparseSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  431. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  432. const std::map<int64_t, int64_t> &attrs) {
  433. MS_EXCEPTION_IF_NULL(sliced);
  434. // Init variables
  435. T *data = send.vals.data();
  436. if (attrs.count(0) == 0 || attrs.count(1) == 0 || attrs.count(2) == 0 || attrs.count(3) == 0) {
  437. MS_LOG(EXCEPTION) << "Invalid attrs keys";
  438. }
  439. auto iter = attrs.find(0);
  440. size_t grad_index = static_cast<size_t>(iter->second);
  441. iter = attrs.find(1);
  442. size_t indice_index = static_cast<size_t>(iter->second);
  443. iter = attrs.find(2);
  444. size_t first_dim_size = static_cast<size_t>(iter->second);
  445. iter = attrs.find(3);
  446. size_t outer_dim_size = static_cast<size_t>(iter->second);
  447. int grad_size = send.lens[grad_index];
  448. int indice_size = send.lens[indice_index];
  449. int segment_size = grad_size / indice_size;
  450. int64_t grad_offset = 0;
  451. int64_t indice_offset = 0;
  452. for (size_t i = 0; i < grad_index; i++) {
  453. grad_offset += send.lens[i];
  454. }
  455. for (size_t j = 0; j < indice_index; j++) {
  456. indice_offset += send.lens[j];
  457. }
  458. T *grad_data = data + grad_offset;
  459. int *indice_data = reinterpret_cast<int *>(data) + indice_offset;
  460. // Build the mappings of indice to gradient
  461. std::vector<std::pair<int, T *>> indice_to_grads;
  462. for (int i = 0; i < indice_size; i++) {
  463. int indice = indice_data[i];
  464. T *grad = grad_data + i * segment_size;
  465. indice_to_grads.push_back(std::make_pair(indice, grad));
  466. }
  467. const Key &key = send.keys[0];
  468. const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
  469. sliced->resize(ranges.size());
  470. // Construct reduced sparse data for each server
  471. for (size_t i = 0; i < ranges.size(); i++) {
  472. const ::ps::Range &range = ranges[i];
  473. const auto &begin = range.begin();
  474. const auto &end = range.end();
  475. auto &kvs = sliced->at(i).second;
  476. kvs.keys = send.keys;
  477. kvs.lens = send.lens;
  478. // Prepare the sparse gradient and indice
  479. std::vector<int> indice_ids;
  480. std::unordered_set<int> distinct_ids;
  481. for (int j = 0; j < indice_size; j++) {
  482. size_t indice = static_cast<size_t>(indice_data[j]);
  483. if (indice >= begin && indice <= end) {
  484. indice_ids.push_back(indice);
  485. distinct_ids.insert(indice);
  486. }
  487. }
  488. size_t indices_size = indice_ids.size();
  489. if (indices_size > 0) {
  490. int slice_segment_size = indices_size * segment_size;
  491. std::vector<T> src_grad_data(slice_segment_size);
  492. std::vector<int> src_indice_data(indices_size);
  493. PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data.data(),
  494. src_indice_data.data());
  495. // Reduce the sparse gradient and indice
  496. std::vector<T> new_grad(slice_segment_size);
  497. std::vector<int> new_indices(indices_size);
  498. mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size});
  499. Util::ReduceSparseGradient(src_grad_data.data(), src_indice_data.data(), indices_size, segment_size,
  500. first_dim_size, outer_dim_size, &unique_sparse_grad);
  501. // Update the length of reduce sparse gradient and indice
  502. ::ps::SArray<int> reduced_lens;
  503. reduced_lens.CopyFrom(kvs.lens);
  504. reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
  505. reduced_lens[indice_index] = unique_sparse_grad.indices_size_;
  506. // Build the sparse value to be sent
  507. size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus<int>());
  508. ::ps::SArray<T> reduced_data(total_size, 0);
  509. BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_,
  510. unique_sparse_grad.indices_, &reduced_data);
  511. kvs.lens = reduced_lens;
  512. kvs.vals = reduced_data;
  513. }
  514. if (indices_size <= 0) {
  515. ::ps::SArray<T> no_keys;
  516. ::ps::SArray<T> no_vals;
  517. ::ps::SArray<T> no_lens;
  518. no_keys.push_back(key);
  519. no_vals.push_back(-100);
  520. kvs.vals = no_vals;
  521. kvs.lens = no_lens;
  522. }
  523. sliced->at(i).first = true;
  524. expected_result_count_[timestamp] += 1;
  525. }
  526. }
  527. template <typename T>
  528. void WorkerProxy<T>::PrepareSparseGradient(const size_t begin, const size_t end,
  529. const std::unordered_set<int> &distinct_ids,
  530. const std::vector<std::pair<int, T *>> &indice_to_grads,
  531. const int *all_indice, const size_t segment_size, T *gradient,
  532. int *indices) {
  533. MS_EXCEPTION_IF_NULL(all_indice);
  534. MS_EXCEPTION_IF_NULL(gradient);
  535. MS_EXCEPTION_IF_NULL(indices);
  536. int64_t offset = 0;
  537. int64_t index = 0;
  538. size_t segment_data_size = segment_size * sizeof(T);
  539. size_t dst_size;
  540. size_t src_size;
  541. void *dst_data = nullptr;
  542. void *src_data = nullptr;
  543. for (auto &pair : indice_to_grads) {
  544. if (distinct_ids.count(pair.first) == 0) {
  545. continue;
  546. }
  547. indices[index++] = pair.first;
  548. dst_size = segment_data_size;
  549. src_size = segment_data_size;
  550. dst_data = gradient + offset;
  551. src_data = pair.second;
  552. MS_EXCEPTION_IF_NULL(dst_data);
  553. MS_EXCEPTION_IF_NULL(src_data);
  554. auto ret = memcpy_s(gradient + offset, dst_size, pair.second, src_size);
  555. if (ret != 0) {
  556. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  557. return;
  558. }
  559. offset += segment_size;
  560. }
  561. }
  562. template <typename T>
  563. void WorkerProxy<T>::BuildSparseValue(const ::ps::SArray<int> &lengths, const size_t grad_index,
  564. const size_t indice_index, const T *original_data, const T *grads, int *indices,
  565. ::ps::SArray<T> *reduced_data) {
  566. MS_EXCEPTION_IF_NULL(original_data);
  567. MS_EXCEPTION_IF_NULL(grads);
  568. MS_EXCEPTION_IF_NULL(indices);
  569. MS_EXCEPTION_IF_NULL(reduced_data);
  570. int64_t offset = 0;
  571. size_t dst_size = 0;
  572. size_t src_size = 0;
  573. void *dst_data = nullptr;
  574. void *src_data = nullptr;
  575. for (size_t i = 0; i < lengths.size(); i++) {
  576. if (i != grad_index && i != indice_index) {
  577. int data_size = lengths[i] * sizeof(T);
  578. dst_size = data_size;
  579. src_size = data_size;
  580. dst_data = reduced_data->data() + offset;
  581. src_data = const_cast<T *>(original_data) + offset;
  582. MS_EXCEPTION_IF_NULL(dst_data);
  583. MS_EXCEPTION_IF_NULL(src_data);
  584. auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  585. if (ret != 0) {
  586. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  587. return;
  588. }
  589. }
  590. offset += lengths[i];
  591. }
  592. // Fill the reduced gradient
  593. int64_t grad_offset = 0;
  594. for (size_t i = 0; i < grad_index; i++) {
  595. grad_offset += lengths[i];
  596. }
  597. int64_t data_size = lengths[grad_index] * sizeof(T);
  598. dst_size = data_size;
  599. src_size = data_size;
  600. dst_data = reduced_data->data() + grad_offset;
  601. src_data = const_cast<T *>(grads);
  602. MS_EXCEPTION_IF_NULL(dst_data);
  603. MS_EXCEPTION_IF_NULL(src_data);
  604. auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  605. if (ret != 0) {
  606. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  607. return;
  608. }
  609. // Fill the reduced indice
  610. int64_t indice_offset = grad_offset + lengths[grad_index];
  611. data_size = lengths[indice_index] * sizeof(T);
  612. T *indice_data = reduced_data->data() + indice_offset;
  613. dst_size = data_size;
  614. src_size = data_size;
  615. dst_data = indice_data;
  616. src_data = indices;
  617. MS_EXCEPTION_IF_NULL(dst_data);
  618. MS_EXCEPTION_IF_NULL(src_data);
  619. ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  620. if (ret != 0) {
  621. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  622. return;
  623. }
  624. }
  625. template <typename T>
  626. void WorkerProxy<T>::BroadcastSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  627. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  628. const std::map<int64_t, int64_t> &attr) {
  629. MS_EXCEPTION_IF_NULL(sliced);
  630. sliced->resize(server_num_);
  631. for (int64_t i = 0; i < server_num_; i++) {
  632. sliced->at(i).first = true;
  633. sliced->at(i).second = send;
  634. expected_result_count_[timestamp] += 1;
  635. }
  636. }
  637. template <typename T>
  638. void WorkerProxy<T>::RoundRobinSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  639. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  640. const std::map<int64_t, int64_t> &attr) {
  641. MS_EXCEPTION_IF_NULL(sliced);
  642. sliced->resize(server_num_);
  643. auto keys = send.keys;
  644. auto vals = send.vals;
  645. auto lens = send.lens;
  646. int64_t server_id, len;
  647. ::ps::Key param_key;
  648. for (size_t i = 0; i < keys.size(); i++) {
  649. param_key = keys[i];
  650. server_id = key_to_server_id_[param_key];
  651. if (!sliced->at(server_id).first) {
  652. sliced->at(server_id).first = true;
  653. expected_result_count_[timestamp] += 1;
  654. }
  655. ::ps::KVPairs<T> &server_kv_pairs = sliced->at(server_id).second;
  656. server_kv_pairs.keys.push_back(param_key);
  657. if (vals.empty()) {
  658. continue;
  659. }
  660. len = lens[i];
  661. int64_t offset = std::accumulate(lens.begin(), lens.begin() + i, 0);
  662. auto val_begin = vals.begin() + offset;
  663. auto val_end = val_begin + len;
  664. for (auto iter = val_begin; iter != val_end; iter++) {
  665. server_kv_pairs.vals.push_back(*iter);
  666. }
  667. server_kv_pairs.lens.push_back(len);
  668. }
  669. }
  670. template <typename T>
  671. void WorkerProxy<T>::WorkerInitEmbeddingSlicer(int64_t timestamp, const ::ps::KVPairs<T> &send,
  672. const std::vector<::ps::Range> &,
  673. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  674. const std::map<int64_t, int64_t> &attrs) {
  675. MS_EXCEPTION_IF_NULL(sliced);
  676. sliced->resize(server_num_);
  677. auto keys = send.keys;
  678. auto vals = send.vals;
  679. auto lens = send.lens;
  680. size_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]];
  681. const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[keys[0]]);
  682. for (size_t i = 0; i < ranges.size(); i++) {
  683. size_t offset_begin = ranges[i].begin() * col_cnt;
  684. size_t offset_end = (ranges[i].end() + 1) * col_cnt;
  685. ::ps::KVPairs<T> kvs;
  686. kvs.keys = keys;
  687. kvs.vals = vals.segment(offset_begin, offset_end);
  688. kvs.lens.push_back(offset_end - offset_begin);
  689. sliced->at(i).first = true;
  690. sliced->at(i).second = kvs;
  691. }
  692. }
  693. template <typename T>
  694. void WorkerProxy<T>::UpdateEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send,
  695. const std::vector<::ps::Range> &,
  696. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  697. const std::map<int64_t, int64_t> &attrs) {
  698. MS_EXCEPTION_IF_NULL(sliced);
  699. T *embedding_vals = send.vals.data();
  700. int *lookup_ids = send.lens.data();
  701. size_t val_size = send.vals.size();
  702. size_t id_size = send.lens.size();
  703. size_t embedding_dim = val_size / id_size;
  704. const Key &key = send.keys[0];
  705. const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
  706. sliced->resize(ranges.size());
  707. for (size_t i = 0; i < ranges.size(); i++) {
  708. const ::ps::Range &range = ranges[i];
  709. const auto &begin = range.begin();
  710. const auto &end = range.end();
  711. auto &kvs = sliced->at(i).second;
  712. kvs.keys.push_back(key);
  713. for (size_t j = 0; j < id_size; j++) {
  714. auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
  715. if (lookup_id >= begin && lookup_id <= end) {
  716. kvs.keys.push_back(lookup_id);
  717. for (size_t k = 0; k < embedding_dim; k++) {
  718. kvs.vals.push_back(embedding_vals[j * embedding_dim + k]);
  719. }
  720. }
  721. }
  722. if (kvs.keys.size() <= 1) {
  723. sliced->at(i).first = false;
  724. } else {
  725. sliced->at(i).first = true;
  726. expected_result_count_[timestamp] += 1;
  727. }
  728. }
  729. }
  730. template <typename T>
  731. void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
  732. int64_t ts = msg.meta.timestamp;
  733. if (msg.meta.pull) {
  734. CHECK_GE(msg.data.size(), (size_t)2);
  735. ::ps::KVPairs<T> kvs;
  736. kvs.keys = msg.data[0];
  737. kvs.vals = msg.data[1];
  738. if (msg.data.size() > (size_t)2) {
  739. kvs.lens = msg.data[2];
  740. }
  741. mutex_.lock();
  742. lookup_results_[ts].push_back(kvs);
  743. mutex_.unlock();
  744. }
  745. if (lookup_customer_->NumResponse(ts) + 1 == server_num_) {
  746. const auto &cb = lookup_callbacks_[ts];
  747. cb();
  748. lookup_callbacks_.erase(ts);
  749. }
  750. }
  751. template <typename T>
  752. void WorkerProxy<T>::ProcessResponse(const ::ps::Message &msg) {
  753. int64_t ts = msg.meta.timestamp;
  754. if (msg.meta.pull) {
  755. CHECK_GE(msg.data.size(), (size_t)2);
  756. ::ps::KVPairs<T> kvs;
  757. kvs.keys = msg.data[0];
  758. kvs.vals = msg.data[1];
  759. if (msg.data.size() > (size_t)2) {
  760. kvs.lens = msg.data[2];
  761. }
  762. mutex_.lock();
  763. int rsp_server_rank = ::ps::Postoffice::Get()->IDtoRank(msg.meta.sender);
  764. gathered_response_[ts][rsp_server_rank] = kvs;
  765. mutex_.unlock();
  766. if (general_customer_->NumResponse(ts) + 1 == server_num_) {
  767. const auto &cb = general_callbacks_[ts];
  768. cb();
  769. general_callbacks_.erase(ts);
  770. }
  771. }
  772. }
  773. template <typename T>
  774. void WorkerProxy<T>::Send(::ps::Customer *customer, int64_t timestamp, bool push, bool pull, int64_t cmd,
  775. const ::ps::KVPairs<T> &kvs, const Slicer &slicer, std::map<int64_t, int64_t> attrs) {
  776. MS_EXCEPTION_IF_NULL(customer);
  777. SlicedKVs sliced;
  778. slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced, attrs);
  779. for (size_t i = 0; i < sliced.size(); i++) {
  780. const auto &s = sliced[i];
  781. if (!s.first) continue;
  782. ::ps::Message msg;
  783. msg.meta.app_id = customer->app_id();
  784. msg.meta.customer_id = customer->customer_id();
  785. msg.meta.request = true;
  786. msg.meta.push = push;
  787. msg.meta.pull = pull;
  788. msg.meta.head = cmd;
  789. msg.meta.timestamp = timestamp;
  790. msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i);
  791. msg.meta.priority = kvs.priority;
  792. const auto &kvs = s.second;
  793. if (kvs.keys.size()) {
  794. msg.AddData(kvs.keys);
  795. msg.AddData(kvs.vals);
  796. if (kvs.lens.size()) {
  797. msg.AddData(kvs.lens);
  798. }
  799. }
  800. ::ps::Postoffice::Get()->van()->Send(msg);
  801. }
  802. }
  803. } // namespace ps
  804. } // namespace mindspore
  805. #endif // MINDSPORE_CCSRC_PS_WORKER_PROXY_H_