You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_
  17. #define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_
  18. #include <unordered_map>
  19. #include <algorithm>
  20. #include <utility>
  21. #include <memory>
  22. #include <vector>
  23. #include <unordered_set>
  24. #include "ps/ps.h"
  25. #include "frontend/parallel/ps/util.h"
  26. #include "backend/kernel_compiler/common_utils.h"
  27. namespace mindspore {
  28. namespace parallel {
  29. namespace ps {
  30. template <typename T>
  31. class WorkerProxy : public ::ps::KVWorker<T> {
  32. public:
  33. using Worker = ::ps::KVWorker<T>;
  34. using Callback = std::function<void()>;
  35. using SlicedKVs = std::vector<std::pair<bool, ::ps::KVPairs<T>>>;
  36. using Slicer = std::function<void(int ts, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &ranges,
  37. SlicedKVs *sliced, std::map<int, int> &attrs)>;
  38. using ::ps::SimpleApp::obj_;
  39. explicit WorkerProxy(int app_id, int customer_id, int lookup_customer_id, int general_customer_id)
  40. : Worker(app_id, customer_id) {
  41. server_num_ = ::ps::NumServers();
  42. Util::SetRankId(::ps::MyRank());
  43. using std::placeholders::_1;
  44. using std::placeholders::_2;
  45. using std::placeholders::_3;
  46. using std::placeholders::_4;
  47. using std::placeholders::_5;
  48. lookup_customer_ = std::unique_ptr<::ps::Customer>(
  49. new ::ps::Customer(app_id, lookup_customer_id, std::bind(&WorkerProxy<T>::ProcessLookupResult, this, _1)));
  50. general_customer_ = std::unique_ptr<::ps::Customer>(
  51. new ::ps::Customer(app_id, general_customer_id, std::bind(&WorkerProxy<T>::ProcessResponse, this, _1)));
  52. lookup_slicer_ = std::bind(&WorkerProxy<T>::LookupIdSlicer, this, _1, _2, _3, _4, _5);
  53. sparse_slicer_ = std::bind(&WorkerProxy<T>::SparseSlicer, this, _1, _2, _3, _4, _5);
  54. broadcast_slicer_ = std::bind(&WorkerProxy<T>::BroadcastSlicer, this, _1, _2, _3, _4, _5);
  55. round_robin_slicer_ = std::bind(&WorkerProxy<T>::RoundRobinSlicer, this, _1, _2, _3, _4, _5);
  56. worker_init_embedding_slicer_ = std::bind(&WorkerProxy<T>::WorkerInitEmbeddingSlicer, this, _1, _2, _3, _4, _5);
  57. }
  58. ~WorkerProxy() override = default;
  59. void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count);
  60. void AddKeyToServerId(const ::ps::Key &key);
  61. void EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
  62. const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd = 0, const Callback &cb = nullptr,
  63. int priority = 0);
  64. int InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
  65. const ::ps::SArray<int> &lens = {}, const Callback &cb = nullptr, int priority = 0);
  66. bool IsReadyForPush(const Key &key);
  67. bool IsReadyForPull(const Key &key);
  68. void PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens = {},
  69. int cmd = 0, int priority = 0);
  70. void PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals, const ::ps::SArray<int> &lens,
  71. size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size);
  72. void PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens = nullptr,
  73. int cmd = 0, int priority = 0);
  74. void Finalize();
  75. private:
  76. template <typename C>
  77. int AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids, C *vals, int cmd,
  78. const Callback &cb);
  79. int AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens, int cmd,
  80. const Callback &cb);
  81. void LookupIdSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  82. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, std::map<int, int> &attrs);
  83. void SparseSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  84. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, std::map<int, int> &attrs);
  85. void BroadcastSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  86. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, std::map<int, int> &attrs);
  87. void RoundRobinSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  88. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, std::map<int, int> &attrs);
  89. void WorkerInitEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  90. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, std::map<int, int> &attrs);
  91. void ProcessLookupResult(const ::ps::Message &msg);
  92. void ProcessResponse(const ::ps::Message &msg);
  93. void Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd, const ::ps::KVPairs<T> &kvs,
  94. const Slicer &slicer, std::map<int, int> attrs = {});
  95. void AddKeyByHashMod(const ::ps::Key &key);
  96. void PrepareSparseGradient(const size_t begin, const size_t end, const std::vector<int> &indice_ids,
  97. const std::vector<std::pair<int, T *>> &indice_to_grad, const int *all_indice,
  98. const size_t segment_size, T *gradient, int *indice);
  99. void ReduceSparseGradient(T *gradients, int *indices, const size_t indices_size, size_t segment_size,
  100. const size_t first_dim_size, const size_t outer_dim_size,
  101. mindspore::kernel::SparseGradient &unique_sparse_grad);
  102. void BuildSparseValue(const ::ps::SArray<int> &lengths, const size_t grad_index, const size_t indice_index,
  103. const T *original_data, const T *grads, int *indices, ::ps::SArray<T> &reduced_data);
  104. int server_num_;
  105. std::unique_ptr<::ps::Customer> lookup_customer_;
  106. std::unique_ptr<::ps::Customer> general_customer_;
  107. std::unordered_map<::ps::Key, std::shared_ptr<std::vector<::ps::Range>>> embedding_table_ranges_;
  108. std::unordered_map<int, std::vector<::ps::KVPairs<T>>> lookup_results_;
  109. std::unordered_map<int, ::ps::KVPairs<T>> gathered_response_;
  110. std::mutex mutex_;
  111. Slicer lookup_slicer_;
  112. Slicer sparse_slicer_;
  113. Slicer broadcast_slicer_;
  114. Slicer round_robin_slicer_;
  115. Slicer worker_init_embedding_slicer_;
  116. std::unordered_map<int, Callback> lookup_callbacks_;
  117. std::unordered_map<int, Callback> general_callbacks_;
  118. std::unordered_map<int, int> expected_result_count_;
  119. std::unordered_map<::ps::Key, int> key_to_server_id_;
  120. std::unordered_map<::ps::Key, size_t> embedding_row_cnt_;
  121. };
  122. template <typename T>
  123. void WorkerProxy<T>::AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count) {
  124. uint64_t begin = 0;
  125. uint64_t end = 0;
  126. for (int i = 0; i < server_num_; i++) {
  127. int local_row_cnt = Util::LocalShard(row_count, i, server_num_);
  128. if (i == 0) {
  129. end = local_row_cnt - 1;
  130. } else {
  131. begin = end + 1;
  132. end += local_row_cnt;
  133. }
  134. ::ps::Range range(begin, end);
  135. if (embedding_table_ranges_.count(key) == 0) {
  136. embedding_table_ranges_[key] = std::make_shared<std::vector<::ps::Range>>();
  137. }
  138. embedding_table_ranges_[key]->push_back(range);
  139. }
  140. embedding_row_cnt_[key] = row_count;
  141. }
  142. template <typename T>
  143. void WorkerProxy<T>::AddKeyByHashMod(const ::ps::Key &key) {
  144. if (server_num_ == 0) {
  145. MS_LOG(EXCEPTION) << "Server number is invalid:0";
  146. }
  147. key_to_server_id_[key] = static_cast<int>(key % server_num_);
  148. MS_LOG(INFO) << "The server id of key " << key << " is " << key_to_server_id_[key];
  149. }
  150. template <typename T>
  151. void WorkerProxy<T>::AddKeyToServerId(const ::ps::Key &key) {
  152. AddKeyByHashMod(key);
  153. }
  154. template <typename T>
  155. void WorkerProxy<T>::EmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
  156. const ::ps::SArray<int> &lens, ::ps::SArray<T> *outs, int cmd, const Callback &cb,
  157. int priority) {
  158. int ts = AddLookupCB(keys, lookup_ids, outs, cmd, cb);
  159. ::ps::KVPairs<T> kvs;
  160. kvs.keys = keys;
  161. kvs.lens = lookup_ids;
  162. kvs.priority = priority;
  163. expected_result_count_[ts] = 0;
  164. Send(lookup_customer_.get(), ts, true, true, cmd, kvs, lookup_slicer_);
  165. int expect_rt_count = expected_result_count_[ts];
  166. lookup_customer_->AddResponse(ts, server_num_ - expect_rt_count);
  167. lookup_customer_->WaitRequest(ts);
  168. expected_result_count_.erase(ts);
  169. }
  170. template <typename T>
  171. int WorkerProxy<T>::InitEmbeddingTable(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
  172. const ::ps::SArray<int> &lens, const Callback &cb, int priority) {
  173. int ts = obj_->NewRequest(::ps::kServerGroup);
  174. ::ps::KVPairs<T> kvs;
  175. kvs.keys = keys;
  176. kvs.vals = vals;
  177. kvs.lens = lens;
  178. kvs.priority = priority;
  179. Send(obj_, ts, true, false, kInitEmbeddingsCmd, kvs, broadcast_slicer_);
  180. return ts;
  181. }
  182. template <typename T>
  183. bool WorkerProxy<T>::IsReadyForPush(const Key &key) {
  184. ::ps::SArray<T> result(1, 0);
  185. PullData({key}, &result, nullptr, kCheckReadyForPushCmd);
  186. if (result[0] > 0) {
  187. return true;
  188. } else {
  189. return false;
  190. }
  191. }
  192. template <typename T>
  193. bool WorkerProxy<T>::IsReadyForPull(const Key &key) {
  194. ::ps::SArray<T> result(1, 0);
  195. PullData({key}, &result, nullptr, kCheckReadyForPullCmd);
  196. if (result[0] > 0) {
  197. return true;
  198. } else {
  199. return false;
  200. }
  201. }
  202. template <typename T>
  203. void WorkerProxy<T>::PushData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
  204. const ::ps::SArray<int> &lens, int cmd, int priority) {
  205. int ts = AddGeneralRspCB(keys, nullptr, nullptr, cmd, nullptr);
  206. ::ps::KVPairs<T> kvs;
  207. kvs.keys = keys;
  208. kvs.vals = vals;
  209. kvs.lens = lens;
  210. kvs.priority = priority;
  211. if (embedding_table_ranges_.count(keys[0])) {
  212. if (cmd == kInitWeightsCmd) {
  213. Send(general_customer_.get(), ts, true, false, cmd, kvs, worker_init_embedding_slicer_);
  214. } else {
  215. Send(general_customer_.get(), ts, true, false, cmd, kvs, broadcast_slicer_);
  216. }
  217. } else {
  218. Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_);
  219. }
  220. if (expected_result_count_[ts] < server_num_) {
  221. general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
  222. }
  223. general_customer_->WaitRequest(ts);
  224. }
  225. template <typename T>
  226. void WorkerProxy<T>::PushSparseData(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<T> &vals,
  227. const ::ps::SArray<int> &lens, size_t grad_index, size_t indice_index,
  228. size_t first_dim_size, size_t outer_dim_size) {
  229. int ts = AddGeneralRspCB(keys, nullptr, nullptr, 0, nullptr);
  230. ::ps::KVPairs<T> kvs;
  231. kvs.keys = keys;
  232. kvs.vals = vals;
  233. kvs.lens = lens;
  234. int cmd = 0;
  235. if (embedding_table_ranges_.count(keys[0])) {
  236. std::map<int, int> attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}};
  237. Send(general_customer_.get(), ts, true, false, cmd, kvs, sparse_slicer_, attrs);
  238. } else {
  239. Send(general_customer_.get(), ts, true, false, cmd, kvs, round_robin_slicer_);
  240. }
  241. if (expected_result_count_[ts] < server_num_) {
  242. general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
  243. }
  244. general_customer_->WaitRequest(ts);
  245. }
  246. template <typename T>
  247. void WorkerProxy<T>::PullData(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens,
  248. int cmd, int priority) {
  249. int ts = AddGeneralRspCB(keys, vals, lens, cmd, nullptr);
  250. ::ps::KVPairs<T> kvs;
  251. kvs.keys = keys;
  252. kvs.priority = priority;
  253. if (embedding_table_ranges_.count(keys[0])) {
  254. Send(general_customer_.get(), ts, false, true, cmd, kvs, broadcast_slicer_);
  255. } else {
  256. Send(general_customer_.get(), ts, false, true, cmd, kvs, round_robin_slicer_);
  257. }
  258. if (expected_result_count_[ts] < server_num_) {
  259. general_customer_->AddResponse(ts, server_num_ - expected_result_count_[ts]);
  260. }
  261. general_customer_->WaitRequest(ts);
  262. }
  263. template <typename T>
  264. void WorkerProxy<T>::Finalize() {
  265. int ts = obj_->NewRequest(::ps::kServerGroup);
  266. ::ps::KVPairs<T> kvs;
  267. kvs.keys.push_back(0);
  268. kvs.vals.push_back(0.0f);
  269. Send(obj_, ts, true, false, kFinalizeCmd, kvs, broadcast_slicer_);
  270. obj_->WaitRequest(ts);
  271. ::ps::Finalize(0, true);
  272. }
  273. template <typename T>
  274. template <typename C>
  275. int WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray<int> &lookup_ids,
  276. C *lookup_result, int cmd, const Callback &cb) {
  277. int ts = lookup_customer_->NewRequest(::ps::kServerGroup);
  278. const auto &callback = [this, ts, keys, lookup_ids, lookup_result, cb]() mutable {
  279. mutex_.lock();
  280. auto &kvs = lookup_results_[ts];
  281. mutex_.unlock();
  282. std::unordered_map<Key, std::shared_ptr<std::pair<T *, int>>> id_addr_map;
  283. for (const auto &s : kvs) {
  284. int offset = 0;
  285. int len = s.vals.size() / s.keys.size();
  286. for (size_t i = 0; i < s.keys.size(); i++) {
  287. const Key &key = s.keys[i];
  288. T *addr = s.vals.data() + offset;
  289. offset += len;
  290. id_addr_map[key] = std::make_shared<std::pair<T *, int>>(std::make_pair(addr, len));
  291. }
  292. }
  293. T *result_addr = lookup_result->data();
  294. int offset = 0;
  295. for (size_t i = 0; i < lookup_ids.size(); i++) {
  296. auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])];
  297. int size = pair->second * sizeof(T);
  298. auto ret = memcpy_s(result_addr + offset, size, pair->first, size);
  299. if (ret != 0) {
  300. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  301. }
  302. offset += pair->second;
  303. }
  304. mutex_.lock();
  305. lookup_results_.erase(ts);
  306. mutex_.unlock();
  307. if (cb) cb();
  308. };
  309. lookup_callbacks_[ts] = callback;
  310. return ts;
  311. }
  312. template <typename T>
  313. int WorkerProxy<T>::AddGeneralRspCB(const ::ps::SArray<::ps::Key> &keys, ::ps::SArray<T> *vals, ::ps::SArray<int> *lens,
  314. int cmd, const Callback &cb) {
  315. int ts = general_customer_->NewRequest(::ps::kServerGroup);
  316. const auto &callback = [this, ts, keys, vals, lens, cb]() mutable {
  317. mutex_.lock();
  318. auto &kvs = gathered_response_[ts];
  319. mutex_.unlock();
  320. *vals = kvs.vals;
  321. if (lens) {
  322. *lens = kvs.lens;
  323. }
  324. mutex_.lock();
  325. gathered_response_.erase(ts);
  326. mutex_.unlock();
  327. if (cb) {
  328. cb();
  329. }
  330. };
  331. general_callbacks_[ts] = callback;
  332. return ts;
  333. }
  334. template <typename T>
  335. void WorkerProxy<T>::LookupIdSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  336. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, std::map<int, int> &attrs) {
  337. int *lookup_ids = send.lens.data();
  338. size_t id_size = send.lens.size();
  339. const Key &key = send.keys[0];
  340. const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
  341. sliced->resize(ranges.size());
  342. for (size_t i = 0; i < ranges.size(); i++) {
  343. const ::ps::Range &range = ranges[i];
  344. const auto &begin = range.begin();
  345. const auto &end = range.end();
  346. std::unordered_set<int> unique_ids;
  347. auto &kvs = sliced->at(i).second;
  348. kvs.keys.push_back(key);
  349. kvs.vals.push_back(0.0f);
  350. for (size_t j = 0; j < id_size; j++) {
  351. auto lookup_id = static_cast<uint64_t>(lookup_ids[j]);
  352. if (lookup_id >= begin && lookup_id <= end) {
  353. unique_ids.insert(lookup_id);
  354. }
  355. }
  356. for (const auto &lookup_id : unique_ids) {
  357. kvs.keys.push_back(lookup_id);
  358. kvs.vals.push_back(0.0f);
  359. }
  360. if (kvs.keys.size() <= 1) {
  361. sliced->at(i).first = false;
  362. } else {
  363. sliced->at(i).first = true;
  364. expected_result_count_[timestamp] += 1;
  365. }
  366. }
  367. }
  368. template <typename T>
  369. void WorkerProxy<T>::SparseSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  370. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, std::map<int, int> &attrs) {
  371. // Init variables
  372. T *data = send.vals.data();
  373. size_t grad_index = static_cast<size_t>(attrs[0]);
  374. size_t indice_index = static_cast<size_t>(attrs[1]);
  375. size_t first_dim_size = static_cast<size_t>(attrs[2]);
  376. size_t outer_dim_size = static_cast<size_t>(attrs[3]);
  377. int grad_size = send.lens[grad_index];
  378. int indice_size = send.lens[indice_index];
  379. int segment_size = grad_size / indice_size;
  380. int grad_offset = 0;
  381. int indice_offset = 0;
  382. for (size_t i = 0; i < grad_index; i++) {
  383. grad_offset += send.lens[i];
  384. }
  385. for (size_t j = 0; j < indice_index; j++) {
  386. indice_offset += send.lens[j];
  387. }
  388. T *grad_data = data + grad_offset;
  389. int *indice_data = reinterpret_cast<int *>(data) + indice_offset;
  390. // Build the mappings of indice to gradient
  391. std::vector<std::pair<int, T *>> indice_to_grads;
  392. for (int i = 0; i < indice_size; i++) {
  393. int indice = indice_data[i];
  394. T *grad = grad_data + i * segment_size;
  395. indice_to_grads.push_back(std::make_pair(indice, grad));
  396. }
  397. const Key &key = send.keys[0];
  398. const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[key]);
  399. sliced->resize(ranges.size());
  400. // Construct reduced sparse data for each server
  401. for (size_t i = 0; i < ranges.size(); i++) {
  402. const ::ps::Range &range = ranges[i];
  403. const auto &begin = range.begin();
  404. const auto &end = range.end();
  405. auto &kvs = sliced->at(i).second;
  406. kvs.keys = send.keys;
  407. kvs.lens = send.lens;
  408. // Prepare the sparse gradient and indice
  409. std::vector<int> indice_ids;
  410. for (int j = 0; j < indice_size; j++) {
  411. size_t indice = static_cast<size_t>(indice_data[j]);
  412. if (indice >= begin && indice <= end) {
  413. indice_ids.push_back(indice);
  414. }
  415. }
  416. size_t indices_size = indice_ids.size();
  417. int slice_segment_size = indices_size * segment_size;
  418. T *src_grad_data = new T[slice_segment_size];
  419. int *src_indice_data = new int[indices_size];
  420. PrepareSparseGradient(begin, end, indice_ids, indice_to_grads, indice_data, segment_size, src_grad_data,
  421. src_indice_data);
  422. // Reduce the sparse gradient and indice
  423. T *new_grad = new T[slice_segment_size];
  424. int *new_indices = new int[indices_size];
  425. mindspore::kernel::SparseGradient unique_sparse_grad({new_grad, new_indices, indices_size});
  426. ReduceSparseGradient(src_grad_data, src_indice_data, indices_size, segment_size, first_dim_size, outer_dim_size,
  427. unique_sparse_grad);
  428. // Update the length of reduce sparse gradient and indice
  429. kvs.lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
  430. kvs.lens[indice_index] = unique_sparse_grad.indices_size_;
  431. // Build the sparse value to be sent
  432. size_t total_size = 0;
  433. for (auto size : kvs.lens) {
  434. total_size += size;
  435. }
  436. ::ps::SArray<T> reduced_data(total_size, 0);
  437. BuildSparseValue(kvs.lens, grad_index, indice_index, data, unique_sparse_grad.value_, unique_sparse_grad.indices_,
  438. reduced_data);
  439. kvs.vals = reduced_data;
  440. if (indices_size <= 0) {
  441. sliced->at(i).first = false;
  442. } else {
  443. sliced->at(i).first = true;
  444. expected_result_count_[timestamp] += 1;
  445. }
  446. }
  447. }
  448. template <typename T>
  449. void WorkerProxy<T>::PrepareSparseGradient(const size_t begin, const size_t end, const std::vector<int> &indice_ids,
  450. const std::vector<std::pair<int, T *>> &indice_to_grads,
  451. const int *all_indice, const size_t segment_size, T *gradient,
  452. int *indices) {
  453. int offset = 0;
  454. int index = 0;
  455. size_t segment_data_size = segment_size * sizeof(T);
  456. for (auto &pair : indice_to_grads) {
  457. indices[index++] = pair.first;
  458. auto ret = memcpy_s(gradient + offset, segment_data_size, pair.second, segment_data_size);
  459. if (ret != 0) {
  460. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  461. }
  462. offset += segment_size;
  463. }
  464. }
  465. template <typename T>
  466. void WorkerProxy<T>::ReduceSparseGradient(T *gradients, int *indices, const size_t indices_size, size_t segment_size,
  467. const size_t first_dim_size, const size_t outer_dim_size,
  468. mindspore::kernel::SparseGradient &unique_sparse_grad) {
  469. size_t slice_segment_size = indices_size * segment_size;
  470. auto workspace_grad = new T[slice_segment_size];
  471. auto workspace_indices = new int[indices_size];
  472. MS_EXCEPTION_IF_NULL(gradients);
  473. MS_EXCEPTION_IF_NULL(indices);
  474. MS_EXCEPTION_IF_NULL(workspace_grad);
  475. MS_EXCEPTION_IF_NULL(workspace_indices);
  476. mindspore::kernel::SparseGradient workspace_sparse_grad({workspace_grad, workspace_indices, indices_size});
  477. mindspore::kernel::SparseGradient input_sparse_grad({gradients, indices, indices_size});
  478. ReduceSparseGradientParam param;
  479. param.input_grad_ = &input_sparse_grad;
  480. param.workspace_grad_ = &workspace_sparse_grad;
  481. param.output_grad_ = &unique_sparse_grad;
  482. param.max_index_ = first_dim_size;
  483. param.value_stride_ = outer_dim_size;
  484. BucketReduceSparseGradient(param);
  485. }
  486. template <typename T>
  487. void WorkerProxy<T>::BuildSparseValue(const ::ps::SArray<int> &lengths, const size_t grad_index,
  488. const size_t indice_index, const T *original_data, const T *grads, int *indices,
  489. ::ps::SArray<T> &reduced_data) {
  490. int offset = 0;
  491. for (size_t i = 0; i < lengths.size(); i++) {
  492. if (i != grad_index && i != indice_index) {
  493. int data_size = lengths[i] * sizeof(T);
  494. auto ret = memcpy_s(reduced_data.data() + offset, data_size, original_data + offset, data_size);
  495. if (ret != 0) {
  496. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  497. }
  498. }
  499. offset += lengths[i];
  500. }
  501. // Fill the reduced gradient
  502. int grad_offset = 0;
  503. for (size_t i = 0; i < grad_index; i++) {
  504. grad_offset += lengths[i];
  505. }
  506. int data_size = lengths[grad_index] * sizeof(T);
  507. auto ret = memcpy_s(reduced_data.data() + grad_offset, data_size, grads, data_size);
  508. if (ret != 0) {
  509. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  510. }
  511. // Fill the reduced indice
  512. data_size = lengths[indice_index] * sizeof(T);
  513. int indice_offset = grad_offset + data_size;
  514. T *indice_data = reduced_data.data() + indice_offset;
  515. T *convert = new T[lengths[indice_index]];
  516. for (int i = 0; i < lengths[indice_index]; i++) {
  517. convert[i] = static_cast<T>(indices[i]);
  518. }
  519. ret = memcpy_s(indice_data, data_size, convert, data_size);
  520. if (ret != 0) {
  521. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  522. }
  523. delete[] convert;
  524. }
  525. template <typename T>
  526. void WorkerProxy<T>::BroadcastSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  527. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced, std::map<int, int> &attr) {
  528. sliced->resize(server_num_);
  529. for (int i = 0; i < server_num_; i++) {
  530. sliced->at(i).first = true;
  531. sliced->at(i).second = send;
  532. expected_result_count_[timestamp] += 1;
  533. }
  534. }
  535. template <typename T>
  536. void WorkerProxy<T>::RoundRobinSlicer(int timestamp, const ::ps::KVPairs<T> &send, const std::vector<::ps::Range> &,
  537. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  538. std::map<int, int> &attr) {
  539. sliced->resize(server_num_);
  540. auto keys = send.keys;
  541. auto vals = send.vals;
  542. auto lens = send.lens;
  543. int server_id, len;
  544. ::ps::Key param_key;
  545. for (size_t i = 0; i < keys.size(); i++) {
  546. param_key = keys[i];
  547. server_id = key_to_server_id_[param_key];
  548. if (!sliced->at(server_id).first) {
  549. sliced->at(server_id).first = true;
  550. expected_result_count_[timestamp] += 1;
  551. }
  552. ::ps::KVPairs<T> &server_kv_pairs = sliced->at(server_id).second;
  553. server_kv_pairs.keys.push_back(param_key);
  554. if (vals.empty()) {
  555. continue;
  556. }
  557. len = lens[i];
  558. int offset = std::accumulate(lens.begin(), lens.begin() + i, 0);
  559. auto val_begin = vals.begin() + offset;
  560. auto val_end = val_begin + len;
  561. for (auto iter = val_begin; iter != val_end; iter++) {
  562. server_kv_pairs.vals.push_back(*iter);
  563. }
  564. server_kv_pairs.lens.push_back(len);
  565. }
  566. }
  567. template <typename T>
  568. void WorkerProxy<T>::WorkerInitEmbeddingSlicer(int timestamp, const ::ps::KVPairs<T> &send,
  569. const std::vector<::ps::Range> &,
  570. std::vector<std::pair<bool, ::ps::KVPairs<T>>> *sliced,
  571. std::map<int, int> &attrs) {
  572. sliced->resize(server_num_);
  573. auto keys = send.keys;
  574. auto vals = send.vals;
  575. auto lens = send.lens;
  576. size_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]];
  577. const std::vector<::ps::Range> &ranges = *(embedding_table_ranges_[keys[0]]);
  578. for (size_t i = 0; i < ranges.size(); i++) {
  579. size_t offset_begin = ranges[i].begin() * col_cnt;
  580. size_t offset_end = (ranges[i].end() + 1) * col_cnt;
  581. ::ps::KVPairs<T> kvs;
  582. kvs.keys = keys;
  583. kvs.vals = vals.segment(offset_begin, offset_end);
  584. kvs.lens.push_back(offset_end - offset_begin);
  585. sliced->at(i).first = true;
  586. sliced->at(i).second = kvs;
  587. }
  588. }
  589. template <typename T>
  590. void WorkerProxy<T>::ProcessLookupResult(const ::ps::Message &msg) {
  591. int ts = msg.meta.timestamp;
  592. if (msg.meta.pull) {
  593. CHECK_GE(msg.data.size(), (size_t)2);
  594. ::ps::KVPairs<T> kvs;
  595. kvs.keys = msg.data[0];
  596. kvs.vals = msg.data[1];
  597. if (msg.data.size() > (size_t)2) {
  598. kvs.lens = msg.data[2];
  599. }
  600. mutex_.lock();
  601. lookup_results_[ts].push_back(kvs);
  602. mutex_.unlock();
  603. }
  604. if (lookup_customer_->NumResponse(ts) == expected_result_count_[ts] - 1) {
  605. const auto &cb = lookup_callbacks_[ts];
  606. cb();
  607. lookup_callbacks_.erase(ts);
  608. }
  609. }
  610. template <typename T>
  611. void WorkerProxy<T>::ProcessResponse(const ::ps::Message &msg) {
  612. int ts = msg.meta.timestamp;
  613. if (msg.meta.pull) {
  614. CHECK_GE(msg.data.size(), (size_t)2);
  615. ::ps::KVPairs<T> kvs;
  616. kvs.keys = msg.data[0];
  617. kvs.vals = msg.data[1];
  618. if (msg.data.size() > (size_t)2) {
  619. kvs.lens = msg.data[2];
  620. }
  621. mutex_.lock();
  622. for (auto key : kvs.keys) {
  623. gathered_response_[ts].keys.push_back(key);
  624. }
  625. for (auto val : kvs.vals) {
  626. gathered_response_[ts].vals.push_back(val);
  627. }
  628. for (auto len : kvs.lens) {
  629. gathered_response_[ts].lens.push_back(len);
  630. }
  631. mutex_.unlock();
  632. if (general_customer_->NumResponse(ts) + 1 == server_num_) {
  633. const auto &cb = general_callbacks_[ts];
  634. cb();
  635. general_callbacks_.erase(ts);
  636. }
  637. }
  638. }
  639. template <typename T>
  640. void WorkerProxy<T>::Send(::ps::Customer *customer, int timestamp, bool push, bool pull, int cmd,
  641. const ::ps::KVPairs<T> &kvs, const Slicer &slicer, std::map<int, int> attrs) {
  642. SlicedKVs sliced;
  643. slicer(timestamp, kvs, ::ps::Postoffice::Get()->GetServerKeyRanges(), &sliced, attrs);
  644. for (size_t i = 0; i < sliced.size(); i++) {
  645. const auto &s = sliced[i];
  646. if (!s.first) continue;
  647. ::ps::Message msg;
  648. msg.meta.app_id = customer->app_id();
  649. msg.meta.customer_id = customer->customer_id();
  650. msg.meta.request = true;
  651. msg.meta.push = push;
  652. msg.meta.pull = pull;
  653. msg.meta.head = cmd;
  654. msg.meta.timestamp = timestamp;
  655. msg.meta.recver = ::ps::Postoffice::Get()->ServerRankToID(i);
  656. msg.meta.priority = kvs.priority;
  657. const auto &kvs = s.second;
  658. if (kvs.keys.size()) {
  659. msg.AddData(kvs.keys);
  660. msg.AddData(kvs.vals);
  661. if (kvs.lens.size()) {
  662. msg.AddData(kvs.lens);
  663. }
  664. }
  665. ::ps::Postoffice::Get()->van()->Send(msg);
  666. }
  667. }
  668. } // namespace ps
  669. } // namespace parallel
  670. } // namespace mindspore
  671. #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PS_WORKER_PROXY_H_