You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

worker.cc 40 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/worker.h"
  17. #include "pipeline/jit/pipeline.h"
  18. namespace mindspore {
  19. namespace ps {
  20. constexpr int kRetryDuration = 2000;
  21. void Worker::Run() {
  22. std::lock_guard<std::mutex> lock(running_mutex_);
  23. server_num_ = PSContext::instance()->initial_server_num();
  24. if (running_) {
  25. MS_LOG(INFO) << "'Worker is already running.";
  26. return;
  27. }
  28. if (!PSContext::instance()->is_worker()) {
  29. MS_LOG(EXCEPTION) << "The role is not worker.";
  30. }
  31. Initialize();
  32. worker_node_.RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
  33. MS_LOG(ERROR) << "Trigger timeout event: SCHEDULER_TIMEOUT begin to exit the system!";
  34. this->Finalize();
  35. exit(0);
  36. });
  37. worker_node_.RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [this]() {
  38. MS_LOG(ERROR) << "Trigger timeout event: NODE_TIMEOUT begin to exit the system!";
  39. this->Finalize();
  40. exit(0);
  41. });
  42. MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
  43. worker_node_.Start();
  44. MS_LOG(INFO) << "Worker connected successfully.";
  45. running_ = true;
  46. }
  47. void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes) {
  48. if (keys.size() == 0) {
  49. MS_LOG(EXCEPTION) << "key size should be greater than zero";
  50. }
  51. if (key_to_optimId_.count(keys[0]) == 0) {
  52. MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0];
  53. }
  54. Key key = keys[0];
  55. int64_t optim_id = key_to_optimId_[key];
  56. MS_LOG(INFO) << "The key is:" << key << " the optim_id:" << optim_id;
  57. bool is_sparse = false;
  58. if (optim_id == 1 || optim_id == kSparseLazyAdamIndex || optim_id == kSparseFtrlIndex) {
  59. is_sparse = true;
  60. }
  61. int64_t grad_index = -1;
  62. int64_t indice_index = -1;
  63. // Sparse adam gradient
  64. if (optim_id == 1 || optim_id == kSparseLazyAdamIndex) {
  65. grad_index = kSparseGradIndex;
  66. indice_index = kSparseIndiceIndex;
  67. // Sparse ftrl gradient
  68. } else if (optim_id == kSparseFtrlIndex) {
  69. grad_index = 0;
  70. indice_index = 1;
  71. }
  72. size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus<int64_t>());
  73. std::vector<float> total_buffer(total_size, 0);
  74. size_t offset = 0;
  75. for (size_t i = 0; i < sizes.size(); i++) {
  76. void *dst_data = total_buffer.data() + offset / sizeof(float);
  77. void *src_data = reinterpret_cast<void *>(addrs[i]);
  78. MS_EXCEPTION_IF_NULL(dst_data);
  79. MS_EXCEPTION_IF_NULL(src_data);
  80. size_t size = sizes[i] * sizeof(float);
  81. size_t dest_size = size;
  82. size_t src_size = size;
  83. auto ret = memcpy_s(dst_data, dest_size, src_data, src_size);
  84. if (ret != 0) {
  85. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  86. return;
  87. }
  88. offset += size;
  89. }
  90. MS_LOG(INFO) << "The total size is:" << total_size;
  91. while (running_ && (!IsReadyForPush(keys[0]))) {
  92. continue;
  93. }
  94. std::vector<int> sizes_int;
  95. (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
  96. [](const int64_t &value) { return static_cast<int>(value); });
  97. if (!is_sparse) {
  98. PushData(std::vector<Key>(keys), total_buffer, std::vector<int>(sizes_int), kPushCmd);
  99. } else {
  100. std::vector<int64_t> &var_shape = key_to_optim_shapes_[key][0];
  101. int64_t first_dim_size = var_shape[0];
  102. int64_t outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies<int64_t>());
  103. MS_LOG(DEBUG) << "The keys:" << keys << " the total_buffer:" << total_buffer << " the sizes_int:" << sizes_int
  104. << " the grad_index:" << grad_index << " the indice_index:" << indice_index
  105. << " the first_dim_size:" << first_dim_size << " the outer_dim_size" << outer_dim_size;
  106. PushSparseData(std::vector<Key>(keys), total_buffer, std::vector<int>(sizes_int), LongToSize(grad_index),
  107. LongToSize(indice_index), LongToSize(first_dim_size), LongToSize(outer_dim_size));
  108. }
  109. }
  110. void Worker::Pull(const size_t key, void *dev_addr, const size_t size) {
  111. MS_EXCEPTION_IF_NULL(dev_addr);
  112. std::vector<float> variables(size / sizeof(float), 0);
  113. while (running_ && (!IsReadyForPull(key))) {
  114. continue;
  115. }
  116. PullData({key}, &variables, nullptr, kPullCmd);
  117. MS_LOG(DEBUG) << "The variables:" << variables << " the size is:" << size;
  118. size_t dst_size = size;
  119. size_t src_size = size;
  120. auto ret = memcpy_s(dev_addr, dst_size, variables.data(), src_size);
  121. if (ret != 0) {
  122. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  123. return;
  124. }
  125. }
  126. size_t Worker::SetParamKey(const std::string &param_name) {
  127. size_t key = UINT64_MAX;
  128. if (param_to_key_.count(param_name)) {
  129. key = param_to_key_[param_name];
  130. MS_LOG(INFO) << param_name << " key is already set: key value is " << key;
  131. } else {
  132. key = key_cnt_++;
  133. param_to_key_[param_name] = key;
  134. MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name;
  135. }
  136. return key;
  137. }
  138. size_t Worker::GetParamKey(const std::string &param_name) {
  139. size_t key = kInvalidKey;
  140. if (param_to_key_.find(param_name) != param_to_key_.end()) {
  141. key = param_to_key_[param_name];
  142. MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key;
  143. }
  144. return key;
  145. }
  146. void Worker::SetParamInitInServer(const std::string &param_name, bool init_in_server) {
  147. MS_LOG(DEBUG) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
  148. param_to_init_in_server_[param_name] = init_in_server;
  149. }
  150. bool Worker::GetParamInitInServer(const std::string &param_name) {
  151. if (param_to_init_in_server_.count(param_name) == 0) {
  152. return false;
  153. }
  154. return param_to_init_in_server_[param_name];
  155. }
  156. void Worker::SetKeyOptimId(size_t key, const std::string &optimizer_name) {
  157. MS_LOG(INFO) << "SetKeyOptimId key is:" << key << " optimizer_name:" << optimizer_name;
  158. key_to_optimId_[key] = Util::optimizer_id(optimizer_name);
  159. }
  160. void Worker::SetOptimInputShapes(size_t key, const ShapeVector &shape) {
  161. if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) {
  162. key_to_optim_shapes_[key] = {shape};
  163. } else {
  164. key_to_optim_shapes_[key].push_back(shape);
  165. }
  166. }
  167. void Worker::AddEmbeddingTable(const Key &key, const size_t &row_count) {
  168. bool has_init = IsKeyInit(key);
  169. if (has_init) {
  170. return;
  171. }
  172. uint64_t begin = 0;
  173. uint64_t end = 0;
  174. for (int64_t i = 0; i < server_num_; i++) {
  175. size_t local_row_cnt = LongToSize(Util::LocalShard(row_count, i, server_num_));
  176. MS_LOG(DEBUG) << "The row_count:" << row_count << " the local_row_cnt:" << local_row_cnt;
  177. if (i == 0) {
  178. end = local_row_cnt - 1;
  179. } else {
  180. begin = end + 1;
  181. end += local_row_cnt;
  182. }
  183. EmbeddingTableShardMetadata range(begin, end);
  184. if (embedding_table_ranges_.count(key) == 0) {
  185. embedding_table_ranges_[key] = std::make_shared<std::vector<EmbeddingTableShardMetadata>>();
  186. MS_EXCEPTION_IF_NULL(embedding_table_ranges_[key]);
  187. }
  188. embedding_table_ranges_[key]->push_back(range);
  189. }
  190. embedding_row_cnt_[key] = row_count;
  191. }
  192. bool Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
  193. const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape,
  194. const ParamInitInfoMessage &info, uint32_t timeout) {
  195. bool has_init = IsKeyInit(key);
  196. if (has_init) {
  197. MS_LOG(DEBUG) << "The key embedding table of key " << key << " is initialized.";
  198. return true;
  199. }
  200. EmbeddingTableMeta embedding_table_meta;
  201. embedding_table_meta.set_key(key);
  202. *embedding_table_meta.mutable_input_shape() = {input_shape.begin(), input_shape.end()};
  203. *embedding_table_meta.mutable_indices_shape() = {indices_shape.begin(), indices_shape.end()};
  204. *embedding_table_meta.mutable_output_shape() = {output_shape.begin(), output_shape.end()};
  205. *embedding_table_meta.mutable_info() = info;
  206. std::string kv_data = embedding_table_meta.SerializeAsString();
  207. #ifdef __APPLE__
  208. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  209. #else
  210. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  211. #endif
  212. size_t dest_size = kv_data.length();
  213. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  214. if (ret != 0) {
  215. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  216. return false;
  217. }
  218. while (!worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), kInitEmbeddingsCmd, timeout)) {
  219. MS_LOG(INFO) << "Worker Broadcast failed!, retrying.";
  220. if (!running_) {
  221. MS_LOG(ERROR) << "Worker Broadcast failed!";
  222. return false;
  223. }
  224. std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDuration));
  225. }
  226. return true;
  227. }
  228. void Worker::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) {
  229. MS_EXCEPTION_IF_NULL(tensor);
  230. MS_EXCEPTION_IF_NULL(input_node);
  231. auto pk_node = input_node->cast<ParameterPtr>();
  232. MS_EXCEPTION_IF_NULL(pk_node);
  233. const std::string &param_name = pk_node->fullname_with_scope();
  234. void *param_data = tensor->data_c();
  235. size_t param_size = LongToSize(tensor->data().nbytes());
  236. size_t param_key = GetParamKey(param_name);
  237. if (param_key == kInvalidKey) {
  238. MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned.";
  239. return;
  240. }
  241. bool init_in_server = false;
  242. auto param_info_ptr = pk_node->param_info();
  243. if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
  244. init_in_server = true;
  245. }
  246. SetParamInitInServer(param_name, init_in_server);
  247. bool init = IsKeyInit(param_key);
  248. if (!init) {
  249. MS_LOG(DEBUG) << "Init parameter key " << param_key << " and optimizer in parameter server side for " << param_name
  250. << ", whether init in server: " << init_in_server;
  251. AddKeyToServerId(param_key);
  252. if (!PsDataPrefetch::GetInstance().cache_enable()) {
  253. if (!init_in_server) {
  254. if (param_size > INT_MAX) {
  255. MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is "
  256. << param_size;
  257. }
  258. InitPSParamData({param_key}, param_data, param_size);
  259. }
  260. InitPSOptimId(param_key);
  261. InitPSOptimInputShapes(param_key);
  262. }
  263. }
  264. }
  265. bool Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
  266. int64_t cmd) {
  267. MS_EXCEPTION_IF_NULL(lookup_result);
  268. EmbeddingTableLookup embedding_table_lookup;
  269. embedding_table_lookup.set_key(key);
  270. *embedding_table_lookup.mutable_keys() = {lookup_ids.begin(), lookup_ids.end()};
  271. PartitionEmbeddingMessages messages;
  272. lookup_partitioner_(embedding_table_lookup, &messages, {});
  273. std::vector<uint32_t> rank_ids;
  274. std::vector<DataPtr> data;
  275. std::vector<size_t> sizes;
  276. for (size_t i = 0; i < messages.size(); i++) {
  277. if (messages.at(i).first) {
  278. rank_ids.push_back(i);
  279. std::string kv_data = messages.at(i).second.SerializeAsString();
  280. #ifdef __APPLE__
  281. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  282. #else
  283. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  284. #endif
  285. size_t dest_size = kv_data.length();
  286. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  287. if (ret != 0) {
  288. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  289. return false;
  290. }
  291. data.push_back(res);
  292. sizes.push_back(kv_data.length());
  293. }
  294. }
  295. std::vector<VectorPtr> resp;
  296. while (!worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, LongToInt(cmd), &resp)) {
  297. MS_LOG(INFO) << "Worker send failed!, retrying.";
  298. if (!running_) {
  299. MS_LOG(ERROR) << "Worker send failed!";
  300. return false;
  301. }
  302. std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDuration));
  303. }
  304. int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
  305. mindspore::HashMap<Key, std::shared_ptr<std::pair<float *, int64_t>>> id_addr_map;
  306. std::shared_ptr<std::vector<float>> values = std::make_shared<std::vector<float>>();
  307. std::shared_ptr<std::vector<Key>> keys = std::make_shared<std::vector<Key>>();
  308. int64_t value_offset = 0;
  309. for (size_t i = 0; i < resp.size(); ++i) {
  310. KVMessage message;
  311. CHECK_RETURN_TYPE(message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size()));
  312. for (auto j = 0; j < message.values_size(); j++) {
  313. values->push_back(message.values(j));
  314. }
  315. for (auto k = 0; k < message.keys_size(); k++) {
  316. const Key &message_key = message.keys(k);
  317. keys->push_back(message_key);
  318. }
  319. }
  320. for (size_t i = 0; i < keys->size(); i++) {
  321. const Key &map_key = keys->at(i);
  322. float *addr = values->data() + value_offset;
  323. value_offset += single_id_len;
  324. id_addr_map[map_key] = std::make_shared<std::pair<float *, int64_t>>(std::make_pair(addr, single_id_len));
  325. }
  326. float *result_addr = lookup_result->data();
  327. MS_EXCEPTION_IF_NULL(result_addr);
  328. int64_t offset = 0;
  329. size_t dst_size = 0;
  330. size_t src_size = 0;
  331. void *dst_data = nullptr;
  332. void *src_data = nullptr;
  333. for (size_t i = 0; i < lookup_ids.size(); i++) {
  334. if (id_addr_map.count(lookup_ids[i]) == 0) {
  335. offset += single_id_len;
  336. continue;
  337. }
  338. const Key &id_key = static_cast<Key>(lookup_ids[i]);
  339. auto &pair = id_addr_map[id_key];
  340. size_t size = LongToSize(single_id_len * sizeof(float));
  341. dst_size = size;
  342. src_size = size;
  343. dst_data = result_addr + offset;
  344. src_data = pair->first;
  345. MS_EXCEPTION_IF_NULL(dst_data);
  346. MS_EXCEPTION_IF_NULL(src_data);
  347. auto mem_ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  348. if (mem_ret != 0) {
  349. MS_LOG(ERROR) << "memcpy_s error, errorno(" << mem_ret << ")";
  350. return false;
  351. }
  352. offset += single_id_len;
  353. }
  354. return true;
  355. }
  356. bool Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
  357. const std::vector<float> &vals) {
  358. KVMessage kvs;
  359. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  360. *kvs.mutable_len() = {lookup_ids.begin(), lookup_ids.end()};
  361. *kvs.mutable_values() = {vals.begin(), vals.end()};
  362. PartitionKVMessages messages;
  363. update_embedding_partitioner_(kvs, &messages, {});
  364. std::vector<uint32_t> rank_ids;
  365. std::vector<DataPtr> data;
  366. std::vector<size_t> sizes;
  367. for (size_t i = 0; i < messages.size(); i++) {
  368. if (messages.at(i).first) {
  369. rank_ids.push_back(i);
  370. std::string kv_data = messages.at(i).second.SerializeAsString();
  371. #ifdef __APPLE__
  372. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  373. #else
  374. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  375. #endif
  376. size_t dest_size = kv_data.length();
  377. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  378. if (ret != 0) {
  379. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  380. return false;
  381. }
  382. data.push_back(res);
  383. sizes.push_back(kv_data.length());
  384. }
  385. }
  386. while (!worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, LongToInt(kUpdateEmbeddingsCmd))) {
  387. MS_LOG(INFO) << "Worker send failed!, retrying.";
  388. if (!running_) {
  389. MS_LOG(ERROR) << "Worker send failed!";
  390. return false;
  391. }
  392. std::this_thread::sleep_for(std::chrono::milliseconds(kRetryDuration));
  393. }
  394. return true;
  395. }
  396. void Worker::Finalize() {
  397. if (running_) {
  398. MS_LOG(INFO) << "Worker starts finalizing...";
  399. KVMessage kvs;
  400. kvs.add_keys(0);
  401. kvs.add_values(0.0f);
  402. std::string kv_data = kvs.SerializeAsString();
  403. #ifdef __APPLE__
  404. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  405. #else
  406. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  407. #endif
  408. size_t dest_size = kv_data.length();
  409. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  410. if (ret != 0) {
  411. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  412. return;
  413. }
  414. worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), kFinalizeCmd);
  415. worker_node_.Finish();
  416. worker_node_.Stop();
  417. running_ = false;
  418. MS_LOG(INFO) << "Worker finalized successfully.";
  419. }
  420. }
  421. void Worker::Initialize() {
  422. lookup_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  423. LookupIdPartitioner(send, partition, attrs);
  424. };
  425. worker_init_embedding_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  426. WorkerInitEmbeddingPartitioner(send, partition, attrs);
  427. };
  428. round_robin_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  429. RoundRobinPartitioner(send, partition, attrs);
  430. };
  431. sparse_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  432. SparsePartitioner(send, partition, attrs);
  433. };
  434. update_embedding_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  435. UpdateEmbeddingPartitioner(send, partition, attrs);
  436. };
  437. broadcast_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  438. BroadcastPartitioner(send, partition, attrs);
  439. };
  440. }
  441. bool Worker::IsKeyInit(const size_t key) {
  442. if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) {
  443. return false;
  444. }
  445. return true;
  446. }
  447. void Worker::AddKeyToServerId(const Key &key) { AddKeyByHashMod(key); }
  448. void Worker::AddKeyByHashMod(const Key &key) {
  449. if (server_num_ == 0) {
  450. MS_LOG(EXCEPTION) << "Server number is invalid:0";
  451. }
  452. key_to_server_id_[key] = static_cast<int64_t>(key % server_num_);
  453. MS_LOG(DEBUG) << "The server id of key " << key << " is " << key_to_server_id_[key];
  454. }
  455. void Worker::InitPSOptimId(const size_t param_key) {
  456. MS_LOG(INFO) << "InitPSOptimId key is:" << param_key;
  457. if (key_to_optimId_.count(param_key) == 0) {
  458. MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key;
  459. }
  460. int64_t optim_id = key_to_optimId_[param_key];
  461. std::vector<Key> keys = {param_key};
  462. std::vector<float> optim_id_vals = {static_cast<float>(optim_id)};
  463. std::vector<int> optim_id_lens = {SizeToInt(optim_id_vals.size())};
  464. MS_LOG(INFO) << "The keys is" << keys << " the optim_id_vals is: " << optim_id_vals
  465. << " optim_id_lens is:" << optim_id_lens;
  466. PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd);
  467. }
  468. void Worker::InitPSOptimInputShapes(const size_t key) {
  469. std::vector<Key> keys;
  470. std::vector<int> shape_len;
  471. std::vector<float> all_shape;
  472. std::vector<ShapeVector> shapes = key_to_optim_shapes_[key];
  473. for (auto shape : shapes) {
  474. keys.push_back(key);
  475. if (shape.size() == 0) {
  476. shape_len.push_back(1);
  477. all_shape.push_back(1);
  478. } else {
  479. shape_len.push_back(SizeToLong(shape.size()));
  480. std::transform(shape.begin(), shape.end(), std::back_inserter(all_shape),
  481. [](size_t dim) -> float { return static_cast<float>(dim); });
  482. }
  483. }
  484. MS_LOG(INFO) << "keys:" << keys;
  485. MS_LOG(INFO) << "shape_len:" << shape_len;
  486. MS_LOG(INFO) << "all_shape:" << all_shape;
  487. if (!init_keys_[key]) {
  488. init_keys_[key] = true;
  489. }
  490. PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd);
  491. }
  492. void Worker::InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size) {
  493. MS_EXCEPTION_IF_NULL(origin_addr);
  494. std::vector<float> addr{reinterpret_cast<float *>(origin_addr),
  495. reinterpret_cast<float *>(origin_addr) + size / sizeof(float)};
  496. std::vector<Key> key(keys);
  497. std::vector<int> lens;
  498. lens.push_back(addr.size());
  499. MS_LOG(INFO) << "the keys are:" << keys;
  500. MS_LOG(INFO) << "the values are:" << addr;
  501. PushData(key, addr, lens, kInitWeightsCmd);
  502. init_keys_[key[0]] = true;
  503. }
  504. bool Worker::IsReadyForPush(const Key &key) {
  505. std::vector<float> result(1, 0);
  506. PullData({key}, &result, nullptr, kCheckReadyForPushCmd);
  507. MS_LOG(INFO) << "key:" << key;
  508. if (result[0] > 0) {
  509. MS_LOG(INFO) << "IsReadyForPush:";
  510. return true;
  511. } else {
  512. MS_LOG(INFO) << "IsReadyForPush:";
  513. return false;
  514. }
  515. }
  516. bool Worker::IsReadyForPull(const Key &key) {
  517. std::vector<float> result(1, 0);
  518. PullData({key}, &result, nullptr, kCheckReadyForPullCmd);
  519. if (result[0] > 0) {
  520. MS_LOG(INFO) << "IsReadyForPull";
  521. return true;
  522. } else {
  523. MS_LOG(INFO) << "IsReadyForPull";
  524. return false;
  525. }
  526. }
  527. void Worker::PrepareSparseGradient(const size_t, const size_t, const mindspore::HashSet<int> &distinct_ids,
  528. const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
  529. const size_t segment_size, float *gradient, int *indices) {
  530. MS_EXCEPTION_IF_NULL(all_indice);
  531. MS_EXCEPTION_IF_NULL(gradient);
  532. MS_EXCEPTION_IF_NULL(indices);
  533. size_t offset = 0;
  534. int64_t index = 0;
  535. size_t segment_data_size = segment_size * sizeof(float);
  536. size_t dst_size;
  537. size_t src_size;
  538. void *dst_data = nullptr;
  539. void *src_data = nullptr;
  540. for (auto &pair : indice_to_grads) {
  541. if (distinct_ids.count(pair.first) == 0) {
  542. continue;
  543. }
  544. indices[index++] = pair.first;
  545. dst_size = segment_data_size;
  546. src_size = segment_data_size;
  547. dst_data = gradient + offset;
  548. src_data = pair.second;
  549. MS_EXCEPTION_IF_NULL(dst_data);
  550. MS_EXCEPTION_IF_NULL(src_data);
  551. auto ret = memcpy_s(gradient + offset, dst_size, pair.second, src_size);
  552. if (ret != 0) {
  553. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  554. return;
  555. }
  556. offset += segment_size;
  557. }
  558. }
  559. void Worker::BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
  560. const float *original_data, const float *grads, int *indices,
  561. std::vector<float> *reduced_data) {
  562. MS_EXCEPTION_IF_NULL(original_data);
  563. MS_EXCEPTION_IF_NULL(grads);
  564. MS_EXCEPTION_IF_NULL(indices);
  565. MS_EXCEPTION_IF_NULL(reduced_data);
  566. int64_t offset = 0;
  567. size_t dst_size = 0;
  568. size_t src_size = 0;
  569. void *dst_data = nullptr;
  570. void *src_data = nullptr;
  571. for (size_t i = 0; i < lengths.size(); i++) {
  572. if (i != grad_index && i != indice_index) {
  573. size_t data_size = lengths[i] * sizeof(float);
  574. dst_size = data_size;
  575. src_size = data_size;
  576. dst_data = reduced_data->data() + offset;
  577. src_data = const_cast<float *>(original_data) + offset;
  578. MS_EXCEPTION_IF_NULL(dst_data);
  579. MS_EXCEPTION_IF_NULL(src_data);
  580. auto mem_ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  581. if (mem_ret != 0) {
  582. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << mem_ret << ")";
  583. return;
  584. }
  585. }
  586. offset += lengths[i];
  587. }
  588. // Fill the reduced gradient
  589. int64_t grad_offset = 0;
  590. for (size_t i = 0; i < grad_index; i++) {
  591. grad_offset += lengths[i];
  592. }
  593. size_t data_size = lengths[grad_index] * sizeof(float);
  594. dst_size = data_size;
  595. src_size = data_size;
  596. dst_data = reduced_data->data() + grad_offset;
  597. src_data = const_cast<float *>(grads);
  598. MS_EXCEPTION_IF_NULL(dst_data);
  599. auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  600. if (ret != 0) {
  601. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  602. return;
  603. }
  604. // Fill the reduced indice
  605. int64_t indice_offset = grad_offset + lengths[grad_index];
  606. data_size = lengths[indice_index] * sizeof(float);
  607. float *indice_data = reduced_data->data() + indice_offset;
  608. dst_size = data_size;
  609. src_size = data_size;
  610. dst_data = indice_data;
  611. src_data = indices;
  612. MS_EXCEPTION_IF_NULL(dst_data);
  613. MS_EXCEPTION_IF_NULL(src_data);
  614. ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  615. if (ret != 0) {
  616. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  617. return;
  618. }
  619. }
  620. void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
  621. int cmd, int64_t) {
  622. KVMessage kvs;
  623. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  624. *kvs.mutable_values() = {vals.begin(), vals.end()};
  625. *kvs.mutable_len() = {lens.begin(), lens.end()};
  626. MS_LOG(INFO) << "the result is:" << embedding_table_ranges_.count(keys[0]);
  627. if (embedding_table_ranges_.count(keys[0])) {
  628. if (cmd == kInitWeightsCmd) {
  629. SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {});
  630. } else {
  631. std::string kv_data = kvs.SerializeAsString();
  632. #ifdef __APPLE__
  633. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  634. #else
  635. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  636. #endif
  637. size_t dest_size = kv_data.length();
  638. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  639. if (ret != 0) {
  640. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  641. return;
  642. }
  643. worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), cmd);
  644. }
  645. } else {
  646. SendForPush(cmd, kvs, round_robin_partitioner_, {});
  647. }
  648. }
  649. void Worker::PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
  650. size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size) {
  651. KVMessage kvs;
  652. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  653. *kvs.mutable_values() = {vals.begin(), vals.end()};
  654. *kvs.mutable_len() = {lens.begin(), lens.end()};
  655. if (embedding_table_ranges_.count(keys[0])) {
  656. std::map<int64_t, int64_t> attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}};
  657. SendForPush(kPushCmd, kvs, sparse_partitioner_, attrs);
  658. } else {
  659. SendForPush(kPushCmd, kvs, round_robin_partitioner_, {});
  660. }
  661. }
  662. void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens, int cmd,
  663. int64_t priority) {
  664. MS_EXCEPTION_IF_NULL(vals);
  665. KVMessage kvs;
  666. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  667. if (embedding_table_ranges_.count(keys[0])) {
  668. SendForPull(cmd, kvs, broadcast_partitioner_, {}, vals, lens);
  669. } else {
  670. SendForPull(cmd, kvs, round_robin_partitioner_, {}, vals, lens);
  671. }
  672. }
  673. void Worker::LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
  674. const std::map<int64_t, int64_t> &) {
  675. MS_EXCEPTION_IF_NULL(partition);
  676. const Key &key = send.key();
  677. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
  678. partition->resize(ranges.size());
  679. for (size_t i = 0; i < ranges.size(); i++) {
  680. const EmbeddingTableShardMetadata &range = ranges[i];
  681. const auto &begin = range.begin();
  682. const auto &end = range.end();
  683. mindspore::HashSet<int32_t> unique_ids;
  684. auto &kvs = partition->at(i).second;
  685. kvs.set_key(key);
  686. std::for_each(send.keys().begin(), send.keys().end(), [&](int32_t lookup_id) {
  687. if (lookup_id >= SizeToInt(begin) && lookup_id <= SizeToInt(end)) {
  688. unique_ids.insert(lookup_id);
  689. }
  690. });
  691. MS_LOG(DEBUG) << "The unique ids size is:" << unique_ids.size();
  692. for (const auto &lookup_id : unique_ids) {
  693. kvs.add_keys(lookup_id);
  694. kvs.add_values(0.0f);
  695. }
  696. if (kvs.keys().empty()) {
  697. partition->at(i).first = false;
  698. } else {
  699. partition->at(i).first = true;
  700. }
  701. }
  702. }
  703. void Worker::SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
  704. const std::map<int64_t, int64_t> &attrs) {
  705. MS_EXCEPTION_IF_NULL(partition);
  706. // Init variables
  707. float *data = const_cast<float *>(send.values().data());
  708. if (attrs.count(kGradIndex) == 0 || attrs.count(kIndiceIndex) == 0 || attrs.count(kFirstDimSize) == 0 ||
  709. attrs.count(kOutDimSize) == 0) {
  710. MS_LOG(EXCEPTION) << "Invalid attrs keys";
  711. }
  712. auto iter = attrs.find(kGradIndex);
  713. size_t grad_index = static_cast<size_t>(iter->second);
  714. iter = attrs.find(kIndiceIndex);
  715. size_t indice_index = static_cast<size_t>(iter->second);
  716. iter = attrs.find(kFirstDimSize);
  717. size_t first_dim_size = static_cast<size_t>(iter->second);
  718. iter = attrs.find(kOutDimSize);
  719. size_t outer_dim_size = static_cast<size_t>(iter->second);
  720. size_t grad_size = send.len()[SizeToInt(grad_index)];
  721. size_t indice_size = send.len()[SizeToInt(indice_index)];
  722. size_t segment_size = grad_size / indice_size;
  723. size_t grad_offset = 0;
  724. size_t indice_offset = 0;
  725. for (size_t i = 0; i < grad_index; i++) {
  726. grad_offset += send.len()[i];
  727. }
  728. for (size_t j = 0; j < indice_index; j++) {
  729. indice_offset += send.len()[j];
  730. }
  731. float *grad_data = data + grad_offset;
  732. void *indice_data_temp = data + indice_offset;
  733. int *indice_data = reinterpret_cast<int *>(indice_data_temp);
  734. // Build the mappings of indice to gradient
  735. std::vector<std::pair<int, float *>> indice_to_grads;
  736. for (size_t i = 0; i < indice_size; i++) {
  737. int indice = indice_data[i];
  738. float *grad = grad_data + i * segment_size;
  739. indice_to_grads.push_back(std::make_pair(indice, grad));
  740. }
  741. const Key &key = send.keys()[0];
  742. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
  743. partition->resize(ranges.size());
  744. // Construct reduced sparse data for each server
  745. for (size_t i = 0; i < ranges.size(); i++) {
  746. const EmbeddingTableShardMetadata &range = ranges[i];
  747. const auto &begin = range.begin();
  748. const auto &end = range.end();
  749. auto &kvs = partition->at(i).second;
  750. *kvs.mutable_keys() = {send.keys().begin(), send.keys().end()};
  751. *kvs.mutable_len() = {send.len().begin(), send.len().end()};
  752. // Prepare the sparse gradient and indice
  753. std::vector<int> indice_ids;
  754. mindspore::HashSet<int> distinct_ids;
  755. for (size_t j = 0; j < indice_size; j++) {
  756. size_t indice = static_cast<size_t>(indice_data[j]);
  757. if (indice >= begin && indice <= end) {
  758. indice_ids.push_back(indice);
  759. distinct_ids.insert(indice);
  760. }
  761. }
  762. size_t indices_size = indice_ids.size();
  763. if (indices_size > 0) {
  764. size_t partition_segment_size = indices_size * segment_size;
  765. std::vector<float> src_grad_data(partition_segment_size);
  766. std::vector<int> src_indice_data(indices_size);
  767. PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data.data(),
  768. src_indice_data.data());
  769. // Reduce the sparse gradient and indice
  770. std::vector<float> new_grad(partition_segment_size);
  771. std::vector<int> new_indices(indices_size);
  772. mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size});
  773. Util::ReduceSparseGradient(src_grad_data.data(), src_indice_data.data(), indices_size, segment_size,
  774. first_dim_size, outer_dim_size, &unique_sparse_grad);
  775. // Update the length of reduce sparse gradient and indice
  776. std::vector<int> reduced_lens = {kvs.len().begin(), kvs.len().end()};
  777. reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
  778. reduced_lens[indice_index] = unique_sparse_grad.indices_size_;
  779. // Build the sparse value to be sent
  780. size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus<int>());
  781. std::vector<float> reduced_data(total_size, 0);
  782. BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_,
  783. unique_sparse_grad.indices_, &reduced_data);
  784. *kvs.mutable_len() = {reduced_lens.begin(), reduced_lens.end()};
  785. *kvs.mutable_values() = {reduced_data.begin(), reduced_data.end()};
  786. }
  787. if (indices_size == 0) {
  788. std::vector<float> no_keys;
  789. std::vector<float> no_vals;
  790. std::vector<float> no_lens;
  791. no_keys.push_back(key);
  792. no_vals.push_back(kGradValue);
  793. *kvs.mutable_values() = {no_vals.begin(), no_vals.end()};
  794. *kvs.mutable_len() = {no_lens.begin(), no_lens.end()};
  795. }
  796. partition->at(i).first = true;
  797. }
  798. }
  799. void Worker::RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  800. const std::map<int64_t, int64_t> &) {
  801. MS_EXCEPTION_IF_NULL(partition);
  802. partition->resize(LongToSize(server_num_));
  803. auto keys = send.keys();
  804. auto values = send.values();
  805. auto lens = send.len();
  806. MS_LOG(INFO) << "the key size is:" << send.keys_size() << " the values size is:" << send.values_size()
  807. << " the lens:" << send.len_size();
  808. size_t len;
  809. Key param_key;
  810. for (int i = 0; i < send.keys_size(); i++) {
  811. param_key = keys[i];
  812. int64_t server_id = key_to_server_id_[param_key];
  813. if (!partition->at(LongToUlong(server_id)).first) {
  814. partition->at(LongToUlong(server_id)).first = true;
  815. }
  816. KVMessage &server_kv_pairs = partition->at(LongToUlong(server_id)).second;
  817. server_kv_pairs.add_keys(param_key);
  818. if (values.empty()) {
  819. continue;
  820. }
  821. len = lens[i];
  822. int64_t offset = std::accumulate(lens.begin(), lens.begin() + i, 0);
  823. auto val_begin = values.begin() + offset;
  824. auto val_end = val_begin + len;
  825. for (auto it = val_begin; it != val_end; ++it) {
  826. server_kv_pairs.add_values(*it);
  827. }
  828. server_kv_pairs.add_len(len);
  829. }
  830. }
  831. void Worker::WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
  832. const std::map<int64_t, int64_t> &) {
  833. MS_EXCEPTION_IF_NULL(partition);
  834. partition->resize(LongToSize(server_num_));
  835. auto keys = send.keys();
  836. auto values = send.values();
  837. auto lens = send.len();
  838. int32_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]];
  839. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[keys[0]]);
  840. for (size_t i = 0; i < ranges.size(); i++) {
  841. size_t offset_begin = ranges[i].begin() * col_cnt;
  842. size_t offset_end = (ranges[i].end() + 1) * col_cnt;
  843. KVMessage kvs;
  844. *kvs.mutable_keys() = keys;
  845. *kvs.mutable_values() = {values.begin() + offset_begin, values.begin() + offset_end};
  846. kvs.add_len(offset_end - offset_begin);
  847. partition->at(i).first = true;
  848. partition->at(i).second = kvs;
  849. }
  850. }
  851. void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  852. const std::map<int64_t, int64_t> &) {
  853. MS_EXCEPTION_IF_NULL(partition);
  854. const float *embedding_vals = send.values().data();
  855. const uint64_t *lookup_ids = send.len().data();
  856. size_t val_size = IntToSize(send.values_size());
  857. size_t id_size = IntToSize(send.len_size());
  858. if (id_size == 0) {
  859. MS_LOG(EXCEPTION) << "The id size is 0.";
  860. return;
  861. }
  862. size_t embedding_dim = val_size / id_size;
  863. const Key &key = send.keys()[0];
  864. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
  865. partition->resize(ranges.size());
  866. for (size_t i = 0; i < ranges.size(); i++) {
  867. const EmbeddingTableShardMetadata &range = ranges[i];
  868. const auto &begin = range.begin();
  869. const auto &end = range.end();
  870. auto &kvs = partition->at(i).second;
  871. kvs.add_keys(key);
  872. for (size_t j = 0; j < id_size; j++) {
  873. auto lookup_id = lookup_ids[j];
  874. if (lookup_id >= begin && lookup_id <= end) {
  875. kvs.add_keys(lookup_id);
  876. for (size_t k = 0; k < embedding_dim; k++) {
  877. kvs.add_values(embedding_vals[j * embedding_dim + k]);
  878. }
  879. }
  880. }
  881. if (kvs.keys_size() <= 1) {
  882. partition->at(i).first = false;
  883. } else {
  884. partition->at(i).first = true;
  885. }
  886. }
  887. }
  888. void Worker::BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  889. const std::map<int64_t, int64_t> &) {
  890. MS_EXCEPTION_IF_NULL(partition);
  891. partition->resize(LongToSize(server_num_));
  892. for (size_t i = 0; i < LongToSize(server_num_); i++) {
  893. partition->at(i).first = true;
  894. partition->at(i).second = send;
  895. }
  896. }
  897. void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
  898. const std::map<int64_t, int64_t> &attrs) {
  899. PartitionKVMessages messages;
  900. partitioner(send, &messages, attrs);
  901. std::vector<uint32_t> rank_ids;
  902. std::vector<DataPtr> data;
  903. std::vector<size_t> sizes;
  904. for (size_t i = 0; i < messages.size(); i++) {
  905. if (messages.at(i).first) {
  906. rank_ids.push_back(i);
  907. std::string kv_data = messages.at(i).second.SerializeAsString();
  908. #ifdef __APPLE__
  909. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  910. #else
  911. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  912. #endif
  913. size_t dest_size = kv_data.length();
  914. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  915. if (ret != 0) {
  916. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  917. return;
  918. }
  919. data.push_back(res);
  920. sizes.push_back(kv_data.length());
  921. }
  922. }
  923. worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd);
  924. }
  925. void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
  926. const std::map<int64_t, int64_t> &, std::vector<float> *vals, std::vector<int> *lens) {
  927. MS_EXCEPTION_IF_NULL(vals);
  928. PartitionKVMessages messages;
  929. partitioner(send, &messages, {});
  930. std::vector<uint32_t> rank_ids;
  931. std::vector<DataPtr> data;
  932. std::vector<size_t> sizes;
  933. for (size_t i = 0; i < messages.size(); i++) {
  934. if (messages.at(i).first) {
  935. rank_ids.push_back(i);
  936. std::string kv_data = messages.at(i).second.SerializeAsString();
  937. #ifdef __APPLE__
  938. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  939. #else
  940. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  941. #endif
  942. size_t dest_size = kv_data.length();
  943. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  944. if (ret != 0) {
  945. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  946. return;
  947. }
  948. data.push_back(res);
  949. sizes.push_back(kv_data.length());
  950. }
  951. }
  952. std::vector<VectorPtr> resp;
  953. worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp);
  954. vals->clear();
  955. for (size_t i = 0; i < resp.size(); ++i) {
  956. KVMessage message;
  957. CHECK_RETURN_TYPE(message.ParseFromArray(resp.at(i)->data(), SizeToInt(resp.at(i)->size())));
  958. std::copy(message.values().begin(), message.values().end(), std::back_inserter(*vals));
  959. if (lens) {
  960. lens->clear();
  961. std::copy(message.len().begin(), message.len().end(), std::back_inserter(*lens));
  962. }
  963. }
  964. }
  965. } // namespace ps
  966. } // namespace mindspore