You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

worker.cc 39 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/worker.h"
  17. #include "pipeline/jit/pipeline.h"
  18. namespace mindspore {
  19. namespace ps {
  20. void Worker::Run() {
  21. std::lock_guard<std::mutex> lock(running_mutex_);
  22. server_num_ = PSContext::instance()->initial_server_num();
  23. if (running_) {
  24. MS_LOG(INFO) << "'Worker is already running.";
  25. return;
  26. }
  27. if (!PSContext::instance()->is_worker()) {
  28. MS_LOG(EXCEPTION) << "The role is not worker.";
  29. }
  30. Initialize();
  31. worker_node_.RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
  32. MS_LOG(ERROR) << "Trigger timeout event: SCHEDULER_TIMEOUT begin to exit the system!";
  33. this->Finalize();
  34. exit(0);
  35. });
  36. worker_node_.RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [this]() {
  37. MS_LOG(ERROR) << "Trigger timeout event: NODE_TIMEOUT begin to exit the system!";
  38. this->Finalize();
  39. exit(0);
  40. });
  41. MS_LOG(INFO) << "Worker starts connecting to scheduler and server...";
  42. worker_node_.Start();
  43. MS_LOG(INFO) << "Worker connected successfully.";
  44. running_ = true;
  45. }
  46. void Worker::Push(const std::vector<size_t> &keys, std::vector<uintptr_t> addrs, const ShapeVector &sizes) {
  47. if (keys.size() == 0) {
  48. MS_LOG(EXCEPTION) << "key size should be greater than zero";
  49. }
  50. if (key_to_optimId_.count(keys[0]) == 0) {
  51. MS_LOG(EXCEPTION) << "no optim id found for key" << keys[0];
  52. }
  53. Key key = keys[0];
  54. int64_t optim_id = key_to_optimId_[key];
  55. MS_LOG(INFO) << "The key is:" << key << " the optim_id:" << optim_id;
  56. bool is_sparse = false;
  57. if (optim_id == 1 || optim_id == kSparseLazyAdamIndex || optim_id == kSparseFtrlIndex) {
  58. is_sparse = true;
  59. }
  60. int64_t grad_index = -1;
  61. int64_t indice_index = -1;
  62. // Sparse adam gradient
  63. if (optim_id == 1 || optim_id == kSparseLazyAdamIndex) {
  64. grad_index = kSparseGradIndex;
  65. indice_index = kSparseIndiceIndex;
  66. // Sparse ftrl gradient
  67. } else if (optim_id == kSparseFtrlIndex) {
  68. grad_index = 0;
  69. indice_index = 1;
  70. }
  71. size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0, std::plus<int64_t>());
  72. std::vector<float> total_buffer(total_size, 0);
  73. size_t offset = 0;
  74. for (size_t i = 0; i < sizes.size(); i++) {
  75. void *dst_data = total_buffer.data() + offset / sizeof(float);
  76. void *src_data = reinterpret_cast<void *>(addrs[i]);
  77. MS_EXCEPTION_IF_NULL(dst_data);
  78. MS_EXCEPTION_IF_NULL(src_data);
  79. size_t size = sizes[i] * sizeof(float);
  80. size_t dest_size = size;
  81. size_t src_size = size;
  82. auto ret = memcpy_s(dst_data, dest_size, src_data, src_size);
  83. if (ret != 0) {
  84. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  85. return;
  86. }
  87. offset += size;
  88. }
  89. MS_LOG(INFO) << "The total size is:" << total_size;
  90. while (running_ && (!IsReadyForPush(keys[0]))) {
  91. continue;
  92. }
  93. std::vector<int> sizes_int;
  94. (void)std::transform(sizes.begin(), sizes.end(), std::back_inserter(sizes_int),
  95. [](const int64_t &value) { return static_cast<int>(value); });
  96. if (!is_sparse) {
  97. PushData(std::vector<Key>(keys), total_buffer, std::vector<int>(sizes_int), kPushCmd);
  98. } else {
  99. std::vector<int64_t> &var_shape = key_to_optim_shapes_[key][0];
  100. int64_t first_dim_size = var_shape[0];
  101. int64_t outer_dim_size = std::accumulate(var_shape.begin() + 1, var_shape.end(), 1, std::multiplies<int64_t>());
  102. MS_LOG(DEBUG) << "The keys:" << keys << " the total_buffer:" << total_buffer << " the sizes_int:" << sizes_int
  103. << " the grad_index:" << grad_index << " the indice_index:" << indice_index
  104. << " the first_dim_size:" << first_dim_size << " the outer_dim_size" << outer_dim_size;
  105. PushSparseData(std::vector<Key>(keys), total_buffer, std::vector<int>(sizes_int), LongToSize(grad_index),
  106. LongToSize(indice_index), LongToSize(first_dim_size), LongToSize(outer_dim_size));
  107. }
  108. }
  109. void Worker::Pull(const size_t key, void *dev_addr, const size_t size) {
  110. MS_EXCEPTION_IF_NULL(dev_addr);
  111. std::vector<float> variables(size / sizeof(float), 0);
  112. while (running_ && (!IsReadyForPull(key))) {
  113. continue;
  114. }
  115. PullData({key}, &variables, nullptr, kPullCmd);
  116. MS_LOG(DEBUG) << "The variables:" << variables << " the size is:" << size;
  117. size_t dst_size = size;
  118. size_t src_size = size;
  119. auto ret = memcpy_s(dev_addr, dst_size, variables.data(), src_size);
  120. if (ret != 0) {
  121. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  122. return;
  123. }
  124. }
  125. size_t Worker::SetParamKey(const std::string &param_name) {
  126. size_t key = UINT64_MAX;
  127. if (param_to_key_.count(param_name)) {
  128. key = param_to_key_[param_name];
  129. MS_LOG(INFO) << param_name << " key is already set: key value is " << key;
  130. } else {
  131. key = key_cnt_++;
  132. param_to_key_[param_name] = key;
  133. MS_LOG(INFO) << "Set key " << key << " for parameter " << param_name;
  134. }
  135. return key;
  136. }
  137. size_t Worker::GetParamKey(const std::string &param_name) {
  138. size_t key = kInvalidKey;
  139. if (param_to_key_.find(param_name) != param_to_key_.end()) {
  140. key = param_to_key_[param_name];
  141. MS_LOG(DEBUG) << "Get key of parameter " << param_name << " key is " << key;
  142. }
  143. return key;
  144. }
  145. void Worker::SetParamInitInServer(const std::string &param_name, bool init_in_server) {
  146. MS_LOG(DEBUG) << "Set parameter " << param_name << " init_in_server:" << init_in_server;
  147. param_to_init_in_server_[param_name] = init_in_server;
  148. }
  149. bool Worker::GetParamInitInServer(const std::string &param_name) {
  150. if (param_to_init_in_server_.count(param_name) == 0) {
  151. return false;
  152. }
  153. return param_to_init_in_server_[param_name];
  154. }
  155. void Worker::SetKeyOptimId(size_t key, const std::string &optimizer_name) {
  156. MS_LOG(INFO) << "SetKeyOptimId key is:" << key << " optimizer_name:" << optimizer_name;
  157. key_to_optimId_[key] = Util::optimizer_id(optimizer_name);
  158. }
  159. void Worker::SetOptimInputShapes(size_t key, const ShapeVector &shape) {
  160. if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) {
  161. key_to_optim_shapes_[key] = {shape};
  162. } else {
  163. key_to_optim_shapes_[key].push_back(shape);
  164. }
  165. }
  166. void Worker::AddEmbeddingTable(const Key &key, const size_t &row_count) {
  167. bool has_init = IsKeyInit(key);
  168. if (has_init) {
  169. return;
  170. }
  171. uint64_t begin = 0;
  172. uint64_t end = 0;
  173. for (int64_t i = 0; i < server_num_; i++) {
  174. size_t local_row_cnt = LongToSize(Util::LocalShard(row_count, i, server_num_));
  175. MS_LOG(DEBUG) << "The row_count:" << row_count << " the local_row_cnt:" << local_row_cnt;
  176. if (i == 0) {
  177. end = local_row_cnt - 1;
  178. } else {
  179. begin = end + 1;
  180. end += local_row_cnt;
  181. }
  182. EmbeddingTableShardMetadata range(begin, end);
  183. if (embedding_table_ranges_.count(key) == 0) {
  184. embedding_table_ranges_[key] = std::make_shared<std::vector<EmbeddingTableShardMetadata>>();
  185. MS_EXCEPTION_IF_NULL(embedding_table_ranges_[key]);
  186. }
  187. embedding_table_ranges_[key]->push_back(range);
  188. }
  189. embedding_row_cnt_[key] = row_count;
  190. }
  191. bool Worker::InitPSEmbeddingTable(const size_t &key, const std::vector<size_t> &input_shape,
  192. const std::vector<size_t> &indices_shape, const std::vector<size_t> &output_shape,
  193. const ParamInitInfoMessage &info) {
  194. bool has_init = IsKeyInit(key);
  195. if (has_init) {
  196. MS_LOG(DEBUG) << "The key embedding table of key " << key << " is initialized.";
  197. return true;
  198. }
  199. EmbeddingTableMeta embedding_table_meta;
  200. embedding_table_meta.set_key(key);
  201. *embedding_table_meta.mutable_input_shape() = {input_shape.begin(), input_shape.end()};
  202. *embedding_table_meta.mutable_indices_shape() = {indices_shape.begin(), indices_shape.end()};
  203. *embedding_table_meta.mutable_output_shape() = {output_shape.begin(), output_shape.end()};
  204. *embedding_table_meta.mutable_info() = info;
  205. std::string kv_data = embedding_table_meta.SerializeAsString();
  206. #ifdef __APPLE__
  207. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  208. #else
  209. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  210. #endif
  211. size_t dest_size = kv_data.length();
  212. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  213. if (ret != 0) {
  214. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  215. return false;
  216. }
  217. return worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), kInitEmbeddingsCmd);
  218. }
  219. void Worker::InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor) {
  220. MS_EXCEPTION_IF_NULL(tensor);
  221. MS_EXCEPTION_IF_NULL(input_node);
  222. auto pk_node = input_node->cast<ParameterPtr>();
  223. MS_EXCEPTION_IF_NULL(pk_node);
  224. const std::string &param_name = pk_node->fullname_with_scope();
  225. void *param_data = tensor->data_c();
  226. size_t param_size = LongToSize(tensor->data().nbytes());
  227. size_t param_key = GetParamKey(param_name);
  228. if (param_key == kInvalidKey) {
  229. MS_LOG(DEBUG) << "Parameter " << param_name << " has no key assigned.";
  230. return;
  231. }
  232. bool init_in_server = false;
  233. auto param_info_ptr = pk_node->param_info();
  234. if (param_info_ptr != nullptr && param_info_ptr->init_in_server()) {
  235. init_in_server = true;
  236. }
  237. SetParamInitInServer(param_name, init_in_server);
  238. bool init = IsKeyInit(param_key);
  239. if (!init) {
  240. MS_LOG(DEBUG) << "Init parameter key " << param_key << " and optimizer in parameter server side for " << param_name
  241. << ", whether init in server: " << init_in_server;
  242. AddKeyToServerId(param_key);
  243. if (!PsDataPrefetch::GetInstance().cache_enable()) {
  244. if (!init_in_server) {
  245. if (param_size > INT_MAX) {
  246. MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is "
  247. << param_size;
  248. }
  249. InitPSParamData({param_key}, param_data, param_size);
  250. }
  251. InitPSOptimId(param_key);
  252. InitPSOptimInputShapes(param_key);
  253. }
  254. }
  255. }
  256. bool Worker::DoPSEmbeddingLookup(const Key &key, const std::vector<int> &lookup_ids, std::vector<float> *lookup_result,
  257. int64_t cmd) {
  258. MS_EXCEPTION_IF_NULL(lookup_result);
  259. EmbeddingTableLookup embedding_table_lookup;
  260. embedding_table_lookup.set_key(key);
  261. *embedding_table_lookup.mutable_keys() = {lookup_ids.begin(), lookup_ids.end()};
  262. PartitionEmbeddingMessages messages;
  263. lookup_partitioner_(embedding_table_lookup, &messages, {});
  264. std::vector<uint32_t> rank_ids;
  265. std::vector<DataPtr> data;
  266. std::vector<size_t> sizes;
  267. for (size_t i = 0; i < messages.size(); i++) {
  268. if (messages.at(i).first) {
  269. rank_ids.push_back(i);
  270. std::string kv_data = messages.at(i).second.SerializeAsString();
  271. #ifdef __APPLE__
  272. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  273. #else
  274. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  275. #endif
  276. size_t dest_size = kv_data.length();
  277. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  278. if (ret != 0) {
  279. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  280. return false;
  281. }
  282. data.push_back(res);
  283. sizes.push_back(kv_data.length());
  284. }
  285. }
  286. std::vector<VectorPtr> resp;
  287. if (!worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, LongToInt(cmd), &resp)) {
  288. MS_LOG(ERROR) << "Worker send failed!";
  289. return false;
  290. }
  291. int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size());
  292. mindspore::HashMap<Key, std::shared_ptr<std::pair<float *, int64_t>>> id_addr_map;
  293. std::shared_ptr<std::vector<float>> values = std::make_shared<std::vector<float>>();
  294. std::shared_ptr<std::vector<Key>> keys = std::make_shared<std::vector<Key>>();
  295. int64_t value_offset = 0;
  296. for (size_t i = 0; i < resp.size(); ++i) {
  297. KVMessage message;
  298. CHECK_RETURN_TYPE(message.ParseFromArray(resp.at(i)->data(), resp.at(i)->size()));
  299. for (auto j = 0; j < message.values_size(); j++) {
  300. values->push_back(message.values(j));
  301. }
  302. for (auto k = 0; k < message.keys_size(); k++) {
  303. const Key &message_key = message.keys(k);
  304. keys->push_back(message_key);
  305. }
  306. }
  307. for (size_t i = 0; i < keys->size(); i++) {
  308. const Key &map_key = keys->at(i);
  309. float *addr = values->data() + value_offset;
  310. value_offset += single_id_len;
  311. id_addr_map[map_key] = std::make_shared<std::pair<float *, int64_t>>(std::make_pair(addr, single_id_len));
  312. }
  313. float *result_addr = lookup_result->data();
  314. MS_EXCEPTION_IF_NULL(result_addr);
  315. int64_t offset = 0;
  316. size_t dst_size = 0;
  317. size_t src_size = 0;
  318. void *dst_data = nullptr;
  319. void *src_data = nullptr;
  320. for (size_t i = 0; i < lookup_ids.size(); i++) {
  321. if (id_addr_map.count(lookup_ids[i]) == 0) {
  322. offset += single_id_len;
  323. continue;
  324. }
  325. const Key &id_key = static_cast<Key>(lookup_ids[i]);
  326. auto &pair = id_addr_map[id_key];
  327. size_t size = LongToSize(single_id_len * sizeof(float));
  328. dst_size = size;
  329. src_size = size;
  330. dst_data = result_addr + offset;
  331. src_data = pair->first;
  332. MS_EXCEPTION_IF_NULL(dst_data);
  333. MS_EXCEPTION_IF_NULL(src_data);
  334. auto mem_ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  335. if (mem_ret != 0) {
  336. MS_LOG(ERROR) << "memcpy_s error, errorno(" << mem_ret << ")";
  337. return false;
  338. }
  339. offset += single_id_len;
  340. }
  341. return true;
  342. }
  343. bool Worker::UpdateEmbeddingTable(const std::vector<Key> &keys, const std::vector<int> &lookup_ids,
  344. const std::vector<float> &vals) {
  345. KVMessage kvs;
  346. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  347. *kvs.mutable_len() = {lookup_ids.begin(), lookup_ids.end()};
  348. *kvs.mutable_values() = {vals.begin(), vals.end()};
  349. PartitionKVMessages messages;
  350. update_embedding_partitioner_(kvs, &messages, {});
  351. std::vector<uint32_t> rank_ids;
  352. std::vector<DataPtr> data;
  353. std::vector<size_t> sizes;
  354. for (size_t i = 0; i < messages.size(); i++) {
  355. if (messages.at(i).first) {
  356. rank_ids.push_back(i);
  357. std::string kv_data = messages.at(i).second.SerializeAsString();
  358. #ifdef __APPLE__
  359. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  360. #else
  361. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  362. #endif
  363. size_t dest_size = kv_data.length();
  364. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  365. if (ret != 0) {
  366. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  367. return false;
  368. }
  369. data.push_back(res);
  370. sizes.push_back(kv_data.length());
  371. }
  372. }
  373. return worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, LongToInt(kUpdateEmbeddingsCmd));
  374. }
  375. void Worker::Finalize() {
  376. if (running_) {
  377. MS_LOG(INFO) << "Worker starts finalizing...";
  378. KVMessage kvs;
  379. kvs.add_keys(0);
  380. kvs.add_values(0.0f);
  381. std::string kv_data = kvs.SerializeAsString();
  382. #ifdef __APPLE__
  383. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  384. #else
  385. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  386. #endif
  387. size_t dest_size = kv_data.length();
  388. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  389. if (ret != 0) {
  390. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  391. return;
  392. }
  393. worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), kFinalizeCmd);
  394. worker_node_.Finish();
  395. worker_node_.Stop();
  396. running_ = false;
  397. MS_LOG(INFO) << "Worker finalized successfully.";
  398. }
  399. }
  400. void Worker::Initialize() {
  401. lookup_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  402. LookupIdPartitioner(send, partition, attrs);
  403. };
  404. worker_init_embedding_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  405. WorkerInitEmbeddingPartitioner(send, partition, attrs);
  406. };
  407. round_robin_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  408. RoundRobinPartitioner(send, partition, attrs);
  409. };
  410. sparse_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  411. SparsePartitioner(send, partition, attrs);
  412. };
  413. update_embedding_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  414. UpdateEmbeddingPartitioner(send, partition, attrs);
  415. };
  416. broadcast_partitioner_ = [this](auto &&send, auto &&partition, auto &&attrs) {
  417. BroadcastPartitioner(send, partition, attrs);
  418. };
  419. }
  420. bool Worker::IsKeyInit(const size_t key) {
  421. if (init_keys_.find(key) == init_keys_.end() || !init_keys_[key]) {
  422. return false;
  423. }
  424. return true;
  425. }
  426. void Worker::AddKeyToServerId(const Key &key) { AddKeyByHashMod(key); }
  427. void Worker::AddKeyByHashMod(const Key &key) {
  428. if (server_num_ == 0) {
  429. MS_LOG(EXCEPTION) << "Server number is invalid:0";
  430. }
  431. key_to_server_id_[key] = static_cast<int64_t>(key % server_num_);
  432. MS_LOG(DEBUG) << "The server id of key " << key << " is " << key_to_server_id_[key];
  433. }
  434. void Worker::InitPSOptimId(const size_t param_key) {
  435. MS_LOG(INFO) << "InitPSOptimId key is:" << param_key;
  436. if (key_to_optimId_.count(param_key) == 0) {
  437. MS_LOG(EXCEPTION) << "Can't find optimizer id of parameter key " << param_key;
  438. }
  439. int64_t optim_id = key_to_optimId_[param_key];
  440. std::vector<Key> keys = {param_key};
  441. std::vector<float> optim_id_vals = {static_cast<float>(optim_id)};
  442. std::vector<int> optim_id_lens = {SizeToInt(optim_id_vals.size())};
  443. MS_LOG(INFO) << "The keys is" << keys << " the optim_id_vals is: " << optim_id_vals
  444. << " optim_id_lens is:" << optim_id_lens;
  445. PushData(keys, optim_id_vals, optim_id_lens, kInitWeightToOptimIdCmd);
  446. }
  447. void Worker::InitPSOptimInputShapes(const size_t key) {
  448. std::vector<Key> keys;
  449. std::vector<int> shape_len;
  450. std::vector<float> all_shape;
  451. std::vector<ShapeVector> shapes = key_to_optim_shapes_[key];
  452. for (auto shape : shapes) {
  453. keys.push_back(key);
  454. if (shape.size() == 0) {
  455. shape_len.push_back(1);
  456. all_shape.push_back(1);
  457. } else {
  458. shape_len.push_back(SizeToLong(shape.size()));
  459. std::transform(shape.begin(), shape.end(), std::back_inserter(all_shape),
  460. [](size_t dim) -> float { return static_cast<float>(dim); });
  461. }
  462. }
  463. MS_LOG(INFO) << "keys:" << keys;
  464. MS_LOG(INFO) << "shape_len:" << shape_len;
  465. MS_LOG(INFO) << "all_shape:" << all_shape;
  466. if (!init_keys_[key]) {
  467. init_keys_[key] = true;
  468. }
  469. PushData(keys, all_shape, shape_len, kInitOptimInputsShapeCmd);
  470. }
  471. void Worker::InitPSParamData(const std::vector<size_t> &keys, void *const origin_addr, size_t size) {
  472. MS_EXCEPTION_IF_NULL(origin_addr);
  473. std::vector<float> addr{reinterpret_cast<float *>(origin_addr),
  474. reinterpret_cast<float *>(origin_addr) + size / sizeof(float)};
  475. std::vector<Key> key(keys);
  476. std::vector<int> lens;
  477. lens.push_back(addr.size());
  478. MS_LOG(INFO) << "the keys are:" << keys;
  479. MS_LOG(INFO) << "the values are:" << addr;
  480. PushData(key, addr, lens, kInitWeightsCmd);
  481. init_keys_[key[0]] = true;
  482. }
  483. bool Worker::IsReadyForPush(const Key &key) {
  484. std::vector<float> result(1, 0);
  485. PullData({key}, &result, nullptr, kCheckReadyForPushCmd);
  486. MS_LOG(INFO) << "key:" << key;
  487. if (result[0] > 0) {
  488. MS_LOG(INFO) << "IsReadyForPush:";
  489. return true;
  490. } else {
  491. MS_LOG(INFO) << "IsReadyForPush:";
  492. return false;
  493. }
  494. }
  495. bool Worker::IsReadyForPull(const Key &key) {
  496. std::vector<float> result(1, 0);
  497. PullData({key}, &result, nullptr, kCheckReadyForPullCmd);
  498. if (result[0] > 0) {
  499. MS_LOG(INFO) << "IsReadyForPull";
  500. return true;
  501. } else {
  502. MS_LOG(INFO) << "IsReadyForPull";
  503. return false;
  504. }
  505. }
  506. void Worker::PrepareSparseGradient(const size_t, const size_t, const mindspore::HashSet<int> &distinct_ids,
  507. const std::vector<std::pair<int, float *>> &indice_to_grads, const int *all_indice,
  508. const size_t segment_size, float *gradient, int *indices) {
  509. MS_EXCEPTION_IF_NULL(all_indice);
  510. MS_EXCEPTION_IF_NULL(gradient);
  511. MS_EXCEPTION_IF_NULL(indices);
  512. size_t offset = 0;
  513. int64_t index = 0;
  514. size_t segment_data_size = segment_size * sizeof(float);
  515. size_t dst_size;
  516. size_t src_size;
  517. void *dst_data = nullptr;
  518. void *src_data = nullptr;
  519. for (auto &pair : indice_to_grads) {
  520. if (distinct_ids.count(pair.first) == 0) {
  521. continue;
  522. }
  523. indices[index++] = pair.first;
  524. dst_size = segment_data_size;
  525. src_size = segment_data_size;
  526. dst_data = gradient + offset;
  527. src_data = pair.second;
  528. MS_EXCEPTION_IF_NULL(dst_data);
  529. MS_EXCEPTION_IF_NULL(src_data);
  530. auto ret = memcpy_s(gradient + offset, dst_size, pair.second, src_size);
  531. if (ret != 0) {
  532. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  533. return;
  534. }
  535. offset += segment_size;
  536. }
  537. }
  538. void Worker::BuildSparseValue(const std::vector<int> &lengths, const size_t grad_index, const size_t indice_index,
  539. const float *original_data, const float *grads, int *indices,
  540. std::vector<float> *reduced_data) {
  541. MS_EXCEPTION_IF_NULL(original_data);
  542. MS_EXCEPTION_IF_NULL(grads);
  543. MS_EXCEPTION_IF_NULL(indices);
  544. MS_EXCEPTION_IF_NULL(reduced_data);
  545. int64_t offset = 0;
  546. size_t dst_size = 0;
  547. size_t src_size = 0;
  548. void *dst_data = nullptr;
  549. void *src_data = nullptr;
  550. for (size_t i = 0; i < lengths.size(); i++) {
  551. if (i != grad_index && i != indice_index) {
  552. size_t data_size = lengths[i] * sizeof(float);
  553. dst_size = data_size;
  554. src_size = data_size;
  555. dst_data = reduced_data->data() + offset;
  556. src_data = const_cast<float *>(original_data) + offset;
  557. MS_EXCEPTION_IF_NULL(dst_data);
  558. MS_EXCEPTION_IF_NULL(src_data);
  559. auto mem_ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  560. if (mem_ret != 0) {
  561. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << mem_ret << ")";
  562. return;
  563. }
  564. }
  565. offset += lengths[i];
  566. }
  567. // Fill the reduced gradient
  568. int64_t grad_offset = 0;
  569. for (size_t i = 0; i < grad_index; i++) {
  570. grad_offset += lengths[i];
  571. }
  572. size_t data_size = lengths[grad_index] * sizeof(float);
  573. dst_size = data_size;
  574. src_size = data_size;
  575. dst_data = reduced_data->data() + grad_offset;
  576. src_data = const_cast<float *>(grads);
  577. MS_EXCEPTION_IF_NULL(dst_data);
  578. auto ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  579. if (ret != 0) {
  580. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  581. return;
  582. }
  583. // Fill the reduced indice
  584. int64_t indice_offset = grad_offset + lengths[grad_index];
  585. data_size = lengths[indice_index] * sizeof(float);
  586. float *indice_data = reduced_data->data() + indice_offset;
  587. dst_size = data_size;
  588. src_size = data_size;
  589. dst_data = indice_data;
  590. src_data = indices;
  591. MS_EXCEPTION_IF_NULL(dst_data);
  592. MS_EXCEPTION_IF_NULL(src_data);
  593. ret = memcpy_s(dst_data, dst_size, src_data, src_size);
  594. if (ret != 0) {
  595. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  596. return;
  597. }
  598. }
  599. void Worker::PushData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
  600. int cmd, int64_t) {
  601. KVMessage kvs;
  602. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  603. *kvs.mutable_values() = {vals.begin(), vals.end()};
  604. *kvs.mutable_len() = {lens.begin(), lens.end()};
  605. MS_LOG(INFO) << "the result is:" << embedding_table_ranges_.count(keys[0]);
  606. if (embedding_table_ranges_.count(keys[0])) {
  607. if (cmd == kInitWeightsCmd) {
  608. SendForPush(cmd, kvs, worker_init_embedding_partitioner_, {});
  609. } else {
  610. std::string kv_data = kvs.SerializeAsString();
  611. #ifdef __APPLE__
  612. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  613. #else
  614. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  615. #endif
  616. size_t dest_size = kv_data.length();
  617. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  618. if (ret != 0) {
  619. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  620. return;
  621. }
  622. worker_node_.Broadcast(core::NodeRole::SERVER, res, kv_data.length(), cmd);
  623. }
  624. } else {
  625. SendForPush(cmd, kvs, round_robin_partitioner_, {});
  626. }
  627. }
  628. void Worker::PushSparseData(const std::vector<Key> &keys, const std::vector<float> &vals, const std::vector<int> &lens,
  629. size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size) {
  630. KVMessage kvs;
  631. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  632. *kvs.mutable_values() = {vals.begin(), vals.end()};
  633. *kvs.mutable_len() = {lens.begin(), lens.end()};
  634. if (embedding_table_ranges_.count(keys[0])) {
  635. std::map<int64_t, int64_t> attrs{{0, grad_index}, {1, indice_index}, {2, first_dim_size}, {3, outer_dim_size}};
  636. SendForPush(kPushCmd, kvs, sparse_partitioner_, attrs);
  637. } else {
  638. SendForPush(kPushCmd, kvs, round_robin_partitioner_, {});
  639. }
  640. }
  641. void Worker::PullData(const std::vector<Key> &keys, std::vector<float> *const vals, std::vector<int> *lens, int cmd,
  642. int64_t priority) {
  643. MS_EXCEPTION_IF_NULL(vals);
  644. KVMessage kvs;
  645. *kvs.mutable_keys() = {keys.begin(), keys.end()};
  646. if (embedding_table_ranges_.count(keys[0])) {
  647. SendForPull(cmd, kvs, broadcast_partitioner_, {}, vals, lens);
  648. } else {
  649. SendForPull(cmd, kvs, round_robin_partitioner_, {}, vals, lens);
  650. }
  651. }
  652. void Worker::LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition,
  653. const std::map<int64_t, int64_t> &) {
  654. MS_EXCEPTION_IF_NULL(partition);
  655. const Key &key = send.key();
  656. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
  657. partition->resize(ranges.size());
  658. for (size_t i = 0; i < ranges.size(); i++) {
  659. const EmbeddingTableShardMetadata &range = ranges[i];
  660. const auto &begin = range.begin();
  661. const auto &end = range.end();
  662. mindspore::HashSet<int32_t> unique_ids;
  663. auto &kvs = partition->at(i).second;
  664. kvs.set_key(key);
  665. std::for_each(send.keys().begin(), send.keys().end(), [&](int32_t lookup_id) {
  666. if (lookup_id >= SizeToInt(begin) && lookup_id <= SizeToInt(end)) {
  667. unique_ids.insert(lookup_id);
  668. }
  669. });
  670. MS_LOG(DEBUG) << "The unique ids size is:" << unique_ids.size();
  671. for (const auto &lookup_id : unique_ids) {
  672. kvs.add_keys(lookup_id);
  673. kvs.add_values(0.0f);
  674. }
  675. if (kvs.keys().empty()) {
  676. partition->at(i).first = false;
  677. } else {
  678. partition->at(i).first = true;
  679. }
  680. }
  681. }
  682. void Worker::SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition,
  683. const std::map<int64_t, int64_t> &attrs) {
  684. MS_EXCEPTION_IF_NULL(partition);
  685. // Init variables
  686. float *data = const_cast<float *>(send.values().data());
  687. if (attrs.count(kGradIndex) == 0 || attrs.count(kIndiceIndex) == 0 || attrs.count(kFirstDimSize) == 0 ||
  688. attrs.count(kOutDimSize) == 0) {
  689. MS_LOG(EXCEPTION) << "Invalid attrs keys";
  690. }
  691. auto iter = attrs.find(kGradIndex);
  692. size_t grad_index = static_cast<size_t>(iter->second);
  693. iter = attrs.find(kIndiceIndex);
  694. size_t indice_index = static_cast<size_t>(iter->second);
  695. iter = attrs.find(kFirstDimSize);
  696. size_t first_dim_size = static_cast<size_t>(iter->second);
  697. iter = attrs.find(kOutDimSize);
  698. size_t outer_dim_size = static_cast<size_t>(iter->second);
  699. size_t grad_size = send.len()[SizeToInt(grad_index)];
  700. size_t indice_size = send.len()[SizeToInt(indice_index)];
  701. size_t segment_size = grad_size / indice_size;
  702. size_t grad_offset = 0;
  703. size_t indice_offset = 0;
  704. for (size_t i = 0; i < grad_index; i++) {
  705. grad_offset += send.len()[i];
  706. }
  707. for (size_t j = 0; j < indice_index; j++) {
  708. indice_offset += send.len()[j];
  709. }
  710. float *grad_data = data + grad_offset;
  711. void *indice_data_temp = data + indice_offset;
  712. int *indice_data = reinterpret_cast<int *>(indice_data_temp);
  713. // Build the mappings of indice to gradient
  714. std::vector<std::pair<int, float *>> indice_to_grads;
  715. for (size_t i = 0; i < indice_size; i++) {
  716. int indice = indice_data[i];
  717. float *grad = grad_data + i * segment_size;
  718. indice_to_grads.push_back(std::make_pair(indice, grad));
  719. }
  720. const Key &key = send.keys()[0];
  721. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
  722. partition->resize(ranges.size());
  723. // Construct reduced sparse data for each server
  724. for (size_t i = 0; i < ranges.size(); i++) {
  725. const EmbeddingTableShardMetadata &range = ranges[i];
  726. const auto &begin = range.begin();
  727. const auto &end = range.end();
  728. auto &kvs = partition->at(i).second;
  729. *kvs.mutable_keys() = {send.keys().begin(), send.keys().end()};
  730. *kvs.mutable_len() = {send.len().begin(), send.len().end()};
  731. // Prepare the sparse gradient and indice
  732. std::vector<int> indice_ids;
  733. mindspore::HashSet<int> distinct_ids;
  734. for (size_t j = 0; j < indice_size; j++) {
  735. size_t indice = static_cast<size_t>(indice_data[j]);
  736. if (indice >= begin && indice <= end) {
  737. indice_ids.push_back(indice);
  738. distinct_ids.insert(indice);
  739. }
  740. }
  741. size_t indices_size = indice_ids.size();
  742. if (indices_size > 0) {
  743. size_t partition_segment_size = indices_size * segment_size;
  744. std::vector<float> src_grad_data(partition_segment_size);
  745. std::vector<int> src_indice_data(indices_size);
  746. PrepareSparseGradient(begin, end, distinct_ids, indice_to_grads, indice_data, segment_size, src_grad_data.data(),
  747. src_indice_data.data());
  748. // Reduce the sparse gradient and indice
  749. std::vector<float> new_grad(partition_segment_size);
  750. std::vector<int> new_indices(indices_size);
  751. mindspore::kernel::SparseGradient<int> unique_sparse_grad({new_grad.data(), new_indices.data(), indices_size});
  752. Util::ReduceSparseGradient(src_grad_data.data(), src_indice_data.data(), indices_size, segment_size,
  753. first_dim_size, outer_dim_size, &unique_sparse_grad);
  754. // Update the length of reduce sparse gradient and indice
  755. std::vector<int> reduced_lens = {kvs.len().begin(), kvs.len().end()};
  756. reduced_lens[grad_index] = unique_sparse_grad.indices_size_ * segment_size;
  757. reduced_lens[indice_index] = unique_sparse_grad.indices_size_;
  758. // Build the sparse value to be sent
  759. size_t total_size = std::accumulate(reduced_lens.begin(), reduced_lens.end(), 0, std::plus<int>());
  760. std::vector<float> reduced_data(total_size, 0);
  761. BuildSparseValue(reduced_lens, grad_index, indice_index, data, unique_sparse_grad.value_,
  762. unique_sparse_grad.indices_, &reduced_data);
  763. *kvs.mutable_len() = {reduced_lens.begin(), reduced_lens.end()};
  764. *kvs.mutable_values() = {reduced_data.begin(), reduced_data.end()};
  765. }
  766. if (indices_size == 0) {
  767. std::vector<float> no_keys;
  768. std::vector<float> no_vals;
  769. std::vector<float> no_lens;
  770. no_keys.push_back(key);
  771. no_vals.push_back(kGradValue);
  772. *kvs.mutable_values() = {no_vals.begin(), no_vals.end()};
  773. *kvs.mutable_len() = {no_lens.begin(), no_lens.end()};
  774. }
  775. partition->at(i).first = true;
  776. }
  777. }
  778. void Worker::RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  779. const std::map<int64_t, int64_t> &) {
  780. MS_EXCEPTION_IF_NULL(partition);
  781. partition->resize(LongToSize(server_num_));
  782. auto keys = send.keys();
  783. auto values = send.values();
  784. auto lens = send.len();
  785. MS_LOG(INFO) << "the key size is:" << send.keys_size() << " the values size is:" << send.values_size()
  786. << " the lens:" << send.len_size();
  787. size_t len;
  788. Key param_key;
  789. for (int i = 0; i < send.keys_size(); i++) {
  790. param_key = keys[i];
  791. int64_t server_id = key_to_server_id_[param_key];
  792. if (!partition->at(LongToUlong(server_id)).first) {
  793. partition->at(LongToUlong(server_id)).first = true;
  794. }
  795. KVMessage &server_kv_pairs = partition->at(LongToUlong(server_id)).second;
  796. server_kv_pairs.add_keys(param_key);
  797. if (values.empty()) {
  798. continue;
  799. }
  800. len = lens[i];
  801. int64_t offset = std::accumulate(lens.begin(), lens.begin() + i, 0);
  802. auto val_begin = values.begin() + offset;
  803. auto val_end = val_begin + len;
  804. for (auto it = val_begin; it != val_end; ++it) {
  805. server_kv_pairs.add_values(*it);
  806. }
  807. server_kv_pairs.add_len(len);
  808. }
  809. }
  810. void Worker::WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector<std::pair<bool, KVMessage>> *partition,
  811. const std::map<int64_t, int64_t> &) {
  812. MS_EXCEPTION_IF_NULL(partition);
  813. partition->resize(LongToSize(server_num_));
  814. auto keys = send.keys();
  815. auto values = send.values();
  816. auto lens = send.len();
  817. int32_t col_cnt = lens[0] / embedding_row_cnt_[keys[0]];
  818. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[keys[0]]);
  819. for (size_t i = 0; i < ranges.size(); i++) {
  820. size_t offset_begin = ranges[i].begin() * col_cnt;
  821. size_t offset_end = (ranges[i].end() + 1) * col_cnt;
  822. KVMessage kvs;
  823. *kvs.mutable_keys() = keys;
  824. *kvs.mutable_values() = {values.begin() + offset_begin, values.begin() + offset_end};
  825. kvs.add_len(offset_end - offset_begin);
  826. partition->at(i).first = true;
  827. partition->at(i).second = kvs;
  828. }
  829. }
  830. void Worker::UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  831. const std::map<int64_t, int64_t> &) {
  832. MS_EXCEPTION_IF_NULL(partition);
  833. const float *embedding_vals = send.values().data();
  834. const uint64_t *lookup_ids = send.len().data();
  835. size_t val_size = IntToSize(send.values_size());
  836. size_t id_size = IntToSize(send.len_size());
  837. if (id_size == 0) {
  838. MS_LOG(EXCEPTION) << "The id size is 0.";
  839. return;
  840. }
  841. size_t embedding_dim = val_size / id_size;
  842. const Key &key = send.keys()[0];
  843. const std::vector<EmbeddingTableShardMetadata> &ranges = *(embedding_table_ranges_[key]);
  844. partition->resize(ranges.size());
  845. for (size_t i = 0; i < ranges.size(); i++) {
  846. const EmbeddingTableShardMetadata &range = ranges[i];
  847. const auto &begin = range.begin();
  848. const auto &end = range.end();
  849. auto &kvs = partition->at(i).second;
  850. kvs.add_keys(key);
  851. for (size_t j = 0; j < id_size; j++) {
  852. auto lookup_id = lookup_ids[j];
  853. if (lookup_id >= begin && lookup_id <= end) {
  854. kvs.add_keys(lookup_id);
  855. for (size_t k = 0; k < embedding_dim; k++) {
  856. kvs.add_values(embedding_vals[j * embedding_dim + k]);
  857. }
  858. }
  859. }
  860. if (kvs.keys_size() <= 1) {
  861. partition->at(i).first = false;
  862. } else {
  863. partition->at(i).first = true;
  864. }
  865. }
  866. }
  867. void Worker::BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition,
  868. const std::map<int64_t, int64_t> &) {
  869. MS_EXCEPTION_IF_NULL(partition);
  870. partition->resize(LongToSize(server_num_));
  871. for (size_t i = 0; i < LongToSize(server_num_); i++) {
  872. partition->at(i).first = true;
  873. partition->at(i).second = send;
  874. }
  875. }
  876. void Worker::SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
  877. const std::map<int64_t, int64_t> &attrs) {
  878. PartitionKVMessages messages;
  879. partitioner(send, &messages, attrs);
  880. std::vector<uint32_t> rank_ids;
  881. std::vector<DataPtr> data;
  882. std::vector<size_t> sizes;
  883. for (size_t i = 0; i < messages.size(); i++) {
  884. if (messages.at(i).first) {
  885. rank_ids.push_back(i);
  886. std::string kv_data = messages.at(i).second.SerializeAsString();
  887. #ifdef __APPLE__
  888. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  889. #else
  890. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  891. #endif
  892. size_t dest_size = kv_data.length();
  893. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  894. if (ret != 0) {
  895. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  896. return;
  897. }
  898. data.push_back(res);
  899. sizes.push_back(kv_data.length());
  900. }
  901. }
  902. worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd);
  903. }
  904. void Worker::SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner,
  905. const std::map<int64_t, int64_t> &, std::vector<float> *vals, std::vector<int> *lens) {
  906. MS_EXCEPTION_IF_NULL(vals);
  907. PartitionKVMessages messages;
  908. partitioner(send, &messages, {});
  909. std::vector<uint32_t> rank_ids;
  910. std::vector<DataPtr> data;
  911. std::vector<size_t> sizes;
  912. for (size_t i = 0; i < messages.size(); i++) {
  913. if (messages.at(i).first) {
  914. rank_ids.push_back(i);
  915. std::string kv_data = messages.at(i).second.SerializeAsString();
  916. #ifdef __APPLE__
  917. std::shared_ptr<unsigned char> res(new unsigned char[kv_data.length()], std::default_delete<unsigned char[]>());
  918. #else
  919. std::shared_ptr<unsigned char[]> res(new unsigned char[kv_data.length()]);
  920. #endif
  921. size_t dest_size = kv_data.length();
  922. int ret = memcpy_s(res.get(), dest_size, kv_data.data(), kv_data.length());
  923. if (ret != 0) {
  924. MS_LOG(ERROR) << "memcpy_s error, errorno(" << ret << ")";
  925. return;
  926. }
  927. data.push_back(res);
  928. sizes.push_back(kv_data.length());
  929. }
  930. }
  931. std::vector<VectorPtr> resp;
  932. worker_node_.Send(core::NodeRole::SERVER, rank_ids, data, sizes, cmd, &resp);
  933. vals->clear();
  934. for (size_t i = 0; i < resp.size(); ++i) {
  935. KVMessage message;
  936. CHECK_RETURN_TYPE(message.ParseFromArray(resp.at(i)->data(), SizeToInt(resp.at(i)->size())));
  937. std::copy(message.values().begin(), message.values().end(), std::back_inserter(*vals));
  938. if (lens) {
  939. lens->clear();
  940. std::copy(message.len().begin(), message.len().end(), std::back_inserter(*lens));
  941. }
  942. }
  943. }
  944. } // namespace ps
  945. } // namespace mindspore